简体   繁体   中英

Slicing tensors in tensorflow using argmax

I want to make a dynamic loss function in tensorflow. I want to calculate the energy of a signal's FFT, more specifically only a window of size 3 around the most dominant peak. I am unable to implement in TF, as it throws a lot of errors like Stride and InvalidArgumentError (see above for traceback): Expected begin, end, and strides to be 1D equal size tensors, but got shapes [1,64], [1,64], and [1] instead.

My code is this:

self.spec = tf.fft(self.signal)
self.spec_mag = tf.complex_abs(self.spec[:,1:33])
self.argm = tf.cast(tf.argmax(self.spec_mag, 1), dtype=tf.int32)
self.frac = tf.reduce_sum(self.spec_mag[self.argm-1:self.argm+2], 1)

Since I am computing batchwise of 64 and dimension of data as 64 too, the shape of self.signal is (64,64) . I wish to calculate only the AC components of the FFT. As the signal is real valued, only half the spectrum would do the job. Hence, the shape of self.spec_mag is (64,32) .

The max in this fft is located at self.argm which has a shape (64,1) .

Now I want to calculate the energy of 3 elements around the max peak via: self.spec_mag[self.argm-1:self.argm+2] .

However when I run the code and try to obtain the value of self.frac , I get thrown with multiple errors.

It seems like you were missing and index when accessing argm. Here is the fixed version of the 1, 64 version.

import tensorflow as tf
import numpy as np

x = np.random.rand(1, 64)
xt = tf.constant(value=x, dtype=tf.complex64)

signal = xt
print('signal', signal.shape)
print('signal', signal.eval())

spec = tf.fft(signal)
print('spec', spec.shape)
print('spec', spec.eval())

spec_mag = tf.abs(spec[:,1:33])
print('spec_mag', spec_mag.shape)
print('spec_mag', spec_mag.eval())

argm = tf.cast(tf.argmax(spec_mag, 1), dtype=tf.int32)
print('argm', argm.shape)
print('argm', argm.eval())

frac = tf.reduce_sum(spec_mag[0][(argm[0]-1):(argm[0]+2)], 0)
print('frac', frac.shape)
print('frac', frac.eval())

and here is the expanded version (batch, m, n)

import tensorflow as tf
import numpy as np

x = np.random.rand(1, 1, 64)
xt = tf.constant(value=x, dtype=tf.complex64)

signal = xt
print('signal', signal.shape)
print('signal', signal.eval())

spec = tf.fft(signal)
print('spec', spec.shape)
print('spec', spec.eval())

spec_mag = tf.abs(spec[:, :, 1:33])
print('spec_mag', spec_mag.shape)
print('spec_mag', spec_mag.eval())

argm = tf.cast(tf.argmax(spec_mag, 2), dtype=tf.int32)
print('argm', argm.shape)
print('argm', argm.eval())

frac = tf.reduce_sum(spec_mag[0][0][(argm[0][0]-1):(argm[0][0]+2)], 0)
print('frac', frac.shape)
print('frac', frac.eval())

you may want to fix function names since I edit this code at a newer version of tensorflow.

Tensorflow indexing uses tf.Tensor.getitem :

This operation extracts the specified region from the tensor. The notation is similar to NumPy with the restriction that currently only support basic indexing. That means that using a tensor as input is not currently allowed

So using tf.slice and tf.strided_slice is out of the question as well.

Whereas in tf.gather indices defines slices into the first dimension of Tensor , in tf.gather_nd , indices defines slices into the first N dimensions of the Tensor , where N = indices.shape[-1]

Since you wanted the 3 values around the max , I manually extract the first, second and third element using a list comprehension, followed be a tf.stack

import tensorflow as tf

signal = tf.placeholder(shape=(64, 64), dtype=tf.complex64)
spec = tf.fft(signal)
spec_mag = tf.abs(spec[:,1:33])
argm = tf.cast(tf.argmax(spec_mag, 1), dtype=tf.int32)

frac = tf.stack([tf.gather_nd(spec,tf.transpose(tf.stack(
             [tf.range(64), argm+i]))) for i in [-1, 0, 1]])

frac = tf.reduce_sum(frac, 1)

This will fail for the corner case where argm is the first or last element in the row, but it should be easy to resolve.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM