百度了一圈gather的用法,看的一知半解,结合了几篇博客的讲解,终于理解了这个的用法,记录下来,用于以后忘记的时候自己可以快速复习,同时不懂得小伙伴也可以参考下我这得理解,或许能帮助到你!!!
首先了解下函数中的参数:
torch.gather(input, dim, index, out=None) → Tensor
Parameters:
1 2 3 4 | input (Tensor) – The source tensor dim (int) – The axis along which to index index (LongTensor) – The indices of elements to gather out (Tensor, optional) – Destination tensor |
input :需要索引的 tensor
dim : 指索引的维度 (0代表横向 1代表纵向 以此类推)
index: 索引的下标
接下来直接上例子解释
1 2 3 4 5 6 7 | import torch b = torch.Tensor([[1,2,3],[4,5,6]]) index_1 = torch.LongTensor([[0,1],[2,0]]) index_2 = torch.LongTensor([[0,1,1],[0,0,0]]) print (torch.gather(b, dim=1, index=index_1)) print (torch.gather(b, dim=0, index=index_2)) |
输出:
1 2 3 4 | tensor([[1., 2.], [6., 4.]]) tensor([[1., 5., 6.], [1., 2., 3.]]) |
第一个式子 dim=1:torch.gather(b, dim=1, index=index_1)
input : b =
1 2 | 1,2,3 4,5,6 |
dim = 1 :代表的是维度1也就是列
index =
1 2 | 0,1 2,0 |
了解了输入后我们分步进行解析
- index 的指就是代表对应维度,这里dim=1 ,0就代表第0列,1就代表第一列,2就代表第二列,我们先把每一个输出的值在input中的坐标的列写出来,注意一点,输出的shape也就是index的shape
1 2 | (,0),(,1) (,2),(,0) |
这样我们就完成了每个输出所在input中的坐标的列的定位
- 接下来每个输出的定位横坐标。每个输出的横坐标,也就是所在输出的横坐标
1 2 | (0,0),(0,1) (1,2),(1,0) |
- 最后用我们上面得到的坐标去获取input中对应的值
1 2 | 1,2 6,4 |
第二个式子 dim=0 :torch.gather(b, dim=0, index=index_2)
input : b =
1 2 | 1,2,3 4,5,6 |
dim = 0 :代表的是维度0也就是行
index =
1 2 | 0,1,1 0,0,0 |
- 有了上个式子的经验,第一步当然是写出对应dim的坐标啦,这个式子dim=0,也就可以先写出横坐标,这里的横坐标,就是index对应的值
1 2 | (0,),(1,),(1,) (0,),(0,),(0,) |
- 接下来写出纵坐标,纵坐标也就是输出所对应的纵坐标
1 2 | (0,0),(1,1),(1,2) (0,0),(0,1),(0,2) |
- 最后写出对应input的值
1 2 | 1,5,6 1,2,3 |
有了上面两个式子的解释,现在可以总结出gather 的用法了
gather的用法就是index所提供要索引的dim维的位置,其余维度的位置也就是index对应的位置 ,也就是输出的坐标,把dim维的替换成index中对应的数字 。
还不理解的话,再举个官方的例子:
1 2 3 4 5 | >>> t = torch.Tensor([[1,2],[3,4]]) >>> torch.gather(t, 1, torch.LongTensor([[0,0],[1,0]])) 1 1 4 3 [torch.FloatTensor of size 2x2] |
index 的中每个元素的坐标为:
1 2 | (0,0),(0,1) (1,0),(1,1) |
dim= 1 ,也就是把第二个维度的坐标替换成index中的值
1 2 | (0,0),(0,0) (1,1),(1,0) |
最后写出对应input中的值
1 2 | 1,1 4,3 |
如果还不懂的话,推荐一个博客,看看别人的讲解吧:https://blog.csdn.net/edogawachia/article/details/80515038