繁体   English   中英

实现自定义 LSTM 时图断开连接

[英]Graph Disconnected when implementing custom LSTM

我一直在尝试编写自己的 LSTM 以进行自定义。 但是,当我尝试使用 Keras 调用我的代码时发生错误。 该错误表示图在c_prev断开连接,但c_prev被用作 LSTM 的单元初始化器。 所以我不确定我的代码或我调用模型的方式是否有问题。 任何帮助表示赞赏。

我的环境:

  • 蟒蛇 3.7.6
  • Tensorflow 2.1.0(通过 pip 安装)
  • 麦克莫哈韦
class EtienneLSTM(tf.keras.layers.Layer):
    def __init__(self, units, activation='tanh', recurrent_activation='sigmoid',
    kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', 
    use_bias=True, unit_forget_bias=True, 
    kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, activity_regularizer=None,
    kernel_constraint=None, recurrent_constraint=None, bias_constraint=None,
    # dropout=0.0, recurrent_dropout=0.0,
    return_sequences=False, return_state=False, go_backwards=False, use_batchnorm=False):
        super(EtienneLSTM, self).__init__()
        self.units = units #

        self.activation = tf.keras.layers.Activation(activation) #
        self.recurrent_activation = tf.keras.layers.Activation(recurrent_activation) #

        self.use_bias = use_bias #

        self.kernel_initializer = kernel_initializer #
        self.recurrent_initializer =  recurrent_initializer #
        self.bias_initializer = bias_initializer #
        self.unit_forget_bias = unit_forget_bias #
        if self.unit_forget_bias:
            self.bias_initializer = 'zeros'

        self.kernel_regularizer = kernel_regularizer #
        self.recurrent_regularizer = recurrent_regularizer #
        self.bias_regularizer = bias_regularizer #
        self.activity_regularizer = activity_regularizer

        self.kernel_constraint = kernel_constraint #
        self.recurrent_constraint = recurrent_constraint #
        self.bias_constraint = bias_constraint #

        # self.dropout = dropout
        # self.recurrent_dropout = recurrent_dropout

        self.return_sequences = return_sequences #
        self.return_state = return_state #
        self.go_backwards = go_backwards #

        self.use_batchnorm = use_batchnorm
        if self.use_batchnorm:
            self.batchnorm_f = tf.keras.layers.BatchNormalization()
            self.batchnorm_i = tf.keras.layers.BatchNormalization()
            self.batchnorm_o = tf.keras.layers.BatchNormalization()
            self.batchnorm_c = tf.keras.layers.BatchNormalization()

    def build(self, input_shape):
        # forgot gate
        self.Wf = self.add_weight(shape=(input_shape[-1], self.units), initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, trainable=True)
        self.Uf = self.add_weight(shape=(self.units, self.units), initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint, trainable=True)
        if self.unit_forget_bias:
            self.bf = self.add_weight(shape=(self.units,), initializer='ones', regularizer=self.bias_regularizer, constraint=self.bias_constraint, trainable=True)
        else:
            self.bf = self.add_weight(shape=(self.units,), initializer=self.bias_initializer, regularizer=self.bias_regularizer, trainable=True)
        # input gate
        self.Wi = self.add_weight(shape=(input_shape[-1], self.units), initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, trainable=True)
        self.Ui = self.add_weight(shape=(self.units, self.units), initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint, trainable=True)
        if self.use_bias:
            self.bi = self.add_weight(shape=(self.units,), initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint, trainable=True)

        # output gate
        self.Wo = self.add_weight(shape=(input_shape[-1], self.units), initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, trainable=True)
        self.Uo = self.add_weight(shape=(self.units, self.units), initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint, trainable=True)
        if self.use_bias:
            self.bo = self.add_weight(shape=(self.units,), initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint, trainable=True)

        # context
        self.Wc = self.add_weight(shape=(input_shape[-1], self.units), initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, trainable=True)
        self.Uc = self.add_weight(shape=(self.units, self.units), initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint, trainable=True)
        if self.use_bias:
            self.bc = self.add_weight(shape=(self.units,), initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint, trainable=True)

    def _inp_gate(self, x, hidden):
        return self.recurrent_activation(tf.matmul(x, self.Wi) + tf.matmul(hidden, self.Ui) + self.bi)

    def _new_mem(self, x, hidden):
        return self.activation(tf.matmul(x, self.Wc) + tf.matmul(hidden, self.Uc) + self.bc)

    def _forget_gate(self, x, hidden):
        return self.recurrent_activation(tf.matmul(x, self.Wf) + tf.matmul(hidden, self.Uf) + self.bf)

    def _update_cell(self, c_prev, c_tilde, f_t, i_t):
        return (f_t * c_prev) + (i_t * c_tilde)

    def _out_gate(self, x, hidden, ct):
        ot = self.recurrent_activation(tf.matmul(x, self.Wo) + tf.matmul(hidden, self.Uo) + self.bo)
        return ot * self.activation(ct)

    def call(self, x, hidden, c_prev):
        if self.go_backwards: x = x[:,:,::-1]

        f_t = self._forget_gate(x, hidden)
        i_t = self._inp_gate(x, hidden)
        c_tilde = self._new_mem(x, hidden)
        c_t = self._update_cell(c_prev, c_tilde, f_t, i_t)
        h_t = self._out_gate(x, hidden, c_t)

        # if self.return_state:
        #     return h_t, c_t
        # if self.return_sequences:
        #     return h_t
        return h_t
tf.keras.backend.clear_session()

def get_LSTM():
    inp = tf.keras.layers.Input(shape=(200, 40))
    out = tf.keras.layers.LSTM(32)(inp)
    return tf.keras.Model(inp, out)

def get_EtienneLSTM():
    inp = tf.keras.layers.Input(shape=(200, 40))
    h0 = tf.keras.layers.Input(shape=(32,), name='h0')
    c0 = tf.keras.layers.Input(shape=(32,), name='c0')
    out = EtienneLSTM(32)(inp, h0, c0)
    return tf.keras.Model(inp, out)

model_tf = get_LSTM()
model_etienne = get_EtienneLSTM()

这是我的错误信息:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
 in 
     14 
     15 model_tf = get_LSTM()
---> 16 model_etienne = get_EtienneLSTM()

 in get_EtienneLSTM()
     11     c0 = tf.keras.layers.Input(shape=(32,), name='c0')
     12     out = EtienneLSTM(32)(inp, h0, c0)
---> 13     return tf.keras.Model(inp, out)
     14 
     15 model_tf = get_LSTM()

~/.env/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/training.py in __init__(self, *args, **kwargs)
    144 
    145   def __init__(self, *args, **kwargs):
--> 146     super(Model, self).__init__(*args, **kwargs)
    147     _keras_api_gauge.get_cell('model').set(True)
    148     # initializing _distribution_strategy here since it is possible to call

~/.env/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/network.py in __init__(self, *args, **kwargs)
    167         'inputs' in kwargs and 'outputs' in kwargs):
    168       # Graph network
--> 169       self._init_graph_network(*args, **kwargs)
    170     else:
    171       # Subclassed network

~/.env/lib/python3.7/site-packages/tensorflow_core/python/training/tracking/base.py in _method_wrapper(self, *args, **kwargs)
    455     self._self_setattr_tracking = False  # pylint: disable=protected-access
    456     try:
--> 457       result = method(self, *args, **kwargs)
    458     finally:
    459       self._self_setattr_tracking = previous_value  # pylint: disable=protected-access

~/.env/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/network.py in _init_graph_network(self, inputs, outputs, name, **kwargs)
    322     # Keep track of the network's nodes and layers.
    323     nodes, nodes_by_depth, layers, _ = _map_graph_network(
--> 324         self.inputs, self.outputs)
    325     self._network_nodes = nodes
    326     self._nodes_by_depth = nodes_by_depth

~/.env/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/network.py in _map_graph_network(inputs, outputs)
   1674                              'The following previous layers '
   1675                              'were accessed without issue: ' +
-> 1676                              str(layers_with_complete_input))
   1677         for x in nest.flatten(node.output_tensors):
   1678           computable_tensors.add(id(x))

ValueError: Graph disconnected: cannot obtain value for tensor Tensor("c0:0", shape=(None, 32), dtype=float32) at layer "c0". The following previous layers were accessed without issue: ['input_2']

感谢您的帮助。

已解决,看来我实现 LSTM 的方式不对。 LSTM的正确实现方法如下:

class EtienneLSTM(tf.keras.layers.Layer):
    def __init__(self, units, activation='tanh', recurrent_activation='sigmoid',
    kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', 
    use_bias=True, unit_forget_bias=True, 
    kernel_regularizer=None, recurrent_regularizer=None, bias_regularizer=None, activity_regularizer=None,
    kernel_constraint=None, recurrent_constraint=None, bias_constraint=None,
    # dropout=0.0, recurrent_dropout=0.0,
    return_sequences=False, return_state=False, go_backwards=False, use_batchnorm=False):
        super(EtienneLSTM, self).__init__()
        self.units = units #

        self.activation = tf.keras.layers.Activation(activation) #
        self.recurrent_activation = tf.keras.layers.Activation(recurrent_activation) #

        self.use_bias = use_bias #

        self.kernel_initializer = kernel_initializer #
        self.recurrent_initializer =  recurrent_initializer #
        self.bias_initializer = bias_initializer #
        self.unit_forget_bias = unit_forget_bias #
        if self.unit_forget_bias:
            self.bias_initializer = 'zeros'

        self.kernel_regularizer = kernel_regularizer #
        self.recurrent_regularizer = recurrent_regularizer #
        self.bias_regularizer = bias_regularizer #
        self.activity_regularizer = activity_regularizer

        self.kernel_constraint = kernel_constraint #
        self.recurrent_constraint = recurrent_constraint #
        self.bias_constraint = bias_constraint #

        # self.dropout = dropout
        # self.recurrent_dropout = recurrent_dropout

        self.return_sequences = return_sequences #
        self.return_state = return_state #
        self.go_backwards = go_backwards #

        self.use_batchnorm = use_batchnorm
        if self.use_batchnorm:
            self.batchnorm_f = tf.keras.layers.BatchNormalization()
            self.batchnorm_i = tf.keras.layers.BatchNormalization()
            self.batchnorm_o = tf.keras.layers.BatchNormalization()
            self.batchnorm_c = tf.keras.layers.BatchNormalization()

    def build(self, input_shape):
        # forgot gate
        self.Wf = self.add_weight(shape=(input_shape[-1], self.units), initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, trainable=True)
        self.Uf = self.add_weight(shape=(self.units, self.units), initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint, trainable=True)
        if self.unit_forget_bias:
            self.bf = self.add_weight(shape=(self.units,), initializer='ones', regularizer=self.bias_regularizer, constraint=self.bias_constraint, trainable=True)
        else:
            self.bf = self.add_weight(shape=(self.units,), initializer=self.bias_initializer, regularizer=self.bias_regularizer, trainable=True)
        # input gate
        self.Wi = self.add_weight(shape=(input_shape[-1], self.units), initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, trainable=True)
        self.Ui = self.add_weight(shape=(self.units, self.units), initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint, trainable=True)
        if self.use_bias:
            self.bi = self.add_weight(shape=(self.units,), initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint, trainable=True)

        # output gate
        self.Wo = self.add_weight(shape=(input_shape[-1], self.units), initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, trainable=True)
        self.Uo = self.add_weight(shape=(self.units, self.units), initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint, trainable=True)
        if self.use_bias:
            self.bo = self.add_weight(shape=(self.units,), initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint, trainable=True)

        # context
        self.Wc = self.add_weight(shape=(input_shape[-1], self.units), initializer=self.kernel_initializer, regularizer=self.kernel_regularizer, constraint=self.kernel_constraint, trainable=True)
        self.Uc = self.add_weight(shape=(self.units, self.units), initializer=self.recurrent_initializer, regularizer=self.recurrent_regularizer, constraint=self.recurrent_constraint, trainable=True)
        if self.use_bias:
            self.bc = self.add_weight(shape=(self.units,), initializer=self.bias_initializer, regularizer=self.bias_regularizer, constraint=self.bias_constraint, trainable=True)

    def _inp_gate(self, x, hidden):
        return self.recurrent_activation(tf.matmul(x, self.Wi) + tf.matmul(hidden, self.Ui) + self.bi)

    def _new_mem(self, x, hidden):
        return self.activation(tf.matmul(x, self.Wc) + tf.matmul(hidden, self.Uc) + self.bc)

    def _forget_gate(self, x, hidden):
        return self.recurrent_activation(tf.matmul(x, self.Wf) + tf.matmul(hidden, self.Uf) + self.bf)

    def _update_cell(self, c_prev, c_tilde, f_t, i_t):
        return (f_t * c_prev) + (i_t * c_tilde)

    def _out_gate(self, x, hidden, ct):
        ot = self.recurrent_activation(tf.matmul(x, self.Wo) + tf.matmul(hidden, self.Uo) + self.bo)
        return ot * self.activation(ct)

    def step_function(self, x_t, states):
        h_t, c_t = states
        f_t = self._forget_gate(x_t, h_t)
        i_t = self._inp_gate(x_t, h_t)
        c_tilde = self._new_mem(x_t, h_t)
        c_t = self._update_cell(c_t, c_tilde, f_t, i_t)
        h_t = self._out_gate(x_t, h_t, c_t)
        return h_t, [h_t, c_t]

    def call(self, x):
        if self.go_backwards: x = x[:,:,::-1]

        h_init = tf.zeros((tf.shape(x)[0], self.units))
        c_init = tf.zeros((tf.shape(x)[0], self.units))
        h, H, c = tf.keras.backend.rnn(self.step_function, x, (h_init, c_init))

        if self.return_state:
            return h, c
        if self.return_sequences:
            return H
        return h

这是指这个问题。

需要使用tf.keras.backend.rnn

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM