繁体   English   中英

Tensorflow:如何在具有不同批次大小的估计量中使用RNN初始状态进行训练和测试?

[英]Tensorflow: how to use RNN initial state in an estimator with different batch size for training and testing?

我正在使用RNN(GRUCell)进行Tensorflow估计器。 我使用zero_state初始化第一个状态,它需要固定的大小。 我的问题是我希望能够使用估计量对单个样本(batchsize = 1)进行预测。 当加载序列化的估计量时,它抱怨我用于预测的批次大小与训练的批次大小不匹配。

如果我以不同的批处理大小重建估算器,则无法加载已序列化的估算器。

有没有一种优雅的方法可以在估算器中使用zero_state? 我看到一些使用变量存储批次大小的解决方案,但使用feed_dict方法。 我没有找到如何使它在估算器的上下文中起作用的方法。

这是估算器中我的简单测试RNN的核心:

cells = [  tf.nn.rnn_cell.GRUCell(self.getNSize()) for _ in range(self.getNLayers())]


multicell = tf.nn.rnn_cell.MultiRNNCell(cells, state_is_tuple=False)
H_init = tf.Variable( multicell.zero_state( batchsize, dtype=tf.float32 ), trainable=False)
H = tf.Variable( H_init )

Yr, state = tf.nn.dynamic_rnn(multicell, Xo, dtype=tf.float32, initial_state=H)

有人对此有线索吗?

编辑:

好的,我对此问题尝试了各种方法。 现在,我尝试过滤从检查点加载的变量以删除“ H”,该“ H”用作循环单元的内部状态。 为了进行预测,我可以保留所有0值。

到目前为止,我已经做到了:首先定义一个钩子:

class RestoreHook(tf.train.SessionRunHook):
    def __init__(self, init_fn):
        self.init_fn = init_fn

    def after_create_session(self, session, coord=None):
        print("--------------->After create session.")
        self.init_fn(session)

然后在我的model_fn中:

if mode == tf.estimator.ModeKeys.PREDICT:
        logits = tf.nn.softmax(logits)

        # Do not restore H as it's batch size might be different.
        vlist = tf.contrib.framework.get_variables_to_restore()
        vlist = [ x for x in vlist if x.name.split(':')[0] != 'architecture/H']
        init_fn = tf.contrib.framework.assign_from_checkpoint_fn(tf.train.latest_checkpoint(self.modelDir), vlist, ignore_missing_vars=True)
        spec = tf.estimator.EstimatorSpec(mode=mode,
                                          predictions = {
                                              'logits': logits,
                                          },
                                          export_outputs={
                                              'prediction': tf.estimator.export.PredictOutput( logits )
                                          },
                                          prediction_hooks=[RestoreHook(init_fn)])

我从https://github.com/tensorflow/tensorflow/issues/14713获取了这段代码

但这还行不通。 看来它仍在尝试从文件中加载H ...我检查它不在vlist中。 我仍在寻找解决方案。

您可以从其他张量示例获取批处理大小

decoder_initial_state = cell.zero_state(array_ops.shape(attention_states)[0], dtypes.float32).clone(cell_state=encoder_state)

我找到了解决方案:

  • 我为batchsize = 64和batchsize = 1的初始状态创建变量。
  • 在训练中,我使用第一个初始化RNN。
  • 在预测时间,我使用第二个。

它可以工作,因为这两个变量都将通过估算器代码进行序列化和恢复,因此不会出现问题。 缺点是在训练时(创建两个变量时)会知道查询批处理大小(在我的情况下为1)。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM