简体   繁体   中英

What are c_state and m_state in Tensorflow LSTM?

Tensorflow r0.12's documentation for tf.nn.rnn_cell.LSTMCell describes this as the init:

tf.nn.rnn_cell.LSTMCell.__call__(inputs, state, scope=None)

where state is as follows:

state: if state_is_tuple is False, this must be a state Tensor, 2-D, batch x state_size. If state_is_tuple is True, this must be a tuple of state Tensors, both 2-D, with column sizes c_state and m_state.

What aare c_state and m_state and how do they fit into LSTMs? I cannot find reference to them anywhere in the documentation.

Here is a link to that page in the documentation.

I agree that the documentation is unclear. Looking at tf.nn.rnn_cell.LSTMCell.__call__ clarifies (I took the code from TensorFlow 1.0.0):

def __call__(self, inputs, state, scope=None):
    """Run one step of LSTM.

    Args:
      inputs: input Tensor, 2D, batch x num_units.
      state: if `state_is_tuple` is False, this must be a state Tensor,
        `2-D, batch x state_size`.  If `state_is_tuple` is True, this must be a
        tuple of state Tensors, both `2-D`, with column sizes `c_state` and
        `m_state`.
      scope: VariableScope for the created subgraph; defaults to "lstm_cell".

    Returns:
      A tuple containing:

      - A `2-D, [batch x output_dim]`, Tensor representing the output of the
        LSTM after reading `inputs` when previous state was `state`.
        Here output_dim is:
           num_proj if num_proj was set,
           num_units otherwise.
      - Tensor(s) representing the new state of LSTM after reading `inputs` when
        the previous state was `state`.  Same type and shape(s) as `state`.

    Raises:
      ValueError: If input size cannot be inferred from inputs via
        static shape inference.
    """
    num_proj = self._num_units if self._num_proj is None else self._num_proj

    if self._state_is_tuple:
      (c_prev, m_prev) = state
    else:
      c_prev = array_ops.slice(state, [0, 0], [-1, self._num_units])
      m_prev = array_ops.slice(state, [0, self._num_units], [-1, num_proj])

    dtype = inputs.dtype
    input_size = inputs.get_shape().with_rank(2)[1]
    if input_size.value is None:
      raise ValueError("Could not infer input size from inputs.get_shape()[-1]")
    with vs.variable_scope(scope or "lstm_cell",
                           initializer=self._initializer) as unit_scope:
      if self._num_unit_shards is not None:
        unit_scope.set_partitioner(
            partitioned_variables.fixed_size_partitioner(
                self._num_unit_shards))
      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
      lstm_matrix = _linear([inputs, m_prev], 4 * self._num_units, bias=True,
                            scope=scope)
      i, j, f, o = array_ops.split(
          value=lstm_matrix, num_or_size_splits=4, axis=1)

      # Diagonal connections
      if self._use_peepholes:
        with vs.variable_scope(unit_scope) as projection_scope:
          if self._num_unit_shards is not None:
            projection_scope.set_partitioner(None)
          w_f_diag = vs.get_variable(
              "w_f_diag", shape=[self._num_units], dtype=dtype)
          w_i_diag = vs.get_variable(
              "w_i_diag", shape=[self._num_units], dtype=dtype)
          w_o_diag = vs.get_variable(
              "w_o_diag", shape=[self._num_units], dtype=dtype)

      if self._use_peepholes:
        c = (sigmoid(f + self._forget_bias + w_f_diag * c_prev) * c_prev +
             sigmoid(i + w_i_diag * c_prev) * self._activation(j))
      else:
        c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
             self._activation(j))

      if self._cell_clip is not None:
        # pylint: disable=invalid-unary-operand-type
        c = clip_ops.clip_by_value(c, -self._cell_clip, self._cell_clip)
        # pylint: enable=invalid-unary-operand-type

      if self._use_peepholes:
        m = sigmoid(o + w_o_diag * c) * self._activation(c)
      else:
        m = sigmoid(o) * self._activation(c)

      if self._num_proj is not None:
        with vs.variable_scope("projection") as proj_scope:
          if self._num_proj_shards is not None:
            proj_scope.set_partitioner(
                partitioned_variables.fixed_size_partitioner(
                    self._num_proj_shards))
          m = _linear(m, self._num_proj, bias=False, scope=scope)

        if self._proj_clip is not None:
          # pylint: disable=invalid-unary-operand-type
          m = clip_ops.clip_by_value(m, -self._proj_clip, self._proj_clip)
          # pylint: enable=invalid-unary-operand-type

    new_state = (LSTMStateTuple(c, m) if self._state_is_tuple else
                 array_ops.concat([c, m], 1))
    return m, new_state

The key lines are:

c = (sigmoid(f + self._forget_bias) * c_prev + sigmoid(i) *
         self._activation(j))

and

m = sigmoid(o) * self._activation(c)

and

new_state = (LSTMStateTuple(c, m) 

If you compare the code to compute c and m with the LSTM equations (see below), you can see it corresponds to the cell state (typically denoted with c ) and hidden state (typically denoted with h ), respectively:

在此输入图像描述

new_state = (LSTMStateTuple(c, m) indicates that the first element of the returned state tuple is c (cell state aka c_state ), and the second element of the returned state tuple is m (hidden state aka m_state ).

I've stumbled upon same question, here's how I understand it! Minimalistic LSTM example:

import tensorflow as tf

sample_input = tf.constant([[1,2,3]],dtype=tf.float32)

LSTM_CELL_SIZE = 2

lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(LSTM_CELL_SIZE, state_is_tuple=True)
state = (tf.zeros([1,LSTM_CELL_SIZE]),)*2

output, state_new = lstm_cell(sample_input, state)

init_op = tf.global_variables_initializer()

sess = tf.Session()
sess.run(init_op)
print sess.run(output)

Notice that state_is_tuple=True so when passing state to this cell , it needs to be in the tuple form. c_state and m_state are probably "Memory State" and "Cell State", though I honestly am NOT sure, as these terms are only mentioned in the docs. In the code and papers about LSTM - letters h and c are commonly used to denote "output value" and "cell state". http://colah.github.io/posts/2015-08-Understanding-LSTMs/ Those tensors represent combined internal state of the cell, and should be passed together. Old way to do it was to simply concatenate them, and new way is to use tuples.

OLD WAY:

lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(LSTM_CELL_SIZE, state_is_tuple=False)
state = tf.zeros([1,LSTM_CELL_SIZE*2])

output, state_new = lstm_cell(sample_input, state)

NEW WAY:

lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(LSTM_CELL_SIZE, state_is_tuple=True)
state = (tf.zeros([1,LSTM_CELL_SIZE]),)*2

output, state_new = lstm_cell(sample_input, state)

So, basically all we did, is changed state from being 1 tensor of length 4 into two tensors of length 2 . The content remained the same. [0,0,0,0] becomes ([0,0],[0,0]) . (This is supposed to make it faster)

Maybe this excerpt from the code will help

def __call__(self, inputs, state, scope=None):
  """Long short-term memory cell (LSTM)."""
  with vs.variable_scope(scope or type(self).__name__):  # "BasicLSTMCell"
    # Parameters of gates are concatenated into one multiply for efficiency.
    if self._state_is_tuple:
      c, h = state
    else:
      c, h = array_ops.split(1, 2, state)
    concat = _linear([inputs, h], 4 * self._num_units, True)

    # i = input_gate, j = new_input, f = forget_gate, o = output_gate
    i, j, f, o = array_ops.split(1, 4, concat)

    new_c = (c * sigmoid(f + self._forget_bias) + sigmoid(i) *
             self._activation(j))
    new_h = self._activation(new_c) * sigmoid(o)

    if self._state_is_tuple:
      new_state = LSTMStateTuple(new_c, new_h)
    else:
      new_state = array_ops.concat(1, [new_c, new_h])
    return new_h, new_state

https://github.com/tensorflow/tensorflow/blob/r1.2/tensorflow/python/ops/rnn_cell_impl.py

Line #308 - 314

class LSTMStateTuple(_LSTMStateTuple): """Tuple used by LSTM Cells for state_size , zero_state , and output state. Stores two elements: (c, h) , in that order. Only used when state_is_tuple=True . """

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM