簡體   English   中英

了解tensorflow中ConvLSTM的實現

[英]Understanding the implementation of ConvLSTM in tensorflow

convLSTM 單元的 tensorflow 實現中,以下代碼行編寫為:

    x_i = self.input_conv(inputs_i, kernel_i, bias_i, padding=self.padding)
    x_f = self.input_conv(inputs_f, kernel_f, bias_f, padding=self.padding)
    x_c = self.input_conv(inputs_c, kernel_c, bias_c, padding=self.padding)
    x_o = self.input_conv(inputs_o, kernel_o, bias_o, padding=self.padding)
    h_i = self.recurrent_conv(h_tm1_i, recurrent_kernel_i)
    h_f = self.recurrent_conv(h_tm1_f, recurrent_kernel_f)
    h_c = self.recurrent_conv(h_tm1_c, recurrent_kernel_c)
    h_o = self.recurrent_conv(h_tm1_o, recurrent_kernel_o)

    i = self.recurrent_activation(x_i + h_i)
    f = self.recurrent_activation(x_f + h_f)
    c = f * c_tm1 + i * self.activation(x_c + h_c)
    o = self.recurrent_activation(x_o + h_o)
    h = o * self.activation(c)

論文中描述的相應方程為:

在此處輸入圖像描述

我無法看到 W_ci、W_cf、W_co C_{t-1}、C_t 是如何在輸入、忘記和 output 門中使用的。 它在哪里用於計算 4 個門?

當然,你在 ConvLSTM 單元的實現中找不到那些,因為它沒有使用窺視孔:

窺孔連接允許門利用之前的內部 state 以及之前隱藏的 state(這是 LSTMCell 的限制)

tf.keras.experimental.PeepholeLSTMCell遵循您在上面發布的方程式,正如您在源代碼中看到的那樣:

x_i, x_f, x_c, x_o = x
h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
i = self.recurrent_activation(
    x_i + K.dot(h_tm1_i, self.recurrent_kernel[:, :self.units]) +
    self.input_gate_peephole_weights * c_tm1)
f = self.recurrent_activation(x_f + K.dot(
    h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2]) +
                              self.forget_gate_peephole_weights * c_tm1)
c = f * c_tm1 + i * self.activation(x_c + K.dot(
    h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3]))
o = self.recurrent_activation(
    x_o + K.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]) +
    self.output_gate_peephole_weights * c)

或者更清楚,如果您查看tf.compat.v1.nn.rnn_cell.LSTMCell源代碼

if self._use_peepholes:
  c = (
      sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev +
      sigmoid(i + self._w_i_diag * c_prev) * self._activation(j))
else:
  c = (
      sigmoid(f + self._forget_bias) * c_prev +
      sigmoid(i) * self._activation(j))

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM