简体   繁体   中英

Replace nan values in tensorflow tensor

I'm working on a convolutional neural network in tensorflow and I have a problem. The problem is the input image I read through tfrecords contains a certain number of nan values. The cause of this is the image represents a depthmap which has some infinite values in it, and in the process of encoding it in the tfrecord and then decoding to feed it to the net these infinite values become nan values.

Now, since in my situation replacing the infinite values in the original image before encoding it in the tfrecors is not an option, there is any way I can replace the nan values in my image tensor as an operation to do before I feed it to the net?

A combination of tf.where and tf.is_nan should work:

import tensorflow as tf
with tf.Session():
    has_nans = tf.constant([float('NaN'), 1.])
    print(tf.where(tf.is_nan(has_nans), tf.zeros_like(has_nans), has_nans).eval())

Prints (using TensorFlow 0.12.1):

[ 0.  1.]

If someone is looking for the solution in Tensorflow 2.0, the adapted code of Allen Lavoie is :

import tensorflow as tf
with tf.compat.v1.Session():
    has_nans = tf.constant([float('NaN'), 1.])
    print(tf.where(tf.math.is_nan(has_nans), tf.zeros_like(has_nans), has_nans).eval())

A much easier approach, compatible with TF2.0, is to just use tf.clip_by_value , which mirrors np.clip and removes NaNs (see here ):

no_nans = tf.clip_by_value(has_nans, -1e12, 1e12)

Some caveats: 1) this also removes infs 2) Depending on your application you may need to set the clip value to a high value to avoid losing info.

Clip by value made NaN infinity and where was overkill for one variable. I used this to convert a single value to 0 if it's NaN:

value_not_nan = tf.dtypes.cast(tf.math.logical_not(tf.math.is_nan(value)), dtype=tf.float32)
tf.math.multiply_no_nan(value, value_not_nan)

In tensorflow 2.0 you can do it with tf.math.is_nan and tf.tensor_scatter_nd_update :

tensor_with_nan = tf.convert_to_tensor([[np.nan,1.],[0.,np.nan]])
new_value = 9.

indices = tf.where(tf.math.is_nan(tensor_with_nan))
tensor_without_nan = tf.tensor_scatter_nd_update(
    tensor_with_nan,
    indices,
    tf.ones((tf.shape(indices)[0]))*new_value
)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM