einsum
是 Einstein summation
的缩写,来源于爱因斯坦求和约定(Einstein summation convention)。这是物理学家阿尔伯特·爱因斯坦引入的一种简便记号,用于描述张量运算,特别是涉及多维数组的运算。
示例1:矩阵乘法
矩阵乘法 C=AB
A = torch.randn(2, 3)
B = torch.randn(3, 4)
C = torch.einsum('ik,kj->ij', A, B)
print(C.size()) # 输出: torch.Size([2, 4])
这里,'ik,kj->ij'
的含义是:
A
的形状为(2, 3)
,对应ik
,i
和k
分别表示第一个和第二个维度。B
的形状为(3, 4)
,对应kj
,k
和j
分别表示第一个和第二个维度。->ij
表示输出张量的模式,结果为(2, 4)
。
示例2:向量点积
向量点积 c=a⋅b
a = torch.randn(3)
b = torch.randn(3)
c = torch.einsum('i,i->', a, b)
print(c.size()) # 输出: torch.Size([])
这里,'i,i->'
的含义是:
a
和b
都是向量,对应模式i
。->
后面为空,表示结果是一个标量。
示例3:批量矩阵乘法
批量矩阵乘法
A = torch.randn(10, 2, 3)
B = torch.randn(10, 3, 4)
C = torch.einsum('bij,bjk->bik', A, B)
print(C.size()) # 输出: torch.Size([10, 2, 4])
这里,'bij,bjk->bik'
的含义是:
A
的形状为(10, 2, 3)
,对应bij
,b
表示批次维度,i
和j
分别表示矩阵的行和列。B
的形状为(10, 3, 4)
,对应bjk
,b
表示批次维度,j
和k
分别表示矩阵的行和列。->bik
表示输出张量的模式,结果为(10, 2, 4)
。
示例4:逐元素相乘(哈达玛积)A.B或A × B
A = torch.randn(3, 4)
B = torch.randn(3, 4)
C = torch.einsum('ij,ij->ij', A, B)
print(C.size()) # 输出: torch.Size([3, 4])
'ij,ij->ij'
表示:
A
和B
都是形状为[3, 4]
的矩阵,用ij
表示。- 结果
C
也是形状为[3, 4]
的矩阵。 - 没有重复索引,所以不进行求和。