简体   繁体   中英

Batch tf.data.Dataset.filter()?

I know one can filter a dataset with tf.data.Dataset.filter() :

d = tf.data.Dataset.from_tensor_slices([1, 2, 3])

d = d.filter(lambda x: x < 3)  # ==> [1, 2]

# `tf.math.equal(x, y)` is required for equality comparison
def filter_fn(x):
  return tf.math.equal(x, 1)

d = d.filter(filter_fn)  # ==> [1]

What if I want to do a "batch filter"? By this, I mean given a batch of strings ['str1', 'str2', 'str3', 'str4'] , how do I make a dataset that is able to return me a filtered dataset that spits out a batch of values that correspond to those strings: [val1_respects_str1, val2_respects_str2, val3_respects_str3, val4_respects_str4] ?

What you want is not a filter but a map. The map function will compute the mapped values, such as:

d = tf.data.Dataset.from_tensor_slices(list(range(100)))

def map_fn(x):
  return x*2  # if you need arbitrary python logic, use tf.py_function to wrap it

d = d.shuffle(100).batch(10).map(map_fn)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM