简体   繁体   中英

tf.data.Dataset.padded_batch pad differently each feature

I have a tf.data.Dataset instance which holds 3 different features

  • label which is a scalar
  • sequence_feature which is a sequence of scalars
  • seq_of_seqs_feature which is a sequence of sequences feature

I am trying to use tf.data.Dataset.padded_batch() to genereate padded data as input to my model - and I want to pad every feature differently.

Example batch:

[{'label': 24,
  'sequence_feature': [1, 2],
  'seq_of_seqs_feature': [[11.1, 22.2],
                          [33.3, 44.4]]},
 {'label': 32,
  'sequence_feature': [3, 4, 5],
  'seq_of_seqs_feature': [[55.55, 66.66]]}]

Expected output:

[{'label': 24,
  'sequence_feature': [1, 2, 0],
  'seq_of_seqs_feature': [[11.1, 22.2],
                          [33.3, 44.4]]},
 {'label': 32,
  'sequence_feature': [3, 4, 5],
  'seq_of_seqs_feature': [[55.55, 66.66],
                           0.0, 0.0    ]}]

As you can see the label feature should not be padded, and the sequence_feature and seq_of_seqs_feature should be padded by the corresponding longest entry in the given batch.

The tf.data.Dataset.padded_batch() method allows you to specify padded_shapes for each component (feature) of the resulting batch. For example, if your input dataset is called ds :

padded_ds = ds.padded_batch(
    BATCH_SIZE,
    padded_shapes={
        'label': [],                          # Scalar elements, no padding.
        'sequence_feature': [None],           # Vector elements, padded to longest.
        'seq_of_seqs_feature': [None, None],  # Matrix elements, padded to longest
    })                                        # in each dimension.

Notice that the padded_shapes argument has the same structure as your input dataset's elements, so in this case it takes a dictionary with keys that match your feature names.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM