简体   繁体   中英

Solve InvalidArgumentError in tensorflow

Nothing much I can say either from that I'm getting an error which I can't seem to solve from the following piece of code.

import tensorflow as tf
import matplotlib.pyplot as plt

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

inpTensor = tf.keras.layers.Input(x_train.shape[1:],)

hidden1Out = tf.keras.layers.Dense(units=128, activation=tf.nn.relu)(inpTensor)
hidden2Out = tf.keras.layers.Dense(units=128, activation=tf.nn.relu)(hidden1Out)  
finalOut = tf.keras.layers.Dense(units=10, activation=tf.nn.softmax)(hidden2Out)

model = tf.keras.Model(inputs=inpTensor, outputs=finalOut)

model.summary()

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train,y_train, epochs = 4)

I've tried changing the loss function to 'categorical_crossentropy', but seemed to not work either. I am running Python 3.7 and would really appriciate some help. I am kind of new to this as well.

Thanks in advance.

the problem is in the way you manage the dimensionality in your network... you receive 3D images and don't pass to 2D to obtain probabilities... this can simply be done using Flatten or global pooling operation. sparse_categorical_crossentropy is correct in your case. here an example

import tensorflow as tf
import matplotlib.pyplot as plt

(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()

inpTensor = tf.keras.layers.Input(x_train.shape[1:],)

hidden1Out = tf.keras.layers.Dense(units=128, activation=tf.nn.relu)(inpTensor)
hidden2Out = tf.keras.layers.Dense(units=128, activation=tf.nn.relu)(hidden1Out)
pooling = tf.keras.layers.GlobalMaxPool2D()(hidden2Out) #<== also GlobalAvgPool2D or Flatten are ok
finalOut = tf.keras.layers.Dense(units=10, activation=tf.nn.softmax)(pooling)

model = tf.keras.Model(inputs=inpTensor, outputs=finalOut)

model.summary()

model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(x_train,y_train, epochs = 4)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM