← 返回首页
张量的索引与数据筛选
发表时间:2024-01-24 01:43:19
张量的索引与数据筛选

张量的索引与数据筛选。

1.张量的索引与数据筛选

import torch
#torch.where
print("====torch.where====")

a = torch.rand(4, 4)
b = torch.rand(4, 4)

print(a)
print(b)
#大于0.5选择a,否则选择b
out = torch.where(a > 0.5, a, b)

print(out)

#torch.index_select
print("====torch.index_select====")

a = torch.rand(4, 4)
print(a)
out = torch.index_select(a, dim=0,
                   index=torch.tensor([0, 3, 2]))

print(out, out.shape)
#torch.gather
print("====torch.gather====")
a = torch.linspace(1, 16, 16).view(4, 4)
print(a)
out = torch.gather(a, dim=0,
             index=torch.tensor([[0, 1, 1, 1],
                                 [0, 1, 2, 2],
                                 [0, 1, 3, 3]]))
print(out)
print(out.shape)

#dim=0, out[i, j, k] = input[index[i, j, k], j, k]
#dim=1, out[i, j, k] = input[i, index[i, j, k], k]
#dim=2, out[i, j, k] = input[i, j, index[i, j, k]]

#torch.masked_index

print("====torch.masked_index====")
a = torch.linspace(1, 16, 16).view(4, 4)
mask = torch.gt(a, 8)
print(a)
print(mask)
out = torch.masked_select(a, mask)
print(out)

#torch.take
print("====torch.take====")
a = torch.linspace(1, 16, 16).view(4, 4)
print(a)
b = torch.take(a, index=torch.tensor([0, 15, 13, 10]))
print(b)

#torch.nonzero
print("====torch.nonzero====")
a = torch.tensor([[0, 1, 2, 0], [2, 3, 0, 1]])
print(a)
out = torch.nonzero(a)
print(out)
#稀疏表示

运行结果:

====torch.where====
tensor([[3.9137e-01, 9.5595e-01, 1.7861e-01, 3.7395e-01],
        [2.2352e-05, 7.4780e-01, 1.6343e-01, 1.6966e-01],
        [8.3604e-01, 4.8091e-01, 3.5203e-01, 5.9287e-01],
        [7.7489e-01, 9.7039e-01, 6.7575e-01, 3.3128e-01]])
tensor([[0.8857, 0.0945, 0.3468, 0.5015],
        [0.4236, 0.2081, 0.1400, 0.1066],
        [0.4563, 0.7172, 0.4498, 0.3887],
        [0.9632, 0.2181, 0.6813, 0.4860]])
tensor([[0.8857, 0.9559, 0.3468, 0.5015],
        [0.4236, 0.7478, 0.1400, 0.1066],
        [0.8360, 0.7172, 0.4498, 0.5929],
        [0.7749, 0.9704, 0.6757, 0.4860]])
====torch.index_select====
tensor([[0.2192, 0.6786, 0.1027, 0.8901],
        [0.2469, 0.8186, 0.0098, 0.5595],
        [0.7827, 0.7446, 0.5794, 0.5256],
        [0.2509, 0.3690, 0.4719, 0.4802]])
tensor([[0.2192, 0.6786, 0.1027, 0.8901],
        [0.2509, 0.3690, 0.4719, 0.4802],
        [0.7827, 0.7446, 0.5794, 0.5256]]) torch.Size([3, 4])
====torch.gather====
tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.],
        [13., 14., 15., 16.]])
tensor([[ 1.,  6.,  7.,  8.],
        [ 1.,  6., 11., 12.],
        [ 1.,  6., 15., 16.]])
torch.Size([3, 4])
====torch.masked_index====
tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.],
        [13., 14., 15., 16.]])
tensor([[False, False, False, False],
        [False, False, False, False],
        [ True,  True,  True,  True],
        [ True,  True,  True,  True]])
tensor([ 9., 10., 11., 12., 13., 14., 15., 16.])
====torch.take====
tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.],
        [13., 14., 15., 16.]])
tensor([ 1., 16., 14., 11.])
====torch.nonzero====
tensor([[0, 1, 2, 0],
        [2, 3, 0, 1]])
tensor([[0, 1],
        [0, 2],
        [1, 0],
        [1, 1],
        [1, 3]])