简体   繁体   中英

How do I visualize the histograms of the layers of an RNN in tensorboard?

I have subclassed RNNCell as the building block of my RNN. I put an instance of this object into tf.dynamic_rnn and then I define a prediction function in my Agent class:

class Agent():
    def __init__(self):
        ...

    def predictions(self):
        cell = RNNCell()
        output, last_state = tf.dynamic_rnn(cell, inputs = ...)
        return output

Everything works fine, but how do I add a histogram for the layers now? I've tried to do it in the RNNCell but it doesn't work:

class RNNCell(tf.nn.rnn_cell.RNNCell):
    def __init__(self):
        super(RNNCell, self).__init__()
        self._output_size = 15
        self._state_size = 15
        self._histogram1 = None

    def __call__(self, X, state):
        network = tflearn.layers.conv_2d(X, 5, [1, 3], activation='relu', weights_init=tflearn.initializations.variance_scaling(), padding="valid")
        self._histogram1 = tf.summary.histogram("layer1_hist_summary", network)
        ...

    @property
    def histogram1(self):
    return self._histogram1

and then

class Agent():
    def __init__(self):
        ...

    def predictions(self):
        cell = RNNCell()
        self.histogram1 = cell.histogram1
        output, last_state = tf.dynamic_rnn(cell, inputs = ...)
        return output

Later when I run sess.run(agent.histogram1, feed_dict=...) I get the error TypeError: Fetch argument None has invalid type <class 'NoneType'>

I think the problem is that the value of Agent's self.histogram1 never got updated to reflect that summary assigned in RNNCell.

Your code for the Agent predictions() method initializes Agent's histogram1 value to None here:

cell = RNNCell()  #invoks __init__() so RNNCELL's histogram1 is now None
self.histogram1 = cell.histogram1

When RNNCell's __call__() method is invoked, it updates the RNNCell's value of histogram1

self._histogram1 = tf.summary.histogram("layer1_hist_summary", network)

But the Agent's copy of histogram1 was apparently not updated, so when the call is made:

sess.run(agent.histogram1, feed_dict=...)

agent.histogram1 is still None.

I don't see in the posted code where the summaries were merged before training, so the missing step is likely in unposted code somewhere.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM