简体   繁体   中英

How do I get the diagonal of a tensor of rank higher than 2 along selected axis in tensorflow

I have an Tensor of shape tf.shape(input)=(Batch_Size,Channels,N,N) my goal is it to calculate and output which contains all diagonal elements along axis 2&3. So that tf.shape(output)=(Batch_Size,Channels,N)

There is the function tf.diag_part(input) but it doesn't let me select the axis I want to consider. How can I define a function that does this for me?

Could following code work?

Batches=[]
for batch in input:
    diagonalpart=tf.diag_part(batch)
    Batches.append(diagonalpart)
output=tf.stack(Batches)

The tf.linalg.diag_part should does exactly what you want, eg:

import tensorflow as tf
import numpy as np

# Input shape: (2, 2, 4, 4)
input = np.array([
                [ [[1, 2, 3, 4],
                   [5, 6, 7, 8],
                   [9, 8, 7, 6],
                   [9, 8, 7, 6]],
                  [[5, 4, 3, 2],
                   [1, 2, 3, 4],
                   [5, 6, 7, 8],
                   [1, 2, 3, 4]] ], 
                [ [[1, 2, 3, 4],
                   [5, 6, 7, 8],
                   [9, 8, 7, 6],
                   [1, 2, 3, 4]],
                  [[5, 4, 3, 2],
                   [1, 2, 3, 4],
                   [5, 6, 7, 8],
                   [9, 8, 7, 6]] ] 
                                ])

print(tf.linalg.diag_part(input))

will outputs:

tf.Tensor(
[[[1 6 7 6]
  [5 2 7 4]]

 [[1 6 7 4]
  [5 2 7 6]]], shape=(2, 2, 4), dtype=int32)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM