简体   繁体   English

Tensorflow:迭代数据集所需的时间比预期长 1-2 个数量级

[英]Tensorflow: Iterating over dataset takes 1-2 orders of magnitude longer than expected

Situation:情况:

I have a dataset ( <class 'tensorflow.python.data.ops.dataset_ops.MapDataset'> ) which is a result of some neural network output.我有一个数据集( <class 'tensorflow.python.data.ops.dataset_ops.MapDataset'> ),这是一些神经网络 output 的结果。 In order to get my final prediction, I am currently iterating over it as follows:为了得到我的最终预测,我目前正在迭代它,如下所示:

for row in dataset:
    ap_distance, an_distance = row
    y_pred.append(int(ap_distance.numpy() > an_distance.numpy()))

The dataset has two columns, each holding a scalar wrapped in a tensor.该数据集有两列,每列都包含一个包裹在张量中的标量。

Problem:问题:

The loop body is very simple, it takes < 1e-5 seconds to compute.循环体非常简单,计算时间 < 1e-5 秒。 Sadly, one iteration takes ca.可悲的是,一次迭代需要大约。 0.3 seconds , so fetching the next row from the dataset seems to take almost 0.3s. 0.3 秒,因此从数据集中获取下一行似乎需要将近 0.3 秒。 This behavior is very weird given my hardware (running on a rented GPU server with 16 AMD EPYC cores and 258GB RAM) and the fact that a colleague on his laptop can finish an iteration in 1-2 orders of magnitude less time than I can.考虑到我的硬件(在具有 16 个 AMD EPYC 内核和 258GB RAM 的租用 GPU 服务器上运行)以及他笔记本电脑上的同事可以在比我少 1-2 个数量级的时间内完成迭代的事实,这种行为非常奇怪。 The dataset has ca, 60k rows.数据集有 ca,60k 行。 hence it is unacceptable to wait for so long.因此,等待这么久是不可接受的。

What I tried:我尝试了什么:

I tried mapping the above loop body onto every row of the dataset object, but sadly .numpy() is not available inside dataset.map() !我尝试将上述循环体映射到数据集 object 的每一行,但遗憾的是.numpy()dataset.map()中不可用!

Questions:问题:

  1. Why does it take so long to get a new row from the dataset?为什么从数据集中获取新行需要这么长时间?
  2. How can I fix this performance degradation?我该如何解决这种性能下降?

The solution to the performance degradation is to vectorize your operations.性能下降的解决方案是矢量化您的操作。

ie IE

dataset=dataset.batch(len(dataset))
vec_ap_distance,vec_an_distance=next(iter(dataset))
y_pred=tf.cast(vec_ap_distance>vec_an_distance,tf.int32).numpy()

The major reason for the long running time in your codes is that there are too many unnecessary numpy() calls.代码运行时间长的主要原因是有太多不必要的 numpy() 调用。

If there are 60k rows, there will be 120k numpy() calls, while in my code, there can be only one numpy() call.如果有 60k 行,将有 120k numpy() 调用,而在我的代码中,只能有一个 numpy() 调用。

In summary:总之:

  1. vectorize as much as possible尽可能矢量化
  2. avoid numpy() calls避免 numpy() 调用

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM