python中list、numpy、torch.tensor之间的相互转换

最近使用pytorch开发个模型,中间遇到的bug中占比较大的一类是数据格式的转换。这里记录下转换的方式,便于后续查阅。

1.1 list 转 numpy

ndarray = np.array(list)

1
2
3
4
5
6
7
8
9
10
11
# -*- encoding:utf-8 -*-
import numpy as np

a = [1, 2, 3]
print(' type :{0}  value: {1}'.format(type(a), a))
ndarray = np.array(a)
print(' type :{0}  value: {1}'.format(type(ndarray), ndarray))

输出:
 type :<class 'list'>  value: [1, 2, 3]
 type :<class 'numpy.ndarray'>  value: [1 2 3]

1.2 numpy 转 list

list = ndarray.tolist()

1
2
3
4
5
6
7
8
9
10
11
# -*- encoding:utf-8 -*-
import numpy as np

ndarray = np.array([1,2,3])  # list 转为 ndarray
print(' type :{0}  value: {1}'.format(type(ndarray), ndarray))
list = ndarray.tolist() # ndarray 转为 list
print(' type :{0}  value: {1}'.format(type(list), list))

output:
 type :<class 'numpy.ndarray'>  value: [1 2 3]
 type :<class 'list'>  value: [1, 2, 3]

2.1 list 转 torch.Tensor

tensor=torch.Tensor(list)

注意:将list中元素类型为int,转换为tensor后,类型转为了float,如果希望转换为int,则需要加上类型。

常用的不同数据类型的Tensor,有32位的浮点型torch.FloatTensor, 64位浮点型 torch.DoubleTensor, 16位整形torch.ShortTensor, 32位整形torch.IntTensor和64位整形torch.LongTensor

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# -*- encoding:utf-8 -*-
import torch

a = [1, 2, 3]
print(' type :{0}  value: {1}'.format(type(a), a))

tensor = torch.Tensor(a) #默认为float
print(' type :{0}  value: {1}'.format(type(tensor), tensor))

tensor = torch.IntTensor(a) #转为int
print(' type :{0}  value: {1}'.format(type(tensor), tensor))

output:
 type :<class 'list'>  value: [1, 2, 3]
 type :<class 'torch.Tensor'>  value: tensor([1., 2., 3.])
 type :<class 'torch.Tensor'>  value: tensor([1, 2, 3], dtype=torch.int32)

2.2 torch.Tensor 转 list

先转numpy,后转list

list = tensor.numpy().tolist()

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# -*- encoding:utf-8 -*-
import torch

a = [1, 2, 3]
print(' type :{0}  value: {1}'.format(type(a), a))

tensor = torch.Tensor(a)
print(' type :{0}  value: {1}'.format(type(tensor), tensor))
list = tensor.numpy().tolist()
print(' type :{0}  value: {1}'.format(type(list), list))

output:
 type :<class 'list'>  value: [1, 2, 3]
 type :<class 'torch.Tensor'>  value: tensor([1., 2., 3.])
 type :<class 'list'>  value: [1.0, 2.0, 3.0]

3.1 torch.Tensor 转 numpy

ndarray = tensor.numpy()

*gpu上的tensor不能直接转为numpy

ndarray = tensor.cpu().numpy()

3.2 numpy 转 torch.Tensor

tensor = torch.from_numpy(ndarray)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# -*- encoding:utf-8 -*-
import numpy as np
import torch

a = [1, 2, 3]
print(' type :{0}  value: {1}'.format(type(a), a))
ndarray = np.array([1,2,3])  # list 转为 ndarray
print(' type :{0}  value: {1}'.format(type(ndarray), ndarray))
tensor = torch.from_numpy(ndarray)
print(' type :{0}  value: {1}'.format(type(tensor), tensor))

output:
 type :<class 'list'>  value: [1, 2, 3]
 type :<class 'numpy.ndarray'>  value: [1 2 3]
 type :<class 'torch.Tensor'>  value: tensor([1, 2, 3])