简体   繁体   中英

How to use tf.data.Dataset.padded_batch with a nested shape?

I am building a dataset with two tensors of shape [batch,width,heigh,3] and [batch,class] for each element. For simplicity lets say class = 5.

What shape do you feed to dataset.padded_batch(1000,shape) such that image is padded along the width/height/3 axis?

I have tried the following:

tf.TensorShape([[None,None,None,3],[None,5]])
[tf.TensorShape([None,None,None,3]),tf.TensorShape([None,5])]
[[None,None,None,3],[None,5]]
([None,None,None,3],[None,5])
(tf.TensorShape([None,None,None,3]),tf.TensorShape([None,5])‌​)

Each raising TypeError

The docs state:

padded_shapes: A nested structure of tf.TensorShape or tf.int64 vector tensor-like objects representing the shape to which the respective component of each input element should be padded prior to batching. Any unknown dimensions (eg tf.Dimension(None) in a tf.TensorShape or -1 in a tensor-like object) will be padded to the maximum size of that dimension in each batch.

The relevant code:

dataset = tf.data.Dataset.from_generator(generator,tf.float32)
shapes = (tf.TensorShape([None,None,None,3]),tf.TensorShape([None,5]))
batch = dataset.padded_batch(1,shapes)

Thanks to mrry for finding the solution. Turns out that the type in from_generator has to match the number of tensors in the entries.

new code:

dataset = tf.data.Dataset.from_generator(generator,(tf.float32,tf.float32))
shapes = (tf.TensorShape([None,None,None,3]),tf.TensorShape([None,5]))
batch = dataset.padded_batch(1,shapes)

TensorShape doesn't accept nested lists. tf.TensorShape([None, None, None, 3, None, 5]) and TensorShape(None) (note no [] ) are legal.

Combining these two tensors sounds odd to me, though. I'm not sure what you're trying to accomplish, but I'd recommend trying to do it without combining tensors of different dimensions.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM