简体   繁体   English

如何在数据管道中获取当前的 global_step

[英]How to get current global_step in data pipeline

I am trying to create a filter which depends on the current global_step of the training but I am failing to do so properly.我正在尝试创建一个取决于当前培训global_step的过滤器,但我未能正确执行此操作。

First, I cannot use tf.train.get_or_create_global_step() in the code below because it will throw首先,我不能在下面的代码中使用tf.train.get_or_create_global_step()因为它会抛出

ValueError: Variable global_step already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? Originally defined at:

This is why I tried fetching the scope with tf.get_default_graph().get_name_scope() and within that context I was able to " get " the global step:这就是为什么我尝试使用tf.get_default_graph().get_name_scope()获取范围并且在该上下文中我能够“获取”全局步骤的原因:

def filter_examples(example):
    scope = tf.get_default_graph().get_name_scope()

    with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
        current_step = tf.train.get_or_create_global_step()

    subtokens_by_step = tf.floor(current_step / curriculum_step_update)
    max_subtokens = min_subtokens + curriculum_step_size * tf.cast(subtokens_by_step, dtype=tf.int32)

    return tf.size(example['targets']) <= max_subtokens


dataset = dataset.filter(filter_examples)

The problem with this is that it does not seem to work as I expected.问题在于它似乎不像我预期的那样工作。 From what I am observing, the current_step in the code above seems to be 0 all the time (I don't know that, just based on my observations I assume that).从我观察到的,上面代码中的current_step似乎一直都是 0(我不知道,只是根据我的观察我假设)。

The only thing that seems to make a difference, and it sounds weird, is restarting the training.唯一似乎有所作为,而且听起来很奇怪的事情是重新开始训练。 I think, also based on observations, in that case current_step will be the actual current step of the training at this point.我认为,同样基于观察,在这种情况下current_step将是此时训练的实际当前步骤。 But the value itself won't update as the training continues.但随着训练的继续,值本身不会更新。

If there a way to get the actual value of the current step and use it in my filter like above?如果有办法获取当前步骤的实际值并在我的过滤器中像上面一样使用它?


Environment环境

Tensorflow 1.12.1张量流 1.12.1

As we discussed in the comments, having and updating your own counter might be an alternative to using the global_step variable.正如我们在评论中所讨论的,拥有和更新您自己的计数器可能是使用global_step变量的替代方法。 The counter variable could be updated as follows: counter变量可以更新如下:

op = tf.assign_add(counter, 1)
with tf.control_dependencies(op): 
    # Some operation here before which the counter should be updated

Using tf.control_dependencies allows to "attach" the update of counter to a path within the computational graph.使用tf.control_dependencies允许将counter的更新“附加”到计算图中的路径。 You can then use the counter variable wherever you need it.然后,您可以在任何需要的地方使用counter变量。

If you use variables inside datasets you need to reinitilize iterators in tf 1.x .如果您在数据集中使用变量,则需要在tf 1.x重新初始化迭代器。

iterator = tf.compat.v1.make_initializable_iterator(dataset)
init = iterator.initializer
tensors = iterator.get_next()

with tf.compat.v1.Session() as sess:
    for epoch in range(num_epochs):
        sess.run(init)
        for example in range(num_examples):
            tensor_vals = sess.run(tensors)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM