[英]Tensorflow RNN how to create zero state with various batch size?
[英]Tensorflow: how to use RNN initial state in an estimator with different batch size for training and testing?
我正在使用RNN(GRUCell)進行Tensorflow估計器。 我使用zero_state初始化第一個狀態,它需要固定的大小。 我的問題是我希望能夠使用估計量對單個樣本(batchsize = 1)進行預測。 當加載序列化的估計量時,它抱怨我用於預測的批次大小與訓練的批次大小不匹配。
如果我以不同的批處理大小重建估算器,則無法加載已序列化的估算器。
有沒有一種優雅的方法可以在估算器中使用zero_state? 我看到一些使用變量存儲批次大小的解決方案,但使用feed_dict方法。 我沒有找到如何使它在估算器的上下文中起作用的方法。
這是估算器中我的簡單測試RNN的核心:
cells = [ tf.nn.rnn_cell.GRUCell(self.getNSize()) for _ in range(self.getNLayers())]
multicell = tf.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=False)
H_init = tf.Variable( multicell.zero_state( batchsize, dtype=tf.float32 ), trainable=False)
H = tf.Variable( H_init )
Yr, state = tf.nn.dynamic_rnn(multicell, Xo, dtype=tf.float32, initial_state=H)
有人對此有線索嗎?
編輯:
好的,我對此問題嘗試了各種方法。 現在,我嘗試過濾從檢查點加載的變量以刪除“ H”,該“ H”用作循環單元的內部狀態。 為了進行預測,我可以保留所有0值。
到目前為止,我已經做到了:首先定義一個鈎子:
class RestoreHook(tf.train.SessionRunHook):
def __init__(self, init_fn):
self.init_fn = init_fn
def after_create_session(self, session, coord=None):
print("--------------->After create session.")
self.init_fn(session)
然后在我的model_fn中:
if mode == tf.estimator.ModeKeys.PREDICT:
logits = tf.nn.softmax(logits)
# Do not restore H as it's batch size might be different.
vlist = tf.contrib.framework.get_variables_to_restore()
vlist = [ x for x in vlist if x.name.split(':')[0] != 'architecture/H']
init_fn = tf.contrib.framework.assign_from_checkpoint_fn(tf.train.latest_checkpoint(self.modelDir), vlist, ignore_missing_vars=True)
spec = tf.estimator.EstimatorSpec(mode=mode,
predictions = {
'logits': logits,
},
export_outputs={
'prediction': tf.estimator.export.PredictOutput( logits )
},
prediction_hooks=[RestoreHook(init_fn)])
我從https://github.com/tensorflow/tensorflow/issues/14713獲取了這段代碼
但這還行不通。 看來它仍在嘗試從文件中加載H ...我檢查它不在vlist中。 我仍在尋找解決方案。
您可以從其他張量示例獲取批處理大小
decoder_initial_state = cell.zero_state(array_ops.shape(attention_states)[0], dtypes.float32).clone(cell_state=encoder_state)
我找到了解決方案:
它可以工作,因為這兩個變量都將通過估算器代碼進行序列化和恢復,因此不會出現問題。 缺點是在訓練時(創建兩個變量時)會知道查詢批處理大小(在我的情況下為1)。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.