[英]tensorflow element-wise multiplication broadcasting?
tensorflow 是否為最后一個維度上的元素乘法廣播提供任何功能?
這是我正在嘗試做什么以及什么不起作用的示例:
import tensorflow as tf
x = tf.constant(5, shape=(1, 200, 175, 6), dtype=tf.float32)
y = tf.constant(1, shape=(1, 200, 175), dtype=tf.float32)
tf.math.multiply(x, y)
本質上,我希望對x
沿最后一個維度的每個切片,與y
逐元素矩陣乘法。
我發現這個問題提出了類似的操作: Efficient element-wise multiplication of a matrix and a vector in TensorFlow
不幸的是,建議的方法(使用tf.multiply()
)現在不再有效。 相應的tf.math.multiply
也不起作用,因為上面的代碼給了我以下錯誤:
Traceback (most recent call last):
File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 1864, in _create_c_op
c_op = c_api.TF_FinishOperation(op_desc)
tensorflow.python.framework.errors_impl.InvalidArgumentError: Dimensions must be equal, but are 175 and 200 for 'Mul' (op: 'Mul') with input shapes: [1,200,175,6], [1,200,175].
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/util/dispatch.py", line 180, in wrapper
return target(*args, **kwargs)
File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py", line 322, in multiply
return gen_math_ops.mul(x, y, name)
File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py", line 6490, in mul
"Mul", x=x, y=y, name=name)
File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py", line 788, in _apply_op_helper
op_def=op_def)
File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/util/deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 3616, in create_op
op_def=op_def)
File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 2027, in __init__
control_input_ops)
File "/home/yuqiong/miniconda3/envs/deep/lib/python3.7/site-packages/tensorflow/python/framework/ops.py", line 1867, in _create_c_op
raise ValueError(str(e))
ValueError: Dimensions must be equal, but are 175 and 200 for 'Mul' (op: 'Mul') with input shapes: [1,200,175,6], [1,200,175].
我可以想到一種可行的方法:將y
復制 6 次,使其具有與x
完全相同的形狀,然后進行元素乘法。
但是在張量流中是否有更快且內存高效的方法來做到這一點?
這應該可以實現您想要的:
x = np.array([[[1,2,3],[4,5,6],[7,8,9],[10,11,12]]])
# [[[ 1 2 3]
# [ 4 5 6]
# [ 7 8 9]
# [10 11 12]]]
y = np.array([[1,2,3,4]])
# [[1 2 3 4]]
y = tf.expand_dims(y, axis=-1)
mul = tf.multiply(x, y)
# [[[ 1 2 3]
# [ 8 10 12]
# [21 24 27]
# [40 44 48]]]
最后,使用您需要的形狀:
x = np.random.rand(1, 200, 175, 6)
y = np.random.rand(1, 200, 175)
y = tf.expand_dims(y, axis=-1)
mul = tf.multiply(x, y)
with tf.Session() as sess:
print(sess.run(mul).shape)
# (1, 200, 175, 6)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.