繁体   English   中英

Tensorflow在python for循环中太慢了

[英]Tensorflow is too slow in a python for loop

我想在Tensorflow中创建一个函数,对于给定数据X的每一行,仅对某些采样类应用softmax函数,比如说,K个类中的2个,并返回一个矩阵S,其中S.shape = (N,K) (N:给定数据的行数,K为总类别)。

矩阵S最终将包含零,并且在采样类为每一行定义的索引中包含非零值。

在简单的python中我使用高级索引 ,但在Tensorflow中我无法弄清楚如何制作它。 我最初的问题是,我提出了numpy代码

所以我试图在Tensorflow中找到一个解决方案,主要思想不是将S用作二维矩阵而是用作一维阵列。 代码看起来像这样:

num_samps = 2
S = tf.Variable(tf.zeros(shape=(N*K)))
W = tf.Variable(tf.random_uniform((K,D)))
tfx = tf.placeholder(tf.float32,shape=(None,D))
sampled_ind = tf.random_uniform(dtype=tf.int32, minval=0, maxval=K-1, shape=[num_samps])
ar_to_sof = tf.matmul(tfx,tf.gather(W,sampled_ind),transpose_b=True)
updates = tf.reshape(tf.nn.softmax(ar_to_sof),shape=(num_samps,))
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for line in range(N):
    inds_new = sampled_ind + line*K
    sess.run(tf.scatter_update(S,inds_new,updates), feed_dict={tfx: X[line:line+1]})

S = tf.reshape(S,shape=(N,K))

这是有效的,结果是预期的。 但它运行得非常慢 为什么会这样? 我怎么能更快地完成这项工作?

在张量流编程时,学习定义操作和执行它们之间的区别至关重要。 大多数函数以tf.开头tf. ,当你在python中运行添加操作到计算图

例如,当你这样做时:

tf.scatter_update(S,inds_new,updates)

以及:

inds_new = sampled_ind + line*K

多次,您的计算图形增长超出了必要的范围,填补了所有内存并大大减慢了速度。

你应该做的是在循环之前定义一次计算:

init = tf.initialize_all_variables()
inds_new = sampled_ind + line*K
update_op = tf.scatter_update(S, inds_new, updates)
sess = tf.Session()
sess.run(init)
for line in range(N):
    sess.run(update_op, feed_dict={tfx: X[line:line+1]})

这样,您的计算图只包含inds_newupdate_op一个副本。 请注意,当您执行update_opinds_new也将被隐式执行,因为它是计算图中的父级。

您还应该知道update_op每次运行时可能会有不同的结果,并且它很好并且预期。

顺便说一下,调试此类问题的一种好方法是使用张量板可视化计算图。 在代码中添加:

summary_writer = tf.train.SummaryWriter('some_logdir', sess.graph_def)

然后在控制台中运行:

tensorboard --logdir=some_logdir

在服务的html页面上会有一张计算图的图片,你可以在那里检查你的张量。

请记住,tf.scatter_update将返回Tensor S,这意味着会话运行中的大内存副本,甚至是分布式环境中的网络副本。 解决方案是基于@ sygi的答案:

update_op = tf.scatter_update(S, inds_new, updates)
update_op_op = update_op.op

然后在会话运行中,您执行此操作

sess.run(update_op_op)

这将避免复制大Tensor S.

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM