Pytorch与张量裁剪。
张量裁剪运算就是对Tensor中元素进行范围过滤,解空间被约束后,网络会更加稳定,因此可以达到正则化和防止过拟合的效果。同时还可以防止梯度离散或者梯度爆炸。
import torch
print("==== 张量裁剪 ====")
a = torch.rand(2, 2) * 10
print(a)
a = a.clamp(2, 5) # 裁剪范围为2-5,对于小于2的直接取2,对于大于5的直接取5,中间值保持不变
print(a)
运行结果:
==== 张量裁剪 ====
tensor([[2.6644, 0.9549],
[8.9727, 9.0019]])
tensor([[2.6644, 2.0000],
[5.0000, 5.0000]])