簡體   English   中英

Tensorflow在python for循環中太慢了

[英]Tensorflow is too slow in a python for loop

我想在Tensorflow中創建一個函數,對於給定數據X的每一行,僅對某些采樣類應用softmax函數,比如說,K個類中的2個,並返回一個矩陣S,其中S.shape = (N,K) (N:給定數據的行數,K為總類別)。

矩陣S最終將包含零,並且在采樣類為每一行定義的索引中包含非零值。

在簡單的python中我使用高級索引 ,但在Tensorflow中我無法弄清楚如何制作它。 我最初的問題是,我提出了numpy代碼

所以我試圖在Tensorflow中找到一個解決方案,主要思想不是將S用作二維矩陣而是用作一維陣列。 代碼看起來像這樣:

num_samps = 2
S = tf.Variable(tf.zeros(shape=(N*K)))
W = tf.Variable(tf.random_uniform((K,D)))
tfx = tf.placeholder(tf.float32,shape=(None,D))
sampled_ind = tf.random_uniform(dtype=tf.int32, minval=0, maxval=K-1, shape=[num_samps])
ar_to_sof = tf.matmul(tfx,tf.gather(W,sampled_ind),transpose_b=True)
updates = tf.reshape(tf.nn.softmax(ar_to_sof),shape=(num_samps,))
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
for line in range(N):
    inds_new = sampled_ind + line*K
    sess.run(tf.scatter_update(S,inds_new,updates), feed_dict={tfx: X[line:line+1]})

S = tf.reshape(S,shape=(N,K))

這是有效的,結果是預期的。 但它運行得非常慢 為什么會這樣? 我怎么能更快地完成這項工作?

在張量流編程時,學習定義操作和執行它們之間的區別至關重要。 大多數函數以tf.開頭tf. ,當你在python中運行添加操作到計算圖

例如,當你這樣做時:

tf.scatter_update(S,inds_new,updates)

以及:

inds_new = sampled_ind + line*K

多次,您的計算圖形增長超出了必要的范圍,填補了所有內存並大大減慢了速度。

你應該做的是在循環之前定義一次計算:

init = tf.initialize_all_variables()
inds_new = sampled_ind + line*K
update_op = tf.scatter_update(S, inds_new, updates)
sess = tf.Session()
sess.run(init)
for line in range(N):
    sess.run(update_op, feed_dict={tfx: X[line:line+1]})

這樣,您的計算圖只包含inds_newupdate_op一個副本。 請注意,當您執行update_opinds_new也將被隱式執行,因為它是計算圖中的父級。

您還應該知道update_op每次運行時可能會有不同的結果,並且它很好並且預期。

順便說一下,調試此類問題的一種好方法是使用張量板可視化計算圖。 在代碼中添加:

summary_writer = tf.train.SummaryWriter('some_logdir', sess.graph_def)

然后在控制台中運行:

tensorboard --logdir=some_logdir

在服務的html頁面上會有一張計算圖的圖片,你可以在那里檢查你的張量。

請記住,tf.scatter_update將返回Tensor S,這意味着會話運行中的大內存副本,甚至是分布式環境中的網絡副本。 解決方案是基於@ sygi的答案:

update_op = tf.scatter_update(S, inds_new, updates)
update_op_op = update_op.op

然后在會話運行中,您執行此操作

sess.run(update_op_op)

這將避免復制大Tensor S.

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM