![](/img/trans.png)
[英]Custom Neural Network Implementation on MNIST using Tensorflow 2.0?
[英]Error using model.fit on MNIST deep neural network training TensorFlow 2.0
我正在尝试使用 MNIST 数据集训练深度神经网络,这是我的 jupyter notebook 的一些代码:
第一个块工作正常:
# Select the hyperparameter batch size
BATCH_SIZE = 100
# Batch the train, validatiion and test data
train_data = train_data.batch(BATCH_SIZE)
validation_data = validation_data.batch(num_validation_samples)
test_data = test_data.batch(num_test_samples)
# Transform the validation data into tuples for the inputs and targets
validation_inputs, validation_targets = next(iter(validation_data))
# Defining model hyperparameters
input_size = 784
output_size = 10
hidden_layer_size = 50
# Defining the model
model = tf.keras.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
tf.keras.layers.Dense(hidden_layer_size, activation='relu'),
tf.keras.layers.Dense(hidden_layer_size, activation='relu'),
tf.keras.layers.Dense(output_size, activation='softmax')
])
# Select the optimizer algorithm and the loss function
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
最后一个块是我得到错误的地方
# Defining hyperparameter number of epochs
NUM_EPOCHS = 5
# Training the model
model.fit(train_data, epochs=NUM_EPOCHS, validation_data=(validation_inputs, validation_targets), verbose=2)
Epoch 1/5
---------------------------------------------------------------------------
DataLossError Traceback (most recent call last)
in
2 NUM_EPOCHS = 5
3 # Training the model
----> 4 model.fit(train_data, epochs=NUM_EPOCHS, validation_data=(validation_inputs, validation_targets), validation_steps=1, verbose=2)
~\Anaconda3\envs\PY3-TF2.0\lib\site-packages\tensorflow\python\keras\engine\training.py in _method_wrapper(self, *args, **kwargs)
64 def _method_wrapper(self, *args, **kwargs):
65 if not self._in_multi_worker_mode(): # pylint: disable=protected-access
---> 66 return method(self, *args, **kwargs)
67
68 # Running inside `run_distribute_coordinator` already.
~\Anaconda3\envs\PY3-TF2.0\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
846 batch_size=batch_size):
847 callbacks.on_train_batch_begin(step)
--> 848 tmp_logs = train_function(iterator)
849 # Catch OutOfRangeError for Datasets of unknown size.
850 # This blocks until the batch has finished executing.
~\Anaconda3\envs\PY3-TF2.0\lib\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds)
578 xla_context.Exit()
579 else:
--> 580 result = self._call(*args, **kwds)
581
582 if tracing_count == self._get_tracing_count():
~\Anaconda3\envs\PY3-TF2.0\lib\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds)
609 # In this case we have created variables on the first call, so we run the
610 # defunned version which is guaranteed to never create variables.
--> 611 return self._stateless_fn(*args, **kwds) # pylint: disable=not-callable
612 elif self._stateful_fn is not None:
613 # Release the lock early so that multiple threads can perform the call
~\Anaconda3\envs\PY3-TF2.0\lib\site-packages\tensorflow\python\eager\function.py in __call__(self, *args, **kwargs)
2418 with self._lock:
2419 graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
-> 2420 return graph_function._filtered_call(args, kwargs) # pylint: disable=protected-access
2421
2422 @property
~\Anaconda3\envs\PY3-TF2.0\lib\site-packages\tensorflow\python\eager\function.py in _filtered_call(self, args, kwargs)
1663 if isinstance(t, (ops.Tensor,
1664 resource_variable_ops.BaseResourceVariable))),
-> 1665 self.captured_inputs)
1666
1667 def _call_flat(self, args, captured_inputs, cancellation_manager=None):
~\Anaconda3\envs\PY3-TF2.0\lib\site-packages\tensorflow\python\eager\function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
1744 # No tape is watching; skip to running the function.
1745 return self._build_call_outputs(self._inference_function.call(
-> 1746 ctx, args, cancellation_manager=cancellation_manager))
1747 forward_backward = self._select_forward_and_backward_functions(
1748 args,
~\Anaconda3\envs\PY3-TF2.0\lib\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args, cancellation_manager)
596 inputs=args,
597 attrs=attrs,
--> 598 ctx=ctx)
599 else:
600 outputs = execute.execute_with_cancellation(
~\Anaconda3\envs\PY3-TF2.0\lib\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
58 ctx.ensure_initialized()
59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60 inputs, attrs, num_outputs)
61 except core._NotOkStatusException as e:
62 if name is not None:
DataLossError: truncated record at 15713022
[[node IteratorGetNext (defined at :4) ]] [Op:__inference_train_function_957]
Function call stack:
train_function
我不知道错误到底在哪里。 我已经尝试更改 model.fit 方法的一些 arguments ,例如聚合validation_steps 参数,但它没有用。 在此先感谢您的帮助
我设法通过更改 tensoflow-datasets 模块的版本解决了这个问题,我使用的是版本 3,然后回到版本 2.1,脚本可以正常工作。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.