簡體   English   中英

Keras Tensorflow 多個錯誤

[英]Keras Tensorflow multiple errors

我在 CodeCademy 上編程並被卡住了。 我找不到答案,終端顯示一些奇怪的東西。 該項目是關於對 covid-19、肺炎和正常肺的圖像進行分類。 希望您能夠幫助我。

代碼:

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.models import Sequential
from tensorflow.keras.preprocessing.image import ImageDataGenerator

from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras import layers

import matplotlib.pyplot as plt
import app

training_generator = ImageDataGenerator(rescale = 1./255)
training_iterator = training_generator.flow_from_directory("augmented-data/train", class_mode='categorical',color_mode='grayscale', batch_size=5)

validation_generator = ImageDataGenerator(rescale = 1./255)
validation_iterator = validation_generator.flow_from_directory("augmented-data/test", class_mode='categorical',color_mode='grayscale', batch_size=5)

model = Sequential()
model.add(tf.keras.Input(shape=training_iterator.image_shape))
model.add(tf.keras.layers.Conv2D(8, 3, strides = 2, activation = "relu"))
model.add(tf.keras.layers.MaxPooling2D(pool_size = (2, 2), strides = (2, 2)))
model.add(tf.keras.layers.Conv2D(8, 3, strides = 2, activation = "relu"))
model.add(tf.keras.layers.MaxPooling2D(pool_size = (2, 2), strides = (2, 2)))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(16, activation = "relu"))
model.add(tf.keras.layers.Dense(4, activation = "relu"))


model.compile(optimizer = tf.keras.optimizers.Adam(learning_rate = 0.01), loss = tf.keras.losses.CategoricalCrossentropy(), metrics = [tf.keras.metrics.CategoricalAccuracy(),tf.keras.metrics.AUC()])

model.fit(training_iterator, steps_per_epoch = training_iterator.samples / 5, epochs = 5, validation_data = validation_iterator, validation_steps = validation_iterator.samples / 5)

錯誤:

Traceback (most recent call last):
  File "script.py", line 31, in <module>
    model.fit(training_iterator, steps_per_epoch = training_iterator.samples / 5, epochs = 5, validation_data = validation_iterator, validation_steps = validation_iterator.samples / 5)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py", line 66, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py", line 848, in fit
    tmp_logs = train_function(iterator)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 580, in __call__
    result = self._call(*args, **kwds)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/def_function.py", line 644, in _call
    return self._stateless_fn(*args, **kwds)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 2420, in __call__
    return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 1665, in _filtered_call
    self.captured_inputs)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 1746, in _call_flat
    ctx, args, cancellation_manager=cancellation_manager))
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/function.py", line 598, in call
    ctx=ctx)
  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py", line 60, in quick_execute
    inputs, attrs, num_outputs)
tensorflow.python.framework.errors_impl.InvalidArgumentError:  Incompatible shapes: [5,3] vs. [5,4]
     [[node categorical_crossentropy/mul (defined at script.py:31) ]] [Op:__inference_train_function_1137]

Function call stack:
train_function

該項目是關於對 covid-19、肺炎和正常肺的圖像進行分類。

正如您所說,您有 3 個類,但在最后一個密集層中,您的 output 層有 4 個神經元,這是不兼容的,也有'relu'作為激活,這是另一個錯誤。

您應該將最后一個密集層更改為:

model.add(tf.keras.layers.Dense(3, activation = tf.nn.softmax))

您的數據與您的 model 架構不匹配

Incompatible shapes: [5,3] vs. [5,4]

要調試這些類型的錯誤,請嘗試將run_eagerly=False參數添加到您的model.compile function; 錯誤變得更具可讀性。

https://www.tensorflow.org/api_docs/python/tf/keras/Model#compile

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM