簡體   English   中英

Tensorflow:如何在具有不同批次大小的估計量中使用RNN初始狀態進行訓練和測試?

[英]Tensorflow: how to use RNN initial state in an estimator with different batch size for training and testing?

我正在使用RNN(GRUCell)進行Tensorflow估計器。 我使用zero_state初始化第一個狀態,它需要固定的大小。 我的問題是我希望能夠使用估計量對單個樣本(batchsize = 1)進行預測。 當加載序列化的估計量時,它抱怨我用於預測的批次大小與訓練的批次大小不匹配。

如果我以不同的批處理大小重建估算器,則無法加載已序列化的估算器。

有沒有一種優雅的方法可以在估算器中使用zero_state? 我看到一些使用變量存儲批次大小的解決方案,但使用feed_dict方法。 我沒有找到如何使它在估算器的上下文中起作用的方法。

這是估算器中我的簡單測試RNN的核心:

cells = [  tf.nn.rnn_cell.GRUCell(self.getNSize()) for _ in range(self.getNLayers())]


multicell = tf.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=False)
H_init = tf.Variable( multicell.zero_state( batchsize, dtype=tf.float32 ), trainable=False)
H = tf.Variable( H_init )

Yr, state = tf.nn.dynamic_rnn(multicell, Xo, dtype=tf.float32, initial_state=H)

有人對此有線索嗎?

編輯:

好的,我對此問題嘗試了各種方法。 現在,我嘗試過濾從檢查點加載的變量以刪除“ H”,該“ H”用作循環單元的內部狀態。 為了進行預測,我可以保留所有0值。

到目前為止,我已經做到了:首先定義一個鈎子:

class RestoreHook(tf.train.SessionRunHook):
    def __init__(self, init_fn):
        self.init_fn = init_fn

    def after_create_session(self, session, coord=None):
        print("--------------->After create session.")
        self.init_fn(session)

然后在我的model_fn中:

if mode == tf.estimator.ModeKeys.PREDICT:
        logits = tf.nn.softmax(logits)

        # Do not restore H as it's batch size might be different.
        vlist = tf.contrib.framework.get_variables_to_restore()
        vlist = [ x for x in vlist if x.name.split(':')[0] != 'architecture/H']
        init_fn = tf.contrib.framework.assign_from_checkpoint_fn(tf.train.latest_checkpoint(self.modelDir), vlist, ignore_missing_vars=True)
        spec = tf.estimator.EstimatorSpec(mode=mode,
                                          predictions = {
                                              'logits': logits,
                                          },
                                          export_outputs={
                                              'prediction': tf.estimator.export.PredictOutput( logits )
                                          },
                                          prediction_hooks=[RestoreHook(init_fn)])

我從https://github.com/tensorflow/tensorflow/issues/14713獲取了這段代碼

但這還行不通。 看來它仍在嘗試從文件中加載H ...我檢查它不在vlist中。 我仍在尋找解決方案。

您可以從其他張量示例獲取批處理大小

decoder_initial_state = cell.zero_state(array_ops.shape(attention_states)[0], dtypes.float32).clone(cell_state=encoder_state)

我找到了解決方案:

  • 我為batchsize = 64和batchsize = 1的初始狀態創建變量。
  • 在訓練中,我使用第一個初始化RNN。
  • 在預測時間,我使用第二個。

它可以工作,因為這兩個變量都將通過估算器代碼進行序列化和恢復,因此不會出現問題。 缺點是在訓練時(創建兩個變量時)會知道查詢批處理大小(在我的情況下為1)。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM