[英]Output shape of numpy.einsum
给定 einsum 的输入 arguments (不运行计算),是否有一种优雅的方法可以从np.einsum
预先计算结果的形状?
# Given a, b and signature with
# a.shape == (1, 2, 5)
# b.shape == (4, 5)
einsum_shape('ijk,mk->ik', a, b) # returns (1, 5)
这是适用于通用输入数量和相关 einsum 表达式的东西,也适用于特定的标量减少情况 -
def einsum_outshape(einsum_expr, inputs):
shps = np.concatenate([in_.shape for in_ in inputs])
p = einsum_expr.split(',')
s = p[:-1] + p[-1].split('->')
if s[-1]=='':
return ()
else:
inop = list(map(list,s))
return tuple(shps[(np.concatenate(inop[:-1])[:,None]==inop[-1]).argmax(0)])
样品运行 -
In [42]: a = np.random.rand(1,2,5)
...: b = np.random.rand(4,5)
...: c = np.random.rand(5,7,8)
...: d = np.random.rand(7,9)
In [43]: einsum_outshape('ijk,mk,kpq,pr->ikpqr', inputs=(a,b,c,d))
Out[43]: (1, 5, 7, 8, 9)
# Reduction to a scalar
In [44]: einsum_outshape('ijk,mk,kpq,pr->', inputs=(a,b,c,d))
Out[44]: ()
根据@Divakar 的回答,我想出了以下内容,如果传递了不受支持的下标字符串,则更具可读性并引发错误。
def einsum_outshape(subscripts, *operants):
"""Compute the shape of output from `numpy.einsum`.
Does not support ellipses.
"""
if "." in subscripts:
raise ValueError(f'Ellipses are not supported: {subscripts}')
insubs, outsubs = subscripts.replace(",", "").split("->")
if outsubs == "":
return ()
insubs = np.array(list(insubs))
innumber = np.concatenate([op.shape for op in operants])
outshape = []
for o in outsubs:
indices, = np.where(insubs == o)
try:
outshape.append(innumber[indices].max())
except ValueError:
raise ValueError(f'Invalid subscripts: {subscripts}')
return tuple(outshape)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.