张量的索引与数据筛选。

import torch
#torch.where
print("====torch.where====")
a = torch.rand(4, 4)
b = torch.rand(4, 4)
print(a)
print(b)
#大于0.5选择a,否则选择b
out = torch.where(a > 0.5, a, b)
print(out)
#torch.index_select
print("====torch.index_select====")
a = torch.rand(4, 4)
print(a)
out = torch.index_select(a, dim=0,
index=torch.tensor([0, 3, 2]))
print(out, out.shape)
#torch.gather
print("====torch.gather====")
a = torch.linspace(1, 16, 16).view(4, 4)
print(a)
out = torch.gather(a, dim=0,
index=torch.tensor([[0, 1, 1, 1],
[0, 1, 2, 2],
[0, 1, 3, 3]]))
print(out)
print(out.shape)
#dim=0, out[i, j, k] = input[index[i, j, k], j, k]
#dim=1, out[i, j, k] = input[i, index[i, j, k], k]
#dim=2, out[i, j, k] = input[i, j, index[i, j, k]]
#torch.masked_index
print("====torch.masked_index====")
a = torch.linspace(1, 16, 16).view(4, 4)
mask = torch.gt(a, 8)
print(a)
print(mask)
out = torch.masked_select(a, mask)
print(out)
#torch.take
print("====torch.take====")
a = torch.linspace(1, 16, 16).view(4, 4)
print(a)
b = torch.take(a, index=torch.tensor([0, 15, 13, 10]))
print(b)
#torch.nonzero
print("====torch.nonzero====")
a = torch.tensor([[0, 1, 2, 0], [2, 3, 0, 1]])
print(a)
out = torch.nonzero(a)
print(out)
#稀疏表示
运行结果:
====torch.where====
tensor([[3.9137e-01, 9.5595e-01, 1.7861e-01, 3.7395e-01],
[2.2352e-05, 7.4780e-01, 1.6343e-01, 1.6966e-01],
[8.3604e-01, 4.8091e-01, 3.5203e-01, 5.9287e-01],
[7.7489e-01, 9.7039e-01, 6.7575e-01, 3.3128e-01]])
tensor([[0.8857, 0.0945, 0.3468, 0.5015],
[0.4236, 0.2081, 0.1400, 0.1066],
[0.4563, 0.7172, 0.4498, 0.3887],
[0.9632, 0.2181, 0.6813, 0.4860]])
tensor([[0.8857, 0.9559, 0.3468, 0.5015],
[0.4236, 0.7478, 0.1400, 0.1066],
[0.8360, 0.7172, 0.4498, 0.5929],
[0.7749, 0.9704, 0.6757, 0.4860]])
====torch.index_select====
tensor([[0.2192, 0.6786, 0.1027, 0.8901],
[0.2469, 0.8186, 0.0098, 0.5595],
[0.7827, 0.7446, 0.5794, 0.5256],
[0.2509, 0.3690, 0.4719, 0.4802]])
tensor([[0.2192, 0.6786, 0.1027, 0.8901],
[0.2509, 0.3690, 0.4719, 0.4802],
[0.7827, 0.7446, 0.5794, 0.5256]]) torch.Size([3, 4])
====torch.gather====
tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]])
tensor([[ 1., 6., 7., 8.],
[ 1., 6., 11., 12.],
[ 1., 6., 15., 16.]])
torch.Size([3, 4])
====torch.masked_index====
tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]])
tensor([[False, False, False, False],
[False, False, False, False],
[ True, True, True, True],
[ True, True, True, True]])
tensor([ 9., 10., 11., 12., 13., 14., 15., 16.])
====torch.take====
tensor([[ 1., 2., 3., 4.],
[ 5., 6., 7., 8.],
[ 9., 10., 11., 12.],
[13., 14., 15., 16.]])
tensor([ 1., 16., 14., 11.])
====torch.nonzero====
tensor([[0, 1, 2, 0],
[2, 3, 0, 1]])
tensor([[0, 1],
[0, 2],
[1, 0],
[1, 1],
[1, 3]])