← 返回首页
张量组合与拼接
发表时间:2024-01-24 02:36:12
张量组合与拼接

张量组合与拼接。

1.张量组合与拼接

张量的拼接操作在神经网络搭建过程中是非常常用的方法,比如:残差网络、注意力机制中都使用到了张量拼接。使用cat函数可以将张量按照指定维度拼接起来,并不会升维;而stack函数可以将张量在指定维度叠加起来,会升维!

import torch

print("===torch.cat===")
a = torch.zeros((2, 4))
b = torch.ones((2, 4))

out = torch.cat((a,b),dim=0)
print(out)

out = torch.cat((a,b),dim=1)
print(out)

#torch.stack

print("===torch.stack===")
a = torch.linspace(1, 6, 6).view(2, 3)
b = torch.linspace(7, 12, 6).view(2, 3)
print(a, b)
out = torch.stack((a, b), dim=2)
print(out)
print(out.shape)

print(out[:, :, 0])
print(out[:, :, 1])

运行结果:

tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])
tensor([[0., 0., 0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1., 1., 1.]])
torch.stack
tensor([[1., 2., 3.],
        [4., 5., 6.]]) tensor([[ 7.,  8.,  9.],
        [10., 11., 12.]])
tensor([[[ 1.,  7.],
         [ 2.,  8.],
         [ 3.,  9.]],

        [[ 4., 10.],
         [ 5., 11.],
         [ 6., 12.]]])
torch.Size([2, 3, 2])
tensor([[1., 2., 3.],
        [4., 5., 6.]])
tensor([[ 7.,  8.,  9.],
        [10., 11., 12.]])