[英]What is the best way to compute the trace of a matrix product in numpy?
[英]Compute trace of matrix product using numpy/pytorch broadcasting
设A为(nxm)矩阵,M为(mxm)矩阵。 为矩阵的轨迹编写tr(),我需要计算tr(AM(A ^ T))。 但是,最终的跟踪操作将丢弃大部分计算。 我可以使用numpy或pytorch的广播规则来仅计算AM(A ^ T)的必要对角线吗?
更新:这是我在PyTorch中计算对角线的解决方案:
torch.sum(torch.sum(At()[:,None,:]*M[:,:,None],0)*At(),0)
您将必须计算至少两个矩阵乘积之一。 随后,您可以在这里使用答案之一: 在numpy中计算矩阵乘积的迹线的最佳方法是什么?
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.