簡體   English   中英

Tensorflow 2D 矩陣乘法返回矩陣乘積列表

[英]Tensorflow 2D Matrices mutiplication returning list of matrix products

我在 keras/tensorflow 中遇到了這個問題。
我正在實現一個用戶定義的損失 function 並且我遇到了這個問題:我必須將 2 個矩陣相乘,以表格形式獲得矩陣乘積列表
[column_0_matrix_1 x row_0_matrix_2],[column_1_matrix_1 x row_1_matrix_2] ecc。

假設我有

A = [[1 1]
     [3 2]]
B = [[4 1]
     [1 3]]

然后我想要表格中的產品列表

C = |[1] x [4 1]|, |[1] x [1 3]|
    |[3]        |  |[2]        |


任何想法? 我自己嘗試過,但總是取回 2 個起始矩陣的乘積。 任何幫助將不勝感激。 謝謝

您可以拆分每個張量,然后在列表理解中使用tf.linalg.matmul來實現您想要的

import tensorflow as tf


a = tf.constant([[1, 1], [3, 2]])
b = tf.constant([[4, 1], [1, 3]])

a_split = tf.split(a, 2, 1)
b_split = tf.split(b, 2, 0)

[tf.linalg.matmul(x, y) for x, y in zip(a_split, b_split)]
# [<tf.Tensor: shape=(2, 2), dtype=int32, numpy=
#  array([[ 4,  1],
#         [12,  3]], dtype=int32)>,
# <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
#  array([[1, 3],
#         [2, 6]], dtype=int32)>]

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM