繁体   English   中英

Tensorflow 批量调整尺寸

[英]Tensorflow batch resize dimension

批处理后我得到了这个维度的数据

tf.Tensor(
[[[  2 436 381 ... 416 333   3]]

 [[  2 651 374 ... 654 370   3]]

 [[  2 743 357 ... 771 358   3]]

 ...

 [[  2 594 432 ... 552 425   3]]

 [[  2 820 409 ... 886 438   3]]

 [[  2 734 397 ... 825 330   3]]], shape=(64, 1, 34), dtype=int64) tf.Tensor(
[[[  2 335 395 ... 281 405   3]]

 [[  2 542 379 ... 512 370   3]]

 [[  2 676 356 ... 696 354   3]]

 ...

 [[  2 733 411 ... 718 403   3]]

 [[  2 828 389 ... 883 407   3]]

 [[  2 774 376 ... 850 316   3]]], shape=(64, 1, 34), dtype=int64)

但是,我希望批次的形状像(64,34)。 我在批处理后尝试了重塑,但它不起作用。 这就是批处理的创建方式。

BATCH_SIZE = 64
def prepare(ds):
  src, trg = tf.split(ds, num_or_size_splits = 2, axis=1)
  return srcs, trgs

def make_batches(ds):
   return (
      ds
      .cache()
      .shuffle(BUFFER_SIZE)
      .batch(BATCH_SIZE,num_parallel_calls=tf.data.experimental.AUTOTUNE)
      .map(prepare,num_parallel_calls=tf.data.experimental.AUTOTUNE)
      .prefetch(tf.data.experimental.AUTOTUNE))

train_batches = make_batches(train_examples)

将您的prepare方法更改为:

def prepare(x):
  srcs, trgs = tf.split(x, num_or_size_splits = 2, axis=1)
  return tf.squeeze(srcs, axis=1), tf.squeeze(trgs, axis=1)

你应该有你想要的 output。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM