[英]How expensive is to use sess.run() in tensorflow?
我使用tensorflow在python中編寫了機器學習算法。 下圖顯示了算法偽代碼。 在這種算法中,我在訓練循環中多次使用sess.run()。 我必須使用多個sess.run()的原因是,我必須在不同的輸入端評估相同的神經網絡以計算δ。 由於某些原因,我仍然不知道我的代碼非常慢(請參閱codereview , AI以查看代碼和相關問題)。
該圖取自Richard S. Sutton和Andrew G. Barto的《 強化學習入門 》一書。
我對此堆棧的疑問如下:
去做,
sess.run([op1],feed_dict={input:data})
sess.run([op2],feed_dict={input:data})
代替,
sess.run([op1,op2],feed_dict={input:data})
有什么區別嗎?
我目前正在計算δ,如下所示:
self.delta = self.time_step_info['r'] + (not self.time_step_info['d'])*self.gamma*sess.run(self.critic(),feed_dict={self.state_in:self.time_step_info['s1']}) - sess.run(self.critic(),feed_dict={self.state_in:self.time_step_info['s']})
對於您的第一個問題,我不確定。
但是對於第二個問題,您可能已經知道,輸入應該是一個矩陣。 矩陣可以包含多個X
然后NN將生成一個對應的結果矩陣Y
,該矩陣Y
每一行是X
中行的輸出。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.