使用 PyTorch 进行矩阵运算时,总是需要去记忆形如 torch.dot
, torch.mm
的函数,当张量维度上升到四维及以上时,torch.einsum
是一种能够表达矩阵点积、外积、转置等运算,包括部分复杂张量运算在内的优雅方式。
对于矩阵 和 ,两个矩阵的乘积 的维度可以表示为 ,用爱因斯坦求和约定可以如下表示:
在代码中,上面的式子可表示为字符串:
'ik,kj->ij'
在 Pytorch 中使用 einsum 的代码实现为:
import torch
A = torch.randn(3, 4)
B = torch.randn(4, 5)
C = torch.einsum('ik,kj->ij', A, B)
C = torch.einsum('ij->', A)
C
tensor(1.5575)
C = torch.einsum('ij->j', A)
C
tensor([-0.4369, 1.1548, 0.8594, -0.0198])
C = torch.einsum('ij,ij->', A, A)
C = torch.einsum('ij->ji', A)
A = torch.randn(3, 4)
B = torch.randn(3, 4, 5)
C = torch.randn(4, 5)
D = torch.einsum('ij,ijk,jk->ik', A, B, C)
D.shape
torch.Size([3, 5])