张量组合与拼接。
张量的拼接操作在神经网络搭建过程中是非常常用的方法,比如:残差网络、注意力机制中都使用到了张量拼接。使用cat函数可以将张量按照指定维度拼接起来,并不会升维;而stack函数可以将张量在指定维度叠加起来,会升维!
import torch
print("===torch.cat===")
a = torch.zeros((2, 4))
b = torch.ones((2, 4))
out = torch.cat((a,b),dim=0)
print(out)
out = torch.cat((a,b),dim=1)
print(out)
#torch.stack
print("===torch.stack===")
a = torch.linspace(1, 6, 6).view(2, 3)
b = torch.linspace(7, 12, 6).view(2, 3)
print(a, b)
out = torch.stack((a, b), dim=2)
print(out)
print(out.shape)
print(out[:, :, 0])
print(out[:, :, 1])
运行结果:
tensor([[0., 0., 0., 0.],
[0., 0., 0., 0.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]])
tensor([[0., 0., 0., 0., 1., 1., 1., 1.],
[0., 0., 0., 0., 1., 1., 1., 1.]])
torch.stack
tensor([[1., 2., 3.],
[4., 5., 6.]]) tensor([[ 7., 8., 9.],
[10., 11., 12.]])
tensor([[[ 1., 7.],
[ 2., 8.],
[ 3., 9.]],
[[ 4., 10.],
[ 5., 11.],
[ 6., 12.]]])
torch.Size([2, 3, 2])
tensor([[1., 2., 3.],
[4., 5., 6.]])
tensor([[ 7., 8., 9.],
[10., 11., 12.]])