简体   繁体   中英

TensorFlow multiplication along axis

I want to multiply only along a given axis like this:

a = tf.ones([2,2,3])
b = tf.constant(7)
c = //multiply a[:,:,1] with b

so that c[...,0] and c[...,2] have ones but c[...,1] has sevens:

print(c.shape)
> (2, 2, 3)

print(a[...,0]) //output same for a[...,1] and a[...,2]
> tf.Tensor(
[[1. 1.]
 [1. 1.]], shape=(2, 2), dtype=float32)

print(c[...,0])
>tf.Tensor(
 [[1. 1.]
  [1. 1.]], shape=(2, 2), dtype=float32)


print(c[...,1])
>tf.Tensor(
 [[7. 7.]
  [7. 7.]], shape=(2, 2), dtype=float32)


print(c[...,2])
>tf.Tensor(
 [[1. 1.]
  [1. 1.]], shape=(2, 2), dtype=float32)

I'm not quite sure what result you're expecting, but if I understood you correctly, you could do something like this:

import tensorflow as tf

a = tf.concat([tf.ones([1, 4, 3], dtype=tf.float32), 
               tf.ones([1, 4, 3], dtype=tf.float32) * 3, 
               tf.zeros([2, 4, 3], dtype=tf.float32)], axis=0)
b = tf.constant(7, dtype=tf.float32)

tensor = tf.slice(a,
               begin=[1, 0, 0],
               size=[1, 4, 3])
c = tensor * b
result = tf.tensor_scatter_nd_update(a, [[1]], c)
print('a -->', a, '\n')
print('c -->', c, '\n')
print('result -->', result)
a --> tf.Tensor(
[[[1. 1. 1.]
  [1. 1. 1.]
  [1. 1. 1.]
  [1. 1. 1.]]

 [[3. 3. 3.]
  [3. 3. 3.]
  [3. 3. 3.]
  [3. 3. 3.]]

 [[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]

 [[0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]
  [0. 0. 0.]]], shape=(4, 4, 3), dtype=float32) 

c --> tf.Tensor(
[[[21. 21. 21.]
  [21. 21. 21.]
  [21. 21. 21.]
  [21. 21. 21.]]], shape=(1, 4, 3), dtype=float32) 

result --> tf.Tensor(
[[[ 1.  1.  1.]
  [ 1.  1.  1.]
  [ 1.  1.  1.]
  [ 1.  1.  1.]]

 [[21. 21. 21.]
  [21. 21. 21.]
  [21. 21. 21.]
  [21. 21. 21.]]

 [[ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]]

 [[ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]]], shape=(4, 4, 3), dtype=float32)

Update: slicing the data the way you want is not what you think it is:

import tensorflow as tf

a = tf.concat([tf.ones([1, 2, 3], dtype=tf.float32), 
               tf.zeros([1, 2, 3], dtype=tf.float32)], axis=0)

b = tf.ones([2, 2, 3], dtype=tf.float32)

print(a.shape, a[...,0])
print(b.shape, b[...,0])
(2, 2, 3) tf.Tensor(
[[1. 1.]
 [0. 0.]], shape=(2, 2), dtype=float32)
(2, 2, 3) tf.Tensor(
[[1. 1.]
 [1. 1.]], shape=(2, 2), dtype=float32)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM