繁体   English   中英

数据集的Tensorflow估计器问题

[英]Tensorflow estimator issue with datasets

我在TF估算器上遇到了一个奇怪的问题,试图在我的输入函数中使用tf.Dataset。

首先,我的模型如下所示:

    model = tf.estimator.DNNClassifier(
        feature_columns=my_feature_column,
        hidden_units=[hidden_layers, hidden_layers],
        n_classes=n_classes)

而我的特色专栏就是这样

    my_feature_column = [tf.feature_column.numeric_column(key='image', shape=[32, 32, 3])]

现在,如果我像这样进行训练,那么一切都可以正常进行,并且训练只需几秒钟即可完成:

    model.train(
        input_fn=tf.estimator.inputs.numpy_input_fn(
            dict({'image':X_train}),
            y_train,
            shuffle=True),
        steps=nb_epoch)

但是,当我尝试在输入函数中添加tf.Datasets时,它将永远需要运行:

def input_fn(features, labels, batch_size):
    dataset = tf.data.Dataset.from_tensor_slices(({'image':features}, labels))
    return dataset.shuffle(1000).batch(batch_size).repeat()

model.train(
    input_fn=lambda:input_fn(X_train, y_train, batch_size),
    steps=nb_epoch)

任何人都可以看到我在做什么错吗? 应该完全一样吧?

谢谢保罗

您的数据集将无限重复,并且没有默认的最大迭代次数,因此tensorflow不知道何时停止。

return dataset.shuffle(1000).batch(batch_size).repeat() return dataset.shuffle(1000).batch(batch_size).repeat(10) ,这将训练10个历元,并且你会没事的。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM