Gridea

爱因斯坦求和约定 torch.einsum()

爱因斯坦求和约定 torch.einsum()
2022-05-23 · 2 min read
PyTorch Coding Python

使用 PyTorch 进行矩阵运算时,总是需要去记忆形如 torch.dot, torch.mm 的函数,当张量维度上升到四维及以上时,torch.einsum 是一种能够表达矩阵点积、外积、转置等运算,包括部分复杂张量运算在内的优雅方式。

矩阵乘法

对于矩阵 ARI×KA \in \mathbb{R}^{I \times K}BRK×JB \in \mathbb{R}^{K \times J},两个矩阵的乘积 CC 的维度可以表示为 RI×J\mathbb{R}^{I \times J},用爱因斯坦求和约定可以如下表示:

Cij=(AB)ij=kAikBkjC_{ij} = (AB)_{ij} = \sum_{k}^{}A_{ik}B_{kj}

在代码中,上面的式子可表示为字符串:

'ik,kj->ij'

在 Pytorch 中使用 einsum 的代码实现为:

import torch
A = torch.randn(3, 4)
B = torch.randn(4, 5)
C = torch.einsum('ik,kj->ij', A, B)

其他的一些例子

求和

C = torch.einsum('ij->', A)
C
tensor(1.5575)

列求和

C = torch.einsum('ij->j', A)
C
tensor([-0.4369,  1.1548,  0.8594, -0.0198])

点积求和

C = torch.einsum('ij,ij->', A, A)

转置

C = torch.einsum('ij->ji', A)

多维矩阵乘法

A = torch.randn(3, 4)
B = torch.randn(3, 4, 5)
C = torch.randn(4, 5)
D = torch.einsum('ij,ijk,jk->ik', A, B, C)

D.shape
torch.Size([3, 5])

官方文档

PyTorch 1.11.0 documentation