我想知道torch.repeat和torch.expand之间的区别!
推出
例如,
复制一楼张量[0,0,0]
有时您想制作第二层张量,例如[[0,0,0],[0,0,0],[0,0,0]]。
Torch.repeat和torch.expand是在这种情况下使用的方法的候选者,但是这些方法有何不同?
我很好奇,所以我做了实验。
实验
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 | >>> import torch >>> a = torch.zeros(3, 1) >>> b = a.repeat(1, 4) >>> c = a.expand(3, 4) # a.expand(-1, 4)でも可能 >>> b.shape torch.Size([3, 4]) >>> c.shape torch.Size([3, 4]) >>> b tensor([[ 0., 0., 0., 0.], [ 0., 0., 0., 0.], [ 0., 0., 0., 0.]]) >>> c tensor([[ 0., 0., 0., 0.], [ 0., 0., 0., 0.], [ 0., 0., 0., 0.]]) >>> b[:, 2] += 1 >>> c[:, 2] += 1 >>> b tensor([[ 0., 0., 1., 0.], [ 0., 0., 1., 0.], [ 0., 0., 1., 0.]]) >>> c tensor([[ 1., 1., 1., 1.], [ 1., 1., 1., 1.], [ 1., 1., 1., 1.]]) |
关键是这个。
区别在于是否使用数据复制原始张量a并将其分配给新的内存。
它也写在teka文档
中
结论
如果要单独处理重复的张量,则可以复制整个数据torch.repeat
如果不是,请使用torch.expand,它可以节省内存
是
。
另外请注意,torch.expand只能在具有一维的轴上扩展。
即使尺寸数大于1,torch.repeat也可以重复。
可能是由于内存限制