I have an Tensor of shape tf.shape(input)=(Batch_Size,Channels,N,N)
my goal is it to calculate and output which contains all diagonal elements along axis 2&3. So that tf.shape(output)=(Batch_Size,Channels,N)
There is the function tf.diag_part(input)
but it doesn't let me select the axis I want to consider. How can I define a function that does this for me?
Could following code work?
Batches=[]
for batch in input:
diagonalpart=tf.diag_part(batch)
Batches.append(diagonalpart)
output=tf.stack(Batches)
The tf.linalg.diag_part should does exactly what you want, eg:
import tensorflow as tf
import numpy as np
# Input shape: (2, 2, 4, 4)
input = np.array([
[ [[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 8, 7, 6],
[9, 8, 7, 6]],
[[5, 4, 3, 2],
[1, 2, 3, 4],
[5, 6, 7, 8],
[1, 2, 3, 4]] ],
[ [[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 8, 7, 6],
[1, 2, 3, 4]],
[[5, 4, 3, 2],
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 8, 7, 6]] ]
])
print(tf.linalg.diag_part(input))
will outputs:
tf.Tensor(
[[[1 6 7 6]
[5 2 7 4]]
[[1 6 7 4]
[5 2 7 6]]], shape=(2, 2, 4), dtype=int32)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.