最近使用pytorch开发个模型,中间遇到的bug中占比较大的一类是数据格式的转换。这里记录下转换的方式,便于后续查阅。
1.1 list 转 numpy
ndarray = np.array(list)
1 2 3 4 5 6 7 8 9 10 11 | # -*- encoding:utf-8 -*- import numpy as np a = [1, 2, 3] print(' type :{0} value: {1}'.format(type(a), a)) ndarray = np.array(a) print(' type :{0} value: {1}'.format(type(ndarray), ndarray)) 输出: type :<class 'list'> value: [1, 2, 3] type :<class 'numpy.ndarray'> value: [1 2 3] |
1.2 numpy 转 list
list = ndarray.tolist()
1 2 3 4 5 6 7 8 9 10 11 | # -*- encoding:utf-8 -*- import numpy as np ndarray = np.array([1,2,3]) # list 转为 ndarray print(' type :{0} value: {1}'.format(type(ndarray), ndarray)) list = ndarray.tolist() # ndarray 转为 list print(' type :{0} value: {1}'.format(type(list), list)) output: type :<class 'numpy.ndarray'> value: [1 2 3] type :<class 'list'> value: [1, 2, 3] |
2.1 list 转 torch.Tensor
tensor=torch.Tensor(list)
注意:将list中元素类型为int,转换为tensor后,类型转为了float,如果希望转换为int,则需要加上类型。
常用的不同数据类型的Tensor,有32位的浮点型torch.FloatTensor, 64位浮点型 torch.DoubleTensor, 16位整形torch.ShortTensor, 32位整形torch.IntTensor和64位整形torch.LongTensor
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 | # -*- encoding:utf-8 -*- import torch a = [1, 2, 3] print(' type :{0} value: {1}'.format(type(a), a)) tensor = torch.Tensor(a) #默认为float print(' type :{0} value: {1}'.format(type(tensor), tensor)) tensor = torch.IntTensor(a) #转为int print(' type :{0} value: {1}'.format(type(tensor), tensor)) output: type :<class 'list'> value: [1, 2, 3] type :<class 'torch.Tensor'> value: tensor([1., 2., 3.]) type :<class 'torch.Tensor'> value: tensor([1, 2, 3], dtype=torch.int32) |
2.2 torch.Tensor 转 list
先转numpy,后转list
list = tensor.numpy().tolist()
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | # -*- encoding:utf-8 -*- import torch a = [1, 2, 3] print(' type :{0} value: {1}'.format(type(a), a)) tensor = torch.Tensor(a) print(' type :{0} value: {1}'.format(type(tensor), tensor)) list = tensor.numpy().tolist() print(' type :{0} value: {1}'.format(type(list), list)) output: type :<class 'list'> value: [1, 2, 3] type :<class 'torch.Tensor'> value: tensor([1., 2., 3.]) type :<class 'list'> value: [1.0, 2.0, 3.0] |
3.1 torch.Tensor 转 numpy
ndarray = tensor.numpy()
*gpu上的tensor不能直接转为numpy
ndarray = tensor.cpu().numpy()
3.2 numpy 转 torch.Tensor
tensor = torch.from_numpy(ndarray)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 | # -*- encoding:utf-8 -*- import numpy as np import torch a = [1, 2, 3] print(' type :{0} value: {1}'.format(type(a), a)) ndarray = np.array([1,2,3]) # list 转为 ndarray print(' type :{0} value: {1}'.format(type(ndarray), ndarray)) tensor = torch.from_numpy(ndarray) print(' type :{0} value: {1}'.format(type(tensor), tensor)) output: type :<class 'list'> value: [1, 2, 3] type :<class 'numpy.ndarray'> value: [1 2 3] type :<class 'torch.Tensor'> value: tensor([1, 2, 3]) |