[英]TensorFlow - shape does not match the shape stored in checkpoint
我是TensorFlow的新手,我正在嘗試對數據進行簡單神經網絡。 我有19列的.csv數據,最后一列是目標列。 它是0或1。
我從這里https://www.tensorflow.org/get_started/estimator開始,並嘗試進行修改以適合我的數據。 我制作了這個。
...
# Data sets
IRIS_TRAINING = "Training.csv"
IRIS_TEST = "Test.csv"
def main():
# Load datasets.
training_set = tf.contrib.learn.datasets.base.load_csv_without_header(
filename=IRIS_TRAINING,
target_dtype=np.int,
features_dtype=np.float32)
test_set = tf.contrib.learn.datasets.base.load_csv_without_header(
filename=IRIS_TEST,
target_dtype=np.int,
features_dtype=np.float32)
# Specify that all features have real-value data
feature_columns = [tf.feature_column.numeric_column("x", shape=[18])]
**# SOMETHING WRONG HERE**
classifier = tf.estimator.DNNClassifier(feature_columns=feature_columns,
hidden_units=[18],
n_classes=2,
model_dir="/tmp/iris_model")
# Define the training inputs
train_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": np.array(training_set.data)},
y=np.array(training_set.target),
num_epochs=None,
shuffle=True)
# Train model.
classifier.train(input_fn=train_input_fn, steps=2000)
# Define the test inputs
test_input_fn = tf.estimator.inputs.numpy_input_fn(
x={"x": np.array(test_set.data)},
y=np.array(test_set.target),
num_epochs=1,
shuffle=False)
# Evaluate accuracy.
accuracy_score = classifier.evaluate(input_fn=test_input_fn)["accuracy"]
print("\nTest Accuracy: {0:f}\n".format(accuracy_score))
if __name__ == "__main__":
main()
我剛剛將隱藏的單位更改為1層,並將形狀更改為18,因為我有18個要素。 但是,我收到此錯誤。
InvalidArgumentError (see above for traceback): tensor_name = dnn/hiddenlayer_0/bias/t_0/Adagrad; shape in shape_and_slice spec [18] does not match the shape stored in checkpoint: [10]
[[Node: save/RestoreV2_1 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](_arg_save/Const_0_0, save/RestoreV2_1/tensor_names, save/RestoreV2_1/shape_and_slices)]]
我相信您的問題model_dir="/tmp/iris_model"
在tf.estimator.DNNClassifier()
中的model_dir="/tmp/iris_model"
中。 實際上,這是在您第一次使用Tensorflow示例數據運行它時,加載並對其保存到該目錄的模型進行重新訓練。 只需取出model_dir="/tmp/iris_model"
部分,該錯誤就會消失。
來源: https : //www.tensorflow.org/api_docs/python/tf/estimator/DNNClassifier
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.