简体   繁体   English

从Tensorflow过滤“空”值

[英]filtering “empty” values from Tensorflow

I wrote this code to filter values from a Dataset that are <= 6. 我编写此代码来过滤Dataset <= 6的值。

import tensorflow as tf
import tensorflow.contrib.data as ds

def make_graph():
    inits = []
    filter_value = tf.constant([6], dtype=tf.int64)
    source = ds.Dataset.range(10)
    batched = source.batch(3)
    batched_iter = batched.make_initializable_iterator()
    batched_next = batched_iter.get_next()
    inits.append(batched_iter.initializer)
    predicate = tf.less_equal(batched_next, filter_value, name="less_than_filter")
    true_coordinates = tf.where(predicate)
    reshaped = tf.reshape(true_coordinates, [-1])
    # need to turn bools into 1 and 0 elsewhere
    found = tf.gather(params=batched_next, indices=reshaped)

    return found, inits # prepend final tensor

def run_graph(final_tensor, initializers, rounds):
    with tf.Session() as sess:
        init_ops = (tf.local_variables_initializer(), tf.global_variables_initializer())
        sess.run(init_ops)
        summary_writer = tf.summary.FileWriter(graph=sess.graph, logdir=".")
        while rounds > 0:
            for i in initializers:
                sess.run(i)
            try:
                while True:
                    final_result = sess.run(final_tensor)
                    p```pythrint("Got result: {r}".format(r=final_result))
            except tf.errors.OutOfRangeError:
                print("Got out of range error")
            rounds -=1

        summary_writer.flush()

def run():
    final_tensor, initializers = make_graph()
    run_graph(final_tensor=final_tensor,
              initializers=initializers,
              rounds=1)

if __name__ == "__main__":
    run()

However, the results are as follows: 但是,结果如下:

Got result: [0 1 2]
Got result: [3 4 5]
Got result: [6]
Got result: []
Got out of range error

Is there a way to filter this empty Tensor? 有没有办法过滤这个空的Tensor? I tried to brainstorm ways to do this, maybe with a tf.while loop, but I'm not sure whether I'm missing something or such an operation (ie an OpKernel "dropping" an input by not producing output based on its value) is not possible in Tensorflow. 我尝试集体讨论这样做的方法,也许是使用tf.while循环,但是我不确定我是否遗漏了某些东西或者这样的操作(即OpKernel通过不根据其值生成输出而“丢弃”输入)在Tensorflow中是不可能的。

Keeping only values <= 6 BEFORE batching : 在批处理之前仅保留<= 6的值

dataset = ds.Dataset.range(10)
dataset = dataset.filter( lambda v : v <= 6 )
dataset = dataset.batch(3)
batched_iter = dataset.make_initializable_iterator()

This will generate batches containing only the data you want. 这将生成仅包含所需数据的批次。 Note that it's generally better to filter out the unwanted data before building the batches. 请注意,通常最好在构建批次之前过滤掉不需要的数据。 This way, empty tensors will not be generated by the iterator. 这样,迭代器不会生成空张量。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM