本文整理了笔者在学习pytorch中经常遇到的一些函数,本篇博客会不断进行更新,并且会加上自己使用背景和使用经验。
1. torch.max()函数
笔者最近在学习目标检测的相关知识,无论是在计算多个bounding box之间的IOU还是确定bounding box的类别信息的时候,都会用到torch.max()函数。torch.max()可以得到一个tensor某个维度的最大值,可以的得到两个tensor之间对应元素之间的最大值。
这个函数的签名是这样的:
1 | torch.max(input, dim, keepdim=False, out=None) -> (Tensor, LongTensor) |
1.1 单个tensor
情况1:如果输入tensor是一维:
1 2 3 | # 输入一维的tensor a = torch.tensor([9,4,6,3,1,2,5,9]) print(torch.max(a)) |
输出结果就是这一维数据中,最大的元素
1 | tensor(9) |
情况2:如果输入tensor是二维:
这时候就涉及到要比较哪个维度上的数据,是按行比较呢?还是按列比较呢?还是就是求这个tensor中最大的哪个元素?
(1)按列比较
1 2 3 4 5 | # 求每一列最大的元素 b = torch.tensor([[4,8,2],[2,9,2]]) max_value, max_index = torch.max(b,dim=0) print(max_value) print(max_index) |
max_value表示得到的最终数据,max_index表示得到最大元素的每一列中索引。最终的数据结果是这样的:
1 2 3 4 5 | # max_value: 表示每一列的最大值 tensor([4, 9, 2]) # max_index: 表示最大值在每一列中的索引 tensor([0, 1, 1]) |
我们发现,最终的结果相对于输入来说,结果的维度比输入的少了一个维度,这是因为我们指定了要比较哪个维度上面的数据。如果通过一幅图片显示torch.max()函数的效果就是这样:

(2)按行比较
1 2 3 4 5 | # 求每一列最大的元素 b = torch.tensor([[4,8,2],[2,9,2]]) max_value, max_index = torch.max(b,dim=1) print(max_value) print(max_index) |
最终的结果就是这样的
1 2 3 4 5 | # max_value: 表示每一行的最大值 tensor([8, 9]) # max_index: 表示最大值在每一行中的索引 tensor([1, 1]) |

(3)求整个tensor中最大的元素
1 2 3 4 5 6 7 | # 求一个tensor中最大的元素 b = torch.tensor([[4,8,2],[2,9,2]]) # 输出结果 print(torch.max(b)) # tensor([9]) |
1.2 两个tensor
torch.max()还可以比较两个tensor之间的最大值
情况1:单元素和另外一个tensor进行比较
1 2 3 | c = torch.tensor([5]) d = torch.tensor([1,2,3]) print(torch.max(c,d)) |
输出的结果是:
1 | tensor([5, 5, 5]) |
情况2:两个不同tensor之间进行比较
这种情况下,必须要求两个tensor的尺寸是一样的,不然会报错
1 2 3 | c = torch.tensor([[1,2,3],[2,3,4]]) d = torch.tensor([[4,3,1],[3,6,3]]) print(torch.max(c,d)) |
输出结果是这样:
1 2 3 | # 输出结果 tensor([[4, 3, 3], [3, 6, 4]]) |
2. torch.argmax()函数
torch.argmax()函数和torch.max()函数功能差不多,只不多前者只会返回一个tensor最大元素的索引,并不会返回这个最大元素是什么。torch.max()函数不仅会返回最大值的索引,而且还会返回这个最大的元素是什么。
max()函数的函数签名是这样的:
1 | torch.argmax(input, dim=None, keepdim=False) -> Tensor |
其中dim的含义是这样的:the dimension to reduce(参考了这篇博客)
情况1:返回一个tensor中最大元素的索引
1 2 3 4 5 6 | # 返回一个tensor最大值的索引 a = torch.tensor([[1,2,3],[2,3,4]]) print(torch.argmax(b)) # 输出值 # tensor(5) |
情况2:返回指定维度的最大元素索引
1 2 3 4 5 6 7 8 9 10 11 12 13 | # 将第0维的数据进行比较 b = torch.tensor([[1,2,3],[2,3,4]]) print(torch.argmax(b,dim=0) # 输入结果 # tensor([1, 1, 1]) # 将第1维数据进行比较 print(torch.argmax(b,dim=1)) # 输出结果 # tensor([2,2]) |
3. torch.nonzero()函数
torch.nonzero()函数有很多功能,比如能够配合mask筛选出你想要的数据。比如在目标检测求confidence小于threshold的bbox,或者只选择某一类别的bbox,都可以使用nonzero()函数进行筛选。
torch.nonzero()函数签名如下:
1 | torch.nonzero(input, out=None) -> Tensor |
这个函数的主要就是找到tensor中所有不为零元素的索引。函数的返回值是一个z*n的tensor。z表示不为零的元素的个数,n表示输入tensor的维度。
1 2 3 4 | # 打印a中所有不为0的元素的索引 a = torch.tensor([[1,2,3],[1,2,0]]) index = torch.nonzero(a) print(index) |
输出结果是这样的:
1 2 3 4 5 | tensor([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]]) |
但是,我目前主要用的就是确定某一列中不含0的索引,然后再将其对应的某些行全部过滤出来。
1 2 3 4 5 6 7 | # 第2列中有一个非零数据,我们想把非零列所对应的行过滤出来 a = torch.tensor([[1,2,3],[1,2,0]]) index = torch.nonzero(a[:,2]) print(a[index.squeeze()]) # 输出效果 tensor([1, 2, 3]) |