[英]filtering “empty” values from Tensorflow
我編寫此代碼來過濾Dataset
<= 6的值。
import tensorflow as tf
import tensorflow.contrib.data as ds
def make_graph():
inits = []
filter_value = tf.constant([6], dtype=tf.int64)
source = ds.Dataset.range(10)
batched = source.batch(3)
batched_iter = batched.make_initializable_iterator()
batched_next = batched_iter.get_next()
inits.append(batched_iter.initializer)
predicate = tf.less_equal(batched_next, filter_value, name="less_than_filter")
true_coordinates = tf.where(predicate)
reshaped = tf.reshape(true_coordinates, [-1])
# need to turn bools into 1 and 0 elsewhere
found = tf.gather(params=batched_next, indices=reshaped)
return found, inits # prepend final tensor
def run_graph(final_tensor, initializers, rounds):
with tf.Session() as sess:
init_ops = (tf.local_variables_initializer(), tf.global_variables_initializer())
sess.run(init_ops)
summary_writer = tf.summary.FileWriter(graph=sess.graph, logdir=".")
while rounds > 0:
for i in initializers:
sess.run(i)
try:
while True:
final_result = sess.run(final_tensor)
p```pythrint("Got result: {r}".format(r=final_result))
except tf.errors.OutOfRangeError:
print("Got out of range error")
rounds -=1
summary_writer.flush()
def run():
final_tensor, initializers = make_graph()
run_graph(final_tensor=final_tensor,
initializers=initializers,
rounds=1)
if __name__ == "__main__":
run()
但是,結果如下:
Got result: [0 1 2]
Got result: [3 4 5]
Got result: [6]
Got result: []
Got out of range error
有沒有辦法過濾這個空的Tensor? 我嘗試集體討論這樣做的方法,也許是使用tf.while
循環,但是我不確定我是否遺漏了某些東西或者這樣的操作(即OpKernel通過不根據其值生成輸出而“丟棄”輸入)在Tensorflow中是不可能的。
在批處理之前僅保留<= 6的值 :
dataset = ds.Dataset.range(10)
dataset = dataset.filter( lambda v : v <= 6 )
dataset = dataset.batch(3)
batched_iter = dataset.make_initializable_iterator()
這將生成僅包含所需數據的批次。 請注意,通常最好在構建批次之前過濾掉不需要的數據。 這樣,迭代器不會生成空張量。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.