繁体   English   中英

使用numpy / pytorch广播计算矩阵产品的跟踪

[英]Compute trace of matrix product using numpy/pytorch broadcasting

设A为(nxm)矩阵,M为(mxm)矩阵。 为矩阵的轨迹编写tr(),我需要计算tr(AM(A ^ T))。 但是,最终的跟踪操作将丢弃大部分计算。 我可以使用numpy或pytorch的广播规则来仅计算AM(A ^ T)的必要对角线吗?

更新:这是我在PyTorch中计算对角线的解决方案:

torch.sum(torch.sum(At()[:,None,:]*M[:,:,None],0)*At(),0)

您将必须计算至少两个矩阵乘积之一。 随后,您可以在这里使用答案之一: 在numpy中计算矩阵乘积的迹线的最佳方法是什么?

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM