Pytorch:torch.mul() 、torch.mm()、torch.bmm()、torch.matmul()

1 torch.mul()

  • 用标量值value乘以输入input的每个元素,并返回一个新的结果张量。 \( out=tensor ? value \)。如果输入是FloatTensor or DoubleTensor类型,则value 必须为实数,否则须为整数。

1
torch.mul(input, value, out=None)
参数 描述
input (Tensor) 输入张量
value (Number) 乘到每个元素的数
out (Tensor) 可选,输出张量

栗子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
>>> a = torch.randn(3)
>>> a

-0.9374
-0.5254
-0.6069
[torch.FloatTensor of size 3]

>>> torch.mul(a, 100)

-93.7411
-52.5374
-60.6908
[torch.FloatTensor of size 3]
  • 两个张量input,other按元素进行相乘,并返回到输出张量。即计算\( out_i=input_i ? other_i \)。两计算张量形状不须匹配,但总元素数须一致。

1
torch.mul(input, other, out=None)
参数 描述
input (Tensor) 第一个相乘张量
other (Tensor) 第二个相乘张量
out (Tensor) 可选,输出张量

栗子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
>>> a = torch.randn(4,4)
>>> a

-0.7280  0.0598 -1.4327 -0.5825
-0.1427 -0.0690  0.0821 -0.3270
-0.9241  0.5110  0.4070 -1.1188
-0.8308  0.7426 -0.6240 -1.1582
[torch.FloatTensor of size 4x4]

>>> b = torch.randn(2, 8)
>>> b

 0.0430 -1.0775  0.6015  1.1647 -0.6549  0.0308 -0.1670  1.0742
-1.2593  0.0292 -0.0849  0.4530  1.2404 -0.4659 -0.1840  0.5974
[torch.FloatTensor of size 2x8]

>>> torch.mul(a, b)

-0.0313 -0.0645 -0.8618 -0.6784
 0.0934 -0.0021 -0.0137 -0.3513
 1.1638  0.0149 -0.0346 -0.5068
-1.0304 -0.3460  0.1148 -0.6919
[torch.FloatTensor of size 4x4]

2 torch.mm()

处理二维矩阵的乘法,而且也只能处理二维矩阵,其他维度要用torch.matmul()。torch.mm(a, b)是矩阵a和b矩阵相乘,比如a的维度是(1, 2),b的维度是(2, 3),返回的就是(1, 3)的矩阵

1
torch.mm(input, mat2, out=None)
  • 栗子
1
2
3
4
5
6
mat1 = torch.randn(2, 3)
mat2 = torch.randn(3, 3)
torch.mm(mat1, mat2)

tensor([[ 0.4851,  0.5037, -0.3633],
        [-0.0760, -3.6705,  2.4784]])

3 torch.bmm()

1
torch.bmm(input, mat2, out=None)

看函数名就知道,在torch.mm的基础上加了个batch计算,不能广播。

4 torch.matmul()

1
torch.matmul(input, other, out=None)

功能
适用性最多的,能处理batch、广播的矩阵:

  1. 如果第一个参数是一维,第二个是二维,那么给第一个提供一个维度
  2. 如果第一个是二维,第二个是一维,就是矩阵乘向量
  3. 带有batch的情况,可保留batch计算
  4. 维度不同时,可先广播,再batch计算

栗子:

  • vector x vector
1
2
3
4
tensor1 = torch.randn(3)
tensor2 = torch.randn(3)
torch.matmul(tensor1, tensor2).size()
torch.Size([])
  • matrix x vector
1
2
3
4
tensor1 = torch.randn(3, 4)
tensor2 = torch.randn(4)
torch.matmul(tensor1, tensor2).size()
torch.Size([3])
  • batched matrix x broadcasted vecto
1
2
3
4
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4)
torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3])
  • batched matrix x batched matrix
1
2
3
4
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(10, 4, 5)
torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])

总结:

对位相乘用torch.mul,二维矩阵乘法用torch.mm,batch二维矩阵用torch.bmm,batch、广播用torch.matmul

参考:

  • https://blog.csdn.net/McEason/article/details/104182648