简体   繁体   中英

tf.data.dataset: How do I assign shape to a dataset (with shape undefined) that is guaranteed to output certain shape?

I have a tf2 Dataset API dataset , which undergoes multiple map operations followed by tf.image.resize that constantly outputs shape (300, 300) ie each record is guaranteed to have this shape after all map operations. However, this is not inherently inferred, and hence the Tensor Spec shows <undefined>, <undefined> shape. Undefined shaped datasets throw an error if they are passed to a model with pre-defined input shape.

Some searching helped me find this function tf.contrib.data.assert_element_shape and Issue #16052 :

dataset = dataset.apply(tf.data.experimental.assert_element_shape(custom_shape))

But this function has been removed in tf2, and the docs does NOT recommend using something else in place of assert_element_shape. What is it's equivalent? Or how do I assign shape to a dataset that is guaranteed to output certain shape?

For some reason, adding set_shape within the map function where I added tf.image.resize does NOT work.

# does not work
def my_map_function(image, label):
    # some image operations here
    image = tf.image.resize(image, size=[300, 300])
    image.set_shape((300, 300, 3))
    return image, label

But when I made a separate map function, it works:

# works
def set_shapes(image, label):
    image.set_shape((300, 300, 3))
    label.set_shape([])
    return image, label

Perhaps I'll stick to this until a direct assert_element_shape or set_element_shape gets added as separate functions

The first snippet of code in the accepted answer will actually work if you put 'set_shape' line before the 'resize' line.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM