开发者

Pytorch中torch.cat()函数举例解析

目录
  • 一. torch.cat()函数解析
    • 1. 函数说明
    • 2. 代码举例
  • 总结

    一. torch.cat()函数解析

    1. 函数说明

    1.1 官网:torch.cat(),函数定义及参数说明如下图所示:

    android

    Pytorch中torch.cat()函数举例解析

    1.2 函数功能

    函数将两个张量(tensor)按指定维度拼接在一起,注意:除拼接维数dim数值可不同外其余维数数值需相同,方能对齐,如下面例子所示。torch.cat()函数不会新增维度,而torch.stack()函数会新增一个维度,相同的是两个都是对张量进行拼接

    2. 代码举例

    2.1 输入两个二维张量(dim=0):dim=0对行进行拼接

    a = torch.randn(2,3)
    b =  torch.randn(3,3)
    c = torch.cat((a,b),dim=0)
    a,b,c
    

    输出结果如下:

    (tensor([[-0.90, -0.37,  1.96],

             [-2.65, -0.60,  0.05]]),

     tensor([[ 1.30,  0.24,  0.27],

             [-1.99, -1.09,  1.67],

             [-1.62,  1.54, -0.14]]),

     tensor([[-0.90, -0.37,  1.96],

             [-2.65, -0.60,  0.05],

             [ 1.30,  0.24,  0.27],

             [-1.99, -1.09,  1.67],

             [-1.62,  1.54, -0.14]]))

    2.2 输入两个二维张量(dim=1): dim=1对列进行拼接

    a = torch.randn(2,3)
    b =  torch.randn(2,4)
    c = torch.cat((a,b),dim=1)
    a,b,c
    

    输出结果如下:

    (tensor([[-0.55, -0.84, -1.60],

             [ 0.39, -0.96,  1.02]]),

     tensor([[-0.83, -0.09,  0.05,  0.17],

             [ 0.28, -0.7www.devze.com4, -0.27, -0.85]]),

     tensor([[-0.55, -0.84, -1.60, -0.83, -0.09,  0.05,  0.17],

             [ 0.39, -0.96,  1.02,  0.28, -0.74, -0.27, -0.85]]))

    2.3 输入两个三维张量:dim=0 对通道进行拼接

    a = torch.randn(2,3,4)
    b =  torch.randn(1,3,4)
    c = torch.cat((a,b),dim=0)
    a,b,c
    

    输出结果如下:

    (tensor([[[ 0.51, -0.72, -0.02,  0.76],

              [ 0.72,  1.01,  0.39, -0.13],

              [ 0.37, -0.63, -2.69,  0.74]],

     

             [[ 0.72, -0.31, -0.27,  0.10],

              [ 1.66, -0.06,  1.91, -0.66],

              [ 0.34, -0.23, -0.18, -1.22]]]),

     tensor([[[ 0.94,  0.77, -0.41, -1.20],

              [-0.23, -1.03, -0.25,  1.67],

              [-1.00, -0.68, -0.35, -0.50]]]),

     tensor([[[ 0.51, -0.72, -0.02,  0.76],

              [ 0.72,  1.01,  0.39, -0.13],

              [ 0.37, -0.63, -2.69,  0.74]],

     

             [[ 0.72, -0.31, -0.27,  0.10],

              [ 1.66, -0.06,  1.91, -0.66],

              [ 0.34, -0.23, -0.18, -1.22]],

     

             [[ 0.94,  0.77, -0.41, -1.20],

              [-0.23, -1.03, -0.25,  1.67],

              [-1.00, -0.68, -0.35, -0.50]]]))

    2.4 输入两个三维张量:dim=1对行进行拼接

    a = torch.randn(2,3,4)
    b =  torch.randn(2,4,4)
    c = torch.cat((a,b),dim=1)
    a,b,c
    

    输出结果如下:

    (tensor([[[-0.86,  0.00, -1.26,  1.20],

              [-0.46, -1.08, -0.82,  2.03],

              [-0.89,  0.43,  1.92,  0.49]],

     

        &编程客栈nbsp;    [[ 0.24, -0.02,  0.32,  0.97],

              [ 0.33, -1.34,  0.76, -1.55],

              [ 0.38,  1.45,  0.27, -0.64]]]),

     tensor([[[ 0.82,  0.85, -0.30, -0.58],

              [-0.09,  0.40,  0.02,  0.75],

              [-0.70,  0.67, -0.88, -0.50],

              [-0.62, -1.65, -1.10, -1.39]],

     

             [[-0.85, -1.61, -0.35, -0.56],

              [ 0.00,  1.40,  0.41,  0.39],

              [-0.01,  0.04,  0.80,  0.41],

              [-1.21, -0.64,  1.14,  1.64]]]),

     tensor([[[-0.86,  0.00, -1.26,  1.20],

              [-0.46, -1.08, -0.82,  2.03],

              [-0.89,  0.43,  1.92,  0.49],

           javascript;   [ 0.82,  0.85, -0.30, -0.58],

              [-0.09,  0.40,  0.02,  0.75],

              [-0.70,  0.67, -0.88, -0.50],

              [-0.62, -1.65, -1.10, -1.39]],

     

             [[ 0.24, -0.02,  0.32,  0.97],

              [ 0.33, -1.34,  0.76, -1.55],

              [ 0.38,  1.45,  0.27, -0.64],

              [-0.85, -1.61, -0.35, -0.56],

              [ 0.00,  1.40,  0.41,  0.39],

              [-0.01,  0.04,  0.80,  0.41],

              [-1.21, -0.64,  1.14,  1.64]]]))

    2.5 输入两个三维张量:dim=2对列进行拼接

    a = torch.randn(2,3,4)
    b =  torch.randn(2,3,5)
    c = torch.cat((a,b),dim=2)
    a,b,c
    

    输出结果如下:

    (tensor([[[ 0.13, -0.02,  0.13, -0.25],

              [ 1.42, -0.22, -0.87,  0.27],

              [-0.07,  1.04, -0.06,  0.91]],

     

     开发者_Go学习        [[ 0.88, -1.46,  0.04,  0.35],

              [ 1.36,  0.64,  0.75,  0.39],

              [ 0.36,  1.13,  0.83,  0.56]]]),

     tensor([[[-0.47, -2.30, -0.49, -1.02,  1.74],

              [ 0.71,  0.89,  0.80, -0.05, -1.35],

              [-0.40,  0.26, -0.78, -1.50, -0.92]],

     

             [[-0.77, -0.01,  1.23,  0.70, -0.66],

              [ 0.28, -0.18, -0.91,  2.23,  1.14],

              [-1.93, -0.17,  0.15,  0.40,  0.32]]]),

     tensor([[[ 0.13, -0.02,  0.13, -0.25, -0.47, -2.30, -0.49, -1.02,  1编程客栈.74],

              [ 1.42, -0.22, -0.87,  0.27,  0.71,  0.89,  0.80, -0.05, -1.35],

              [-0.07,  1.04, -0.06,  0.91, -0.40,  0.26, -0.78, -1.50, -0.92]],

     

             [[ 0.88, -1.46,  0.04,  0.35, -0.77, -0.01,  1.23,  0.70, -0.66],

              [ 1.36,  0.64,  0.75,  0.39,  0.28, -0.18, -0.91,  2.23,  1.14],

              [ 0.36,  1.13,  0.83,  0.56, -1.93, -0.17,  0.15,  0.40,  0.32]]]))

    总结

    到此这篇关于Pytorch中torch.cat()函数举例解析的文章就介绍到这了,更多相关Pytorch中torch.cat()函数内容请搜索我们以前的文章或继续浏览下面的相关文章希望大家以后多多支持我们!

    0

    上一篇:

    下一篇:

    精彩评论

    暂无评论...
    验证码 换一张
    取 消

    最新开发

    开发排行榜