繁体   English   中英

在 MNIST 深度神经网络训练 TensorFlow 2.0 上使用 model.fit 时出错

[英]Error using model.fit on MNIST deep neural network training TensorFlow 2.0

我正在尝试使用 MNIST 数据集训练深度神经网络,这是我的 jupyter notebook 的一些代码:

第一个块工作正常:

# Select the hyperparameter batch size
BATCH_SIZE = 100
# Batch the train, validatiion and test data
train_data = train_data.batch(BATCH_SIZE)
validation_data = validation_data.batch(num_validation_samples)
test_data = test_data.batch(num_test_samples)
# Transform the validation data into tuples for the inputs and targets
validation_inputs, validation_targets = next(iter(validation_data))
# Defining model hyperparameters
input_size = 784
output_size = 10
hidden_layer_size = 50
# Defining the model
model = tf.keras.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
    tf.keras.layers.Dense(hidden_layer_size, activation='relu'),
    tf.keras.layers.Dense(hidden_layer_size, activation='relu'),
    tf.keras.layers.Dense(output_size, activation='softmax')
])
# Select the optimizer algorithm and the loss function
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

最后一个块是我得到错误的地方

# Defining hyperparameter number of epochs
NUM_EPOCHS = 5
# Training the model
model.fit(train_data, epochs=NUM_EPOCHS, validation_data=(validation_inputs, validation_targets), verbose=2)

Epoch 1/5
---------------------------------------------------------------------------
DataLossError                             Traceback (most recent call last)
 in 
      2 NUM_EPOCHS = 5
      3 # Training the model
----> 4 model.fit(train_data, epochs=NUM_EPOCHS, validation_data=(validation_inputs, validation_targets), validation_steps=1, verbose=2)

~\Anaconda3\envs\PY3-TF2.0\lib\site-packages\tensorflow\python\keras\engine\training.py in _method_wrapper(self, *args, **kwargs)
     64   def _method_wrapper(self, *args, **kwargs):
     65     if not self._in_multi_worker_mode():  # pylint: disable=protected-access
---> 66       return method(self, *args, **kwargs)
     67 
     68     # Running inside `run_distribute_coordinator` already.

~\Anaconda3\envs\PY3-TF2.0\lib\site-packages\tensorflow\python\keras\engine\training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_batch_size, validation_freq, max_queue_size, workers, use_multiprocessing)
    846                 batch_size=batch_size):
    847               callbacks.on_train_batch_begin(step)
--> 848               tmp_logs = train_function(iterator)
    849               # Catch OutOfRangeError for Datasets of unknown size.
    850               # This blocks until the batch has finished executing.

~\Anaconda3\envs\PY3-TF2.0\lib\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds)
    578         xla_context.Exit()
    579     else:
--> 580       result = self._call(*args, **kwds)
    581 
    582     if tracing_count == self._get_tracing_count():

~\Anaconda3\envs\PY3-TF2.0\lib\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds)
    609       # In this case we have created variables on the first call, so we run the
    610       # defunned version which is guaranteed to never create variables.
--> 611       return self._stateless_fn(*args, **kwds)  # pylint: disable=not-callable
    612     elif self._stateful_fn is not None:
    613       # Release the lock early so that multiple threads can perform the call

~\Anaconda3\envs\PY3-TF2.0\lib\site-packages\tensorflow\python\eager\function.py in __call__(self, *args, **kwargs)
   2418     with self._lock:
   2419       graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
-> 2420     return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
   2421 
   2422   @property

~\Anaconda3\envs\PY3-TF2.0\lib\site-packages\tensorflow\python\eager\function.py in _filtered_call(self, args, kwargs)
   1663          if isinstance(t, (ops.Tensor,
   1664                            resource_variable_ops.BaseResourceVariable))),
-> 1665         self.captured_inputs)
   1666 
   1667   def _call_flat(self, args, captured_inputs, cancellation_manager=None):

~\Anaconda3\envs\PY3-TF2.0\lib\site-packages\tensorflow\python\eager\function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1744       # No tape is watching; skip to running the function.
   1745       return self._build_call_outputs(self._inference_function.call(
-> 1746           ctx, args, cancellation_manager=cancellation_manager))
   1747     forward_backward = self._select_forward_and_backward_functions(
   1748         args,

~\Anaconda3\envs\PY3-TF2.0\lib\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args, cancellation_manager)
    596               inputs=args,
    597               attrs=attrs,
--> 598               ctx=ctx)
    599         else:
    600           outputs = execute.execute_with_cancellation(

~\Anaconda3\envs\PY3-TF2.0\lib\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     58     ctx.ensure_initialized()
     59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:
     62     if name is not None:

DataLossError:  truncated record at 15713022
     [[node IteratorGetNext (defined at :4) ]] [Op:__inference_train_function_957]

Function call stack:
train_function

我不知道错误到底在哪里。 我已经尝试更改 model.fit 方法的一些 arguments ,例如聚合validation_steps 参数,但它没有用。 在此先感谢您的帮助

我设法通过更改 tensoflow-datasets 模块的版本解决了这个问题,我使用的是版本 3,然后回到版本 2.1,脚本可以正常工作。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM