[英]using tf.nn.dynamic_rnn to make LSTM RNN of multiple hidden layers
[英]Tensorflow: What's the difference between the use of tf.mat_fn() or tf.nn.dynamic_rnn() to apply layers before an LSTM?
這個問題是關於使用Tensorflow的編碼策略。 我想創建一個小分類器網絡:
在tensorflow中,要使用類tf.nn.dynamic_rnn(),我們需要向網絡發送一批序列。 到目前為止,它的工作完美(我喜歡這個庫)。
但是因為我想在我的序列的每個特征上應用一個簡單的層(在我的描述中的第二層),我想知道:
要么
通常,它應該給我相同的結果? 如果是這樣的話,我該怎么用?
感謝您的時間 !
我最近遇到了類似的情況,我想鏈接循環和非循環圖層。
我是否在LSTM圖層之前加上這個簡單圖層並將它們傳遞給tf.nn.dynamic_rnn()操作...
這不行。 函數dynamic_rnn
期望一個單元格作為其第一個參數。 單元格是繼承自tf.nn.rnn_cell.RNNCell
的類。 此外, dynamic_rnn
的第二個輸入參數應該是具有至少3個維度的張量,其中前兩個維度是批次和時間( time_major=False
)或時間和批次( time_major=True
)。
我是否使用函數tf.map_fn()兩次(一個用於解包批處理,一個用於解包序列),如果理解得好,則能夠解壓縮序列並在每個要素線上應用圖層。
這可能有用,但在我看來並不是一個有效和干凈的解決方案。 首先,沒有必要“解包批次”,因為您可能希望對批次的特征和時間步驟執行某些操作,其中批次中的每個觀察都獨立於其他觀察。
我對這個特殊問題的解決方案是創建一個tf.nn.rnn_cell.RNNCell
的子類。 在我的例子中,我想要一個簡單的前饋層,它將迭代所有時間步驟,並且可以在dynamic_rnn
:
import tensorflow as tf
class FeedforwardCell(tf.nn.rnn_cell.RNNCell):
"""A stateless feedforward cell that can be used with MultiRNNCell
"""
def __init__(self, num_units, activation=tf.tanh, dtype=tf.float32):
self._num_units = num_units
self._activation = activation
# Store a dummy state to make dynamic_rnn happy.
self.dummy = tf.constant([[0.0]], dtype=dtype)
@property
def state_size(self):
return 1
@property
def output_size(self):
return self._num_units
def zero_state(self, batch_size, dtype):
return self.dummy
def __call__(self, inputs, state, scope=None):
"""Basic feedforward: output = activation(W * input)."""
with tf.variable_scope(scope or type(self).__name__): # "FeedforwardCell"
output = self._activation(tf.nn.rnn_cell._linear(
[inputs], self._num_units, True))
return output, self.dummy
可以在具有“正常”RNN單元的列表中將該類的實例傳遞給tf.nn.rnn_cell.MultiRNNCell
初始化器。 生成的對象實例可以作為cell
輸入參數傳遞給dynamic_rnn
。
需要注意的重要事項: dynamic_rnn
期望循環單元在調用時返回一個狀態。 因此,我在FeedforwardCell
使用dummy
作為偽狀態變量。
我的解決方案可能不是將復發和非復發層鏈接在一起的最流暢或最好的方法。 我有興趣聽取其他Tensorflow用戶的建議。
編輯如果選擇使用dynamic_rnn
的sequence_length
輸入參數,則state_size
應為self._num_units
, dummy
狀態應為shape [batch_size, self.state_size]
。 換句話說,國家不能成為標量。 請注意, bidirectional_dynamic_rnn
要求sequence_length
參數不是None
,而dynamic_rnn
則沒有此要求。 (這在TF文檔中很少記錄。)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.