简体   繁体   中英

How to filter dataset by tensor shape in Tensorflow

I have loaded a dataset from tfds.load and want to throw away certain images that interfere with proper training/are of no use to me (for example, are too small).

It seems like there is absolutely no information on this specific problem anywhere so I went with what seems like the best fit which was.filter(predicate) on the dataset. Unfortunately the input to the predicate has indeterminate shape (None, None, 3) and as expected raises an error that 'int' cannot be compared with 'NoneType'.

Is it even possible to solve this problem in tensorflow or should I not waste my time?

Pseudo code

ds_train = tfds.load('name')
ds_train = ds_train.map(lambda ds: ds['image'])
ds_train = ds_train.filter(lambda image: image.shape[0] >= 256)

When writing code with tf.data.Dataset , you should use tf.shape(tensor) rather than tensor.shape , because tf.data.Dataset works in graph mode.

Quoting the documentation oftf.shape :

tf.shape and Tensor.shape should be identical in eager mode. Within tf.function or within a compat.v1 context, not all dimensions may be known until execution time. Hence when defining custom layers and models for graph mode, prefer the dynamic tf.shape(x) over the static x.shape.

ds_train = ds_train.filter(lambda image: tf.shape(image)[0] >= 256)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM