[PyTorch] torch.repeat和torch.expand之间的区别


我想知道torch.repeat和torch.expand之间的区别!

推出

例如,
复制一楼张量[0,0,0]
有时您想制作第二层张量,例如[[0,0,0],[0,0,0],[0,0,0]]。

Torch.repeat和torch.expand是在这种情况下使用的方法的候选者,但是这些方法有何不同?
我很好奇,所以我做了实验。

实验

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
>>> import torch
>>> a = torch.zeros(3, 1)
>>> b = a.repeat(1, 4)
>>> c = a.expand(3, 4)   # a.expand(-1, 4)でも可能
>>> b.shape
torch.Size([3, 4])
>>> c.shape
torch.Size([3, 4])
>>> b
tensor([[ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.]])
>>> c
tensor([[ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.]])
>>> b[:, 2] += 1
>>> c[:, 2] += 1
>>> b
tensor([[ 0.,  0.,  1.,  0.],
        [ 0.,  0.,  1.,  0.],
        [ 0.,  0.,  1.,  0.]])
>>> c
tensor([[ 1.,  1.,  1.,  1.],
        [ 1.,  1.,  1.,  1.],
        [ 1.,  1.,  1.,  1.]])

关键是这个。
区别在于是否使用数据复制原始张量a并将其分配给新的内存。
它也写在teka文档

结论

如果要单独处理重复的张量,则可以复制整个数据torch.repeat
如果不是,请使用torch.expand,它可以节省内存

另外请注意,torch.expand只能在具有一维的轴上扩展。
即使尺寸数大于1,torch.repeat也可以重复。
可能是由于内存限制