[英]Understanding the implementation of ConvLSTM in tensorflow
在convLSTM 單元的 tensorflow 實現中,以下代碼行編寫為:
x_i = self.input_conv(inputs_i, kernel_i, bias_i, padding=self.padding)
x_f = self.input_conv(inputs_f, kernel_f, bias_f, padding=self.padding)
x_c = self.input_conv(inputs_c, kernel_c, bias_c, padding=self.padding)
x_o = self.input_conv(inputs_o, kernel_o, bias_o, padding=self.padding)
h_i = self.recurrent_conv(h_tm1_i, recurrent_kernel_i)
h_f = self.recurrent_conv(h_tm1_f, recurrent_kernel_f)
h_c = self.recurrent_conv(h_tm1_c, recurrent_kernel_c)
h_o = self.recurrent_conv(h_tm1_o, recurrent_kernel_o)
i = self.recurrent_activation(x_i + h_i)
f = self.recurrent_activation(x_f + h_f)
c = f * c_tm1 + i * self.activation(x_c + h_c)
o = self.recurrent_activation(x_o + h_o)
h = o * self.activation(c)
論文中描述的相應方程為:
我無法看到 W_ci、W_cf、W_co C_{t-1}、C_t 是如何在輸入、忘記和 output 門中使用的。 它在哪里用於計算 4 個門?
當然,你在 ConvLSTM 單元的實現中找不到那些,因為它沒有使用窺視孔:
窺孔連接允許門利用之前的內部 state 以及之前隱藏的 state(這是 LSTMCell 的限制)
tf.keras.experimental.PeepholeLSTMCell遵循您在上面發布的方程式,正如您在源代碼中看到的那樣:
x_i, x_f, x_c, x_o = x
h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
i = self.recurrent_activation(
x_i + K.dot(h_tm1_i, self.recurrent_kernel[:, :self.units]) +
self.input_gate_peephole_weights * c_tm1)
f = self.recurrent_activation(x_f + K.dot(
h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2]) +
self.forget_gate_peephole_weights * c_tm1)
c = f * c_tm1 + i * self.activation(x_c + K.dot(
h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3]))
o = self.recurrent_activation(
x_o + K.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]) +
self.output_gate_peephole_weights * c)
或者更清楚,如果您查看tf.compat.v1.nn.rnn_cell.LSTMCell的源代碼:
if self._use_peepholes:
c = (
sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev +
sigmoid(i + self._w_i_diag * c_prev) * self._activation(j))
else:
c = (
sigmoid(f + self._forget_bias) * c_prev +
sigmoid(i) * self._activation(j))
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.