pytorch中矩阵各种乘法的使用
关于 @运算,*运算,torch.mul(), torch.mm(), torch.mv(), tensor.t(), torch.matmul
@ 和 *
代表矩阵的两种相乘方式: @
表示常规的数学上定义的矩阵相乘; *
表示两个矩阵对应位置处的两个元素相乘。
x.dot(y)
: 向量乘积,x,y均为一维向量。
*和torch.mul()
等同:表示相同shape矩阵点乘,即对应位置相乘,得到矩阵有相同的shape。
@和torch.mm(a, b)
等同:正常矩阵相乘,要求a的列数与b的行数相同。
torch.mv(X, w0)
:是矩阵和向量相乘.第一个参数是矩阵,第二个参数只能是一维向量,等价于 X乘以w0的转置 Y.t()
:矩阵Y的转置。
1 | a = torch.tensor([ |
1、 *和torch.mul()
等同,矩阵点乘
1 | res1 = a*b |
Output:
a*b=tensor([[3, 0, 3],
[0, 0, 0],
[3, 0, 3]])
torch.mul(a, b)=tensor([[3, 0, 3],
[0, 0, 0],
[3, 0, 3]])
2、 @, torch.matmul(a, b)和torch.mm(a, b)
等同,矩阵乘法
1 | res2 = a@b |
Output:
a@b=tensor([[6, 2, 6],
[1, 0, 1],
[6, 2, 6]])
torch.mm(a,b)=tensor([[6, 2, 6],
[1, 0, 1],
[6, 2, 6]])
3、 dot()
向量乘法
向量运算,参数不能是多维矩阵,否则报错: RuntimeError: 1D tensors expected, got 2D, 2D tensors at
.
1 | res3 = w1.dot(w2) |
Output:
w1.dot(w2)=14
4、 c.t()
矩阵转置
1 | res4 = c.t() |
Output:
c.t()=tensor([[1, 0, 0],
[1, 1, 0],
[1, 1, 1]])
5、 torch.mv(a, w1)
矩阵乘向量
1 | res5 = torch.mv(a, w1) |
Output: 向量w1是行向量和列向量都可以,结果一样。
torch.mv(a, w1)=tensor([2, 2, 2])
torch.mv(a, w1.t())=tensor([2, 2, 2])
如果命题拿不准就多测试两遍,相互验证。