简体   繁体   English

Python: numpy.dot / numpy.tensordot 用于多维 ZA3CBC3F9D0CE1F2C159CZD71D71B

[英]Python: numpy.dot / numpy.tensordot for multidimensional arrays

I'm optimising my implementation of the back-propagation algorithm to train a neural network.我正在优化反向传播算法的实现以训练神经网络。 One of the aspects I'm working on is performing the matrix operations on the set of datapoints (input/output vector) as a batch process optimised by the numpy library instead of looping through every datapoint.我正在研究的一个方面是对一组数据点(输入/输出向量)执行矩阵运算作为由 numpy 库优化的批处理,而不是循环遍历每个数据点。

In my original algorithm I did the following:在我原来的算法中,我做了以下事情:

for datapoint in datapoints:
  A = ... (created out of datapoint info)
  B = ... (created out of datapoint info)

  C = np.dot(A,B.transpose())
____________________

A: (7,1) numpy array
B: (6,1) numpy array
C: (7,6) numpy array

I then expanded said matrices to tensors, where the first shape index would refer to the dataset.然后我将所述矩阵扩展为张量,其中第一个形状索引将引用数据集。 If I have 3 datasets (for simplicity purposes), the matrices would look like this:如果我有 3 个数据集(为简单起见),矩阵将如下所示:

A: (3,7,1) numpy array
B: (3,6,1) numpy array
C: (3,7,6) numpy array

Using ONLY np.tensordot or other numpy manipulations, how do I generate C?仅使用 np.tensordot 或其他 numpy 操作,如何生成 C?

I assume the answer would look something like this:我假设答案看起来像这样:

C = np.tensordot(A.[some manipulation], B.[some manipulation], axes = (...))

(This is a part of a much more complex application, and the way I'm structuring things is not flexible anymore. If I find no solution I will only loop through the datasets and perform the multiplication for each dataset) (这是一个更复杂的应用程序的一部分,我构建事物的方式不再灵活。如果我找不到解决方案,我只会遍历数据集并为每个数据集执行乘法)

We can use np.einsum -我们可以使用np.einsum -

c = np.einsum('ijk,ilm->ijl',a,b)

Since the last axes are singleton, you might be better off with sliced arrays -由于最后一个轴是 singleton,因此使用切片 arrays 可能会更好 -

c = np.einsum('ij,il->ijl',a[...,0],b[...,0])

With np.matmul/@-operator -使用np.matmul/@-operator -

c = a@b.swapaxes(1,2)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM