繁体   English   中英

如何解决 TPU 推理的数据获取瓶颈?

[英]How to solve data fetch bottle neck for TPU inference?

这就是我的推理设置的样子

autotune = tf.data.experimental.AUTOTUNE

with strategy.scope():
    model = LoadModel()
    raw_dataset = tf.data.TFRecordDataset(tfRecordAddress)
    train_dataset = raw_dataset.map(_parse_example, num_parallel_calls=autotune)
    train_dataset = train_dataset.padded_batch(batch_size, padding_values=(1, 1, b'-'), padded_shapes=(512, 512, 1))
    # train_dataset = train_dataset.repeat()
    train_dataset = train_dataset.prefetch(autotune)
    train_dataset = strategy.experimental_distribute_dataset(train_dataset)

def per_core_inference_fn(inputIds,attnIds ):
    return model.inference((inputIds, attnIds))

@tf.function
def inference_fn(inputIds, attnIds):
    return strategy.run(per_core_inference_fn, args=(inputIds,attnIds))

results = []
for x in train_dataset:
    t0 = time.time()
    results.append(inference_fn(x[0], x[1]))
    t1 = time.time()
    print('time is :', t1-t0)

使用巨大的 batch_sizes,推理速度非常快,大约 0.0003 秒。 然而,下一批的获取需要很长时间, for x in train_dataset: ,大约 60-80 秒。

据我所知,我的推理是正确的,但不知何故,TPU 的 CPU 在批量检索方面遇到了巨大的瓶颈。

我在训练过程中没有看到这个瓶颈。 所以看起来model.fit正在做我没有做的事情。

我有一种感觉,这个瓶颈是由于for x in train_dataset而发生的。 批加载之间的这 60-80 秒对我来说意味着预取没有按预期工作。 在自定义训练循环 (CTL) 代码中,我通常会看到整个循环都包含在tf.function中,例如此处

你能类似地修改你的代码吗? 您还可以尝试捕获 TPU 配置文件 ( https://cloud.google.com/tpu/docs/cloud-tpu-tools#capture_profile ) 而不是使用time.time()进行基准测试。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM