简体   繁体   English

Tensorflow在python for循环中太慢了

[英]Tensorflow is too slow in a python for loop

I would like to create in Tensorflow , a function that for every line of a given data X, is applying the softmax function only for some sampled classes, lets say 2, out of K total classes, and returns a matrix S, where S.shape = (N,K) (N: the lines of the given Data and K the total classes). 我想在Tensorflow中创建一个函数,对于给定数据X的每一行,仅对某些采样类应用softmax函数,比如说,K个类中的2个,并返回一个矩阵S,其中S.shape = (N,K) (N:给定数据的行数,K为总类别)。

The matrix S finally would contain zeros, and non_zero values in the indexes defined for every line by the sampled classes. 矩阵S最终将包含零,并且在采样类为每一行定义的索引中包含非零值。

In simple python I use advanced indexing , but in Tensorflow I cannot figure out how to make it. 在简单的python中我使用高级索引 ,但在Tensorflow中我无法弄清楚如何制作它。 My initial question was this, where I present the numpy code . 我最初的问题是,我提出了numpy代码

So I tried to find a solution in Tensorflow and the main idea was not to use the S as a 2-d matrix but as an 1-d array. 所以我试图在Tensorflow中找到一个解决方案,主要思想不是将S用作二维矩阵而是用作一维阵列。 The code looks like that: 代码看起来像这样:

num_samps = 2
S = tf.Variable(tf.zeros(shape=(N*K)))
W = tf.Variable(tf.random_uniform((K,D)))
tfx = tf.placeholder(tf.float32,shape=(None,D))
sampled_ind = tf.random_uniform(dtype=tf.int32, minval=0, maxval=K-1, shape=[num_samps])
ar_to_sof = tf.matmul(tfx,tf.gather(W,sampled_ind),transpose_b=True)
updates = tf.reshape(tf.nn.softmax(ar_to_sof),shape=(num_samps,))
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for line in range(N):
    inds_new = sampled_ind + line*K
    sess.run(tf.scatter_update(S,inds_new,updates), feed_dict={tfx: X[line:line+1]})

S = tf.reshape(S,shape=(N,K))

This is working, and the result is the expected. 这是有效的,结果是预期的。 But it is running extremely slow . 但它运行得非常慢 Why is that happening? 为什么会这样? How could I make that work faster? 我怎么能更快地完成这项工作?

While programming in tensorflow, it is crucial to learn a distinction between defining operations and executing them. 在张量流编程时,学习定义操作和执行它们之间的区别至关重要。 Most of the functions starting with tf. 大多数函数以tf.开头tf. , when you run in python add operations to the computation graph . ,当你在python中运行添加操作到计算图

For example, when you do: 例如,当你这样做时:

tf.scatter_update(S,inds_new,updates)

as well as: 以及:

inds_new = sampled_ind + line*K

multiple times, your computation graph grows beyond what is necessary, filling all the memory and slowing things down enormously. 多次,您的计算图形增长超出了必要的范围,填补了所有内存并大大减慢了速度。

What you should do instead is to define the computation one time, before the loop: 你应该做的是在循环之前定义一次计算:

init = tf.initialize_all_variables()
inds_new = sampled_ind + line*K
update_op = tf.scatter_update(S, inds_new, updates)
sess = tf.Session()
sess.run(init)
for line in range(N):
    sess.run(update_op, feed_dict={tfx: X[line:line+1]})

This way your computation graph contains only one copy of the inds_new and update_op . 这样,您的计算图只包含inds_newupdate_op一个副本。 Note that when you execute update_op , the inds_new will be implicitly executed too, as it is its parent in the computation graph. 请注意,当您执行update_opinds_new也将被隐式执行,因为它是计算图中的父级。

You should also know that update_op will probably have different results each time it is run and it is fine and expected. 您还应该知道update_op每次运行时可能会有不同的结果,并且它很好并且预期。

By the way, a good way to debug this kind of problem is to visualize the computation graph using tensorboard. 顺便说一下,调试此类问题的一种好方法是使用张量板可视化计算图。 In code you add: 在代码中添加:

summary_writer = tf.train.SummaryWriter('some_logdir', sess.graph_def)

and then run in console: 然后在控制台中运行:

tensorboard --logdir=some_logdir

on the served html page there will be a picture of computation graph, where you can examine your tensors. 在服务的html页面上会有一张计算图的图片,你可以在那里检查你的张量。

Keep in mind that tf.scatter_update will return the Tensor S, which means a large memory copy in session run, or even network copy in distributed environment. 请记住,tf.scatter_update将返回Tensor S,这意味着会话运行中的大内存副本,甚至是分布式环境中的网络副本。 The solution is, based on @sygi's answer: 解决方案是基于@ sygi的答案:

update_op = tf.scatter_update(S, inds_new, updates)
update_op_op = update_op.op

Then in session run, you do this 然后在会话运行中,您执行此操作

sess.run(update_op_op)

This will avoid copy the large Tensor S. 这将避免复制大Tensor S.

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM