[英]Tensorflow: Iterating over dataset takes 1-2 orders of magnitude longer than expected
Situation:情况:
I have a dataset ( <class 'tensorflow.python.data.ops.dataset_ops.MapDataset'>
) which is a result of some neural network output.我有一个数据集(
<class 'tensorflow.python.data.ops.dataset_ops.MapDataset'>
),这是一些神经网络 output 的结果。 In order to get my final prediction, I am currently iterating over it as follows:为了得到我的最终预测,我目前正在迭代它,如下所示:
for row in dataset:
ap_distance, an_distance = row
y_pred.append(int(ap_distance.numpy() > an_distance.numpy()))
The dataset has two columns, each holding a scalar wrapped in a tensor.该数据集有两列,每列都包含一个包裹在张量中的标量。
Problem:问题:
The loop body is very simple, it takes < 1e-5 seconds to compute.循环体非常简单,计算时间 < 1e-5 秒。 Sadly, one iteration takes ca.
可悲的是,一次迭代需要大约。 0.3 seconds , so fetching the next row from the dataset seems to take almost 0.3s.
0.3 秒,因此从数据集中获取下一行似乎需要将近 0.3 秒。 This behavior is very weird given my hardware (running on a rented GPU server with 16 AMD EPYC cores and 258GB RAM) and the fact that a colleague on his laptop can finish an iteration in 1-2 orders of magnitude less time than I can.
考虑到我的硬件(在具有 16 个 AMD EPYC 内核和 258GB RAM 的租用 GPU 服务器上运行)以及他笔记本电脑上的同事可以在比我少 1-2 个数量级的时间内完成迭代的事实,这种行为非常奇怪。 The dataset has ca, 60k rows.
数据集有 ca,60k 行。 hence it is unacceptable to wait for so long.
因此,等待这么久是不可接受的。
What I tried:我尝试了什么:
I tried mapping the above loop body onto every row of the dataset object, but sadly .numpy()
is not available inside dataset.map()
!我尝试将上述循环体映射到数据集 object 的每一行,但遗憾的是
.numpy()
在dataset.map()
中不可用!
Questions:问题:
The solution to the performance degradation is to vectorize your operations.性能下降的解决方案是矢量化您的操作。
ie IE
dataset=dataset.batch(len(dataset))
vec_ap_distance,vec_an_distance=next(iter(dataset))
y_pred=tf.cast(vec_ap_distance>vec_an_distance,tf.int32).numpy()
The major reason for the long running time in your codes is that there are too many unnecessary numpy() calls.代码运行时间长的主要原因是有太多不必要的 numpy() 调用。
If there are 60k rows, there will be 120k numpy() calls, while in my code, there can be only one numpy() call.如果有 60k 行,将有 120k numpy() 调用,而在我的代码中,只能有一个 numpy() 调用。
In summary:总之:
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.