簡體   English   中英

Tensorflow 的 Estimator 停止訓練

[英]Tensorflow's Estimator stops training

我正在使用 Tensorflow 的Estimator訓練模型,但在執行評估后 2600 步后它突然停止訓練。 是不是應該繼續訓練直到最后一個 epoch 結束?

def train():
    train_input_func = lambda: input_fn(mode='train')
    eval_input_func = lambda: input_fn(mode='eval')

    est_conf = tf.estimator.RunConfig(cfg.model_dir, save_checkpoints_secs=120)
    estimator = tf.estimator.Estimator(model_fn, cfg.model_dir, est_conf)


    Path(estimator.eval_dir()).mkdir(parents=True, exist_ok=True)
    train_spec = tf.estimator.TrainSpec(input_fn=train_input_func)
    eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_func, throttle_secs=120)
    tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

if __name__ == '__main__':
    train()

這是input_fn函數:

def input_fn(mode=None):
        data_generator = lambda: data_loader.data_generator(mode=mode)

        dataset = tf.data.Dataset.from_generator(data_generator,
                                                 output_types=(tf.int32, tf.int32),
                                                 output_shapes=([None], [None]))

        if mode is 'train':
            dataset.shuffle(cfg.shuffle_buffer).repeat(1000)

        dataset = dataset.padded_batch(cfg.batch_size, padded_shapes=([None],[None])).prefetch(1)

        return dataset

當使用tf.estimator.train_and_evaluate ,為了使max_steps工作,你不應該使用repeat(1000) ,請使用repeat() ,它會無限重復輸入,並且不會拋出OutOfRangeError

首先,您需要在 TrainSpec 定義中指定 max_stps,如下所示:

train_spec = tf.estimator.TrainSpec(input_fn=train_input_func, max_steps=num_steps_you_specify)

第二,當 input_fn 拋出“OutOfRangeError”時,訓練過程將停止,在這種情況下 max_step 將無法按照設計的方式工作。 因此,為了使訓練貫穿整個 epoch,您需要指定 input_fn,如下所示:

dataset = dataset.repeat()# don't specify any number in the repeat()

希望這會幫助你。

問題是我沒有分配dataset.shuffle(cfg.shuffle_buffer).repeat(1000) 這將解決問題:

dataset = dataset.shuffle(cfg.shuffle_buffer).repeat(1000)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM