Pytorch简单编程技巧。
pytorch默认情况下只使用一个GPU,在多GPU情况下需要使用DataParallel。也就是所谓的支持单机多卡和多机多卡。
import torch
import numpy as np
import cv2
data = cv2.imread("cat.jpg")
print(data)
cv2.imshow("cat1", data)
# cv2.waitKey(0)
# a = np.zeros([2, 2])
out = torch.from_numpy(data)
#tensor创建在GPU上
out = out.to(torch.device("cuda"))
print(out.is_cuda)
#把图片翻转过来
out = torch.flip(out, dims=[0])
#tensor创建在cpu上
out = out.to(torch.device("cpu"))
print(out.is_cuda)
data = out.numpy()
cv2.imshow("cat2", data)
cv2.waitKey(0)