簡體   English   中英

具有線程的TensorFlow摘要

[英]TensorFlow summaries with threading

我試圖將摘要添加到異步運行的TensorFlow圖中。 我已經在單線程情況下完成了所有工作,但是一旦進入多線程,摘要似乎消失了。 這是我正在嘗試做的一個玩具示例

import tensorflow as tf  # 1.1.0
import threading


class Worker:
    def __init__(self):
        self.x = tf.Variable([1, -2, 3], tf.float32, name='x')
        self.y = tf.Variable([-1, 2, -3], tf.float32, name='y')
        self.dot_product = tf.reduce_sum(tf.multiply(self.x, self.y))
        tf.summary.scalar("Dot_Product", self.dot_product)

    def work(self):
        for i in range(10):
            SESS.run(self.dot_product)

            # Write summary
            summary_str = SESS.run(tf.summary.merge_all())
            WRITER.add_summary(summary_str, i)
            WRITER.flush()

COORD = tf.train.Coordinator()
SESS = tf.Session()
WRITER = tf.summary.FileWriter(SUMMARY_DIR, SESS.graph)

# Single Thread  case
w = Worker()
SESS.run(tf.global_variables_initializer())
print(tf.get_collection(tf.GraphKeys.SUMMARIES))
w.work()

這很好。 但是,如果我使用多線程:

# Multi-thread case
workers = [Worker() for i in range(4)]
SESS.run(tf.global_variables_initializer())
print(tf.get_collection(tf.GraphKeys.SUMMARIES))

worker_threads = []
for worker in workers:
    job = lambda: worker.work()
    t = threading.Thread(target=job)
    t.start()
    worker_threads.append(t)
COORD.join(worker_threads)

每當tf.summary.merge_all()都會出現這樣的錯誤,因為它看不到任何摘要:

Exception in thread Thread-2:
Traceback (most recent call last):
  File "/usr/lib/python3.5/threading.py", line 914, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.5/threading.py", line 862, in run
    self._target(*self._args, **self._kwargs)
  File "/home/anjum/PycharmProjects/junk.py", line 43, in <lambda>
    job = lambda: worker.work()
  File "/home/anjum/PycharmProjects/junk.py", line 22, in work
    summary_str = SESS.run(tf.summary.merge_all())
  File "/usr/local/lib/python3.5/dist-
packages/tensorflow/python/client/session.py", line 778, in run
    run_metadata_ptr)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 969, in _run
fetch_handler = _FetchHandler(self._graph, fetches, feed_dict_string)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 408, in __init__
self._fetch_mapper = _FetchMapper.for_fetch(fetches)
  File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 227, in for_fetch
(fetch, type(fetch)))
TypeError: Fetch argument None has invalid type <class 'NoneType'>

如果我將print(tf.get_collection(tf.GraphKeys.SUMMARIES))放入work() ,則返回一個空列表。 因此,這意味着我的摘要正迷失在某個地方。

有人可以解釋一下如何在多線程中正確使用摘要嗎?

我想我已經知道了-匯總必須改為這樣合並。 我不確定100%為什么TensorFlow如此挑剔

class Worker:
    def __init__(self):
        self.x = tf.Variable([1, -2, 3], tf.float32, name='x')
        self.y = tf.Variable([-1, 2, -3], tf.float32, name='y')
        self.dot_product = tf.reduce_sum(tf.multiply(self.x, self.y))
        tf.summary.scalar("Dot_Product", self.dot_product)

        self.summarise = tf.summary.merge_all()

    def work(self):
        for i in range(10):
            SESS.run(self.dot_product)

            # Write summary
            summary = SESS.run(self.summarise)
            WRITER.add_summary(summary, i)
            WRITER.flush()

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM