简体   繁体   中英

filtering “empty” values from Tensorflow

I wrote this code to filter values from a Dataset that are <= 6.

import tensorflow as tf
import tensorflow.contrib.data as ds

def make_graph():
    inits = []
    filter_value = tf.constant([6], dtype=tf.int64)
    source = ds.Dataset.range(10)
    batched = source.batch(3)
    batched_iter = batched.make_initializable_iterator()
    batched_next = batched_iter.get_next()
    inits.append(batched_iter.initializer)
    predicate = tf.less_equal(batched_next, filter_value, name="less_than_filter")
    true_coordinates = tf.where(predicate)
    reshaped = tf.reshape(true_coordinates, [-1])
    # need to turn bools into 1 and 0 elsewhere
    found = tf.gather(params=batched_next, indices=reshaped)

    return found, inits # prepend final tensor

def run_graph(final_tensor, initializers, rounds):
    with tf.Session() as sess:
        init_ops = (tf.local_variables_initializer(), tf.global_variables_initializer())
        sess.run(init_ops)
        summary_writer = tf.summary.FileWriter(graph=sess.graph, logdir=".")
        while rounds > 0:
            for i in initializers:
                sess.run(i)
            try:
                while True:
                    final_result = sess.run(final_tensor)
                    p```pythrint("Got result: {r}".format(r=final_result))
            except tf.errors.OutOfRangeError:
                print("Got out of range error")
            rounds -=1

        summary_writer.flush()

def run():
    final_tensor, initializers = make_graph()
    run_graph(final_tensor=final_tensor,
              initializers=initializers,
              rounds=1)

if __name__ == "__main__":
    run()

However, the results are as follows:

Got result: [0 1 2]
Got result: [3 4 5]
Got result: [6]
Got result: []
Got out of range error

Is there a way to filter this empty Tensor? I tried to brainstorm ways to do this, maybe with a tf.while loop, but I'm not sure whether I'm missing something or such an operation (ie an OpKernel "dropping" an input by not producing output based on its value) is not possible in Tensorflow.

Keeping only values <= 6 BEFORE batching :

dataset = ds.Dataset.range(10)
dataset = dataset.filter( lambda v : v <= 6 )
dataset = dataset.batch(3)
batched_iter = dataset.make_initializable_iterator()

This will generate batches containing only the data you want. Note that it's generally better to filter out the unwanted data before building the batches. This way, empty tensors will not be generated by the iterator.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM