← 返回首页
张量切片
发表时间:2024-01-24 02:36:54
张量切片

张量切片

1.张量切片

import torch

print("===torch.chunk===")
a = torch.rand((3, 4))
print(a)
out = torch.chunk(a, 2, dim=1)
print(out[0], out[0].shape)
print(out[1], out[1].shape)

out = torch.chunk(a, 2, dim=0)
print(out[0], out[0].shape)
print(out[1], out[1].shape)

print("===torch.split===")
a = torch.rand((10, 4))
print(a)
out = torch.split(a, 3, dim=0)
print(len(out))
for t in out:
    print(t, t.shape)
print("---------------------")
print(a)
#[1,3,6]指分成1分,3份和6份,不是指下标
out = torch.split(a, [1, 3, 6], dim=0)
for t in out:
    print(t, t.shape)

运行结果:

===torch.chunk===
tensor([[0.7712, 0.0995, 0.6532, 0.0351],
        [0.0719, 0.8926, 0.5960, 0.2994],
        [0.2464, 0.3772, 0.0520, 0.2138]])
tensor([[0.7712, 0.0995],
        [0.0719, 0.8926],
        [0.2464, 0.3772]]) torch.Size([3, 2])
tensor([[0.6532, 0.0351],
        [0.5960, 0.2994],
        [0.0520, 0.2138]]) torch.Size([3, 2])
tensor([[0.7712, 0.0995, 0.6532, 0.0351],
        [0.0719, 0.8926, 0.5960, 0.2994]]) torch.Size([2, 4])
tensor([[0.2464, 0.3772, 0.0520, 0.2138]]) torch.Size([1, 4])
===torch.split===
tensor([[0.0902, 0.0042, 0.5589, 0.1715],
        [0.4750, 0.6471, 0.3872, 0.3494],
        [0.1124, 0.3425, 0.4445, 0.4682],
        [0.1485, 0.3883, 0.8868, 0.6445],
        [0.9101, 0.0865, 0.7651, 0.3049],
        [0.2239, 0.3355, 0.8833, 0.5462],
        [0.4075, 0.2409, 0.0298, 0.9346],
        [0.4887, 0.9803, 0.2442, 0.4688],
        [0.8249, 0.7240, 0.6846, 0.9174],
        [0.4382, 0.5654, 0.4757, 0.0778]])
4
tensor([[0.0902, 0.0042, 0.5589, 0.1715],
        [0.4750, 0.6471, 0.3872, 0.3494],
        [0.1124, 0.3425, 0.4445, 0.4682]]) torch.Size([3, 4])
tensor([[0.1485, 0.3883, 0.8868, 0.6445],
        [0.9101, 0.0865, 0.7651, 0.3049],
        [0.2239, 0.3355, 0.8833, 0.5462]]) torch.Size([3, 4])
tensor([[0.4075, 0.2409, 0.0298, 0.9346],
        [0.4887, 0.9803, 0.2442, 0.4688],
        [0.8249, 0.7240, 0.6846, 0.9174]]) torch.Size([3, 4])
tensor([[0.4382, 0.5654, 0.4757, 0.0778]]) torch.Size([1, 4])
---------------------
tensor([[0.0902, 0.0042, 0.5589, 0.1715],
        [0.4750, 0.6471, 0.3872, 0.3494],
        [0.1124, 0.3425, 0.4445, 0.4682],
        [0.1485, 0.3883, 0.8868, 0.6445],
        [0.9101, 0.0865, 0.7651, 0.3049],
        [0.2239, 0.3355, 0.8833, 0.5462],
        [0.4075, 0.2409, 0.0298, 0.9346],
        [0.4887, 0.9803, 0.2442, 0.4688],
        [0.8249, 0.7240, 0.6846, 0.9174],
        [0.4382, 0.5654, 0.4757, 0.0778]])
tensor([[0.0902, 0.0042, 0.5589, 0.1715]]) torch.Size([1, 4])
tensor([[0.4750, 0.6471, 0.3872, 0.3494],
        [0.1124, 0.3425, 0.4445, 0.4682],
        [0.1485, 0.3883, 0.8868, 0.6445]]) torch.Size([3, 4])
tensor([[0.9101, 0.0865, 0.7651, 0.3049],
        [0.2239, 0.3355, 0.8833, 0.5462],
        [0.4075, 0.2409, 0.0298, 0.9346],
        [0.4887, 0.9803, 0.2442, 0.4688],
        [0.8249, 0.7240, 0.6846, 0.9174],
        [0.4382, 0.5654, 0.4757, 0.0778]]) torch.Size([6, 4])