简体   繁体   中英

Convert Keras-Fuctional-API into a Keras Subclassed Model

I'm relatively new to Keras and Tensorflow and I want to learn the basic implementations. For this I want to build a model that can learn/detect/predict handwritten digits, therefore I use the MNIST-dataset from Keras. I already created this model with the Keras Functional API and everything works fine. Now I wanted to do the exact same thing, but this time I want to build a Keras subclassed model. The problem is, that I got an error when I executed the code with the Keras subclassed model. This is the code of the model with the functional API (that works fine without any problem):

import tensorflow as tf 
from tensorflow import keras
from tensorflow.keras import layers    
from keras.datasets import mnist
import numpy as np 

#Load MNIST-Dataset
(x_train_full, y_train_full), (x_test, y_test) = mnist.load_data()

#Create train- and validationdata
X_valid = x_train_full[:5000]/255.0         
X_train = x_train_full[5000:] / 255.0       
y_valid, y_train = y_train_full[:5000], y_train_full[5000:]   



#Create the model with the keras functional-API
inputs = keras.layers.Input(shape=(28, 28))                
flatten = keras.layers.Flatten(input_shape=(28, 28))(inputs)        
hidden1 = keras.layers.Dense(256, activation="relu")(flatten)
hidden2 = keras.layers.Dense(128, activation='relu')(hidden1)
outputs = keras.layers.Dense(10, activation='softmax')(hidden2)

model = keras.Model(inputs=[inputs], outputs=[outputs])

model.compile(loss="sparse_categorical_crossentropy", optimizer="sgd", metrics=["accuracy"])
h = model.fit(X_train, y_train, epochs=5, validation_data=(X_valid, y_valid))

#Evaluate the model with testdata
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
print('\nTest accuracy: ', test_acc)
print('\nTest loss: ', test_loss)

#Create Predictions:
myPrediction = model.predict(x_test)

#Prediction example of one testpicture
print(myPrediction[0])
print('Predicted Item: ', np.argmax(myPrediction[0]))
print('Actual Item: ', y_test[0])

And here is the (not working) code of the Keras subclassed model which should do exactly the same thing like the code above:

import tensorflow as tf 
from tensorflow import keras
from tensorflow.keras import layers    
from keras.datasets import mnist
import numpy as np 

#Load MNIST-Dataset
(x_train_full, y_train_full), (x_test, y_test) = mnist.load_data()

#Create train- and validationdata
X_valid = x_train_full[:5000]/255.0         
X_train = x_train_full[5000:] / 255.0       
y_valid, y_train = y_train_full[:5000], y_train_full[5000:] 

#Create a keras-subclassing-model:
class MyModel(tf.keras.Model):

    def __init__(self):
        super(MyModel, self).__init__()
        #Define layers
        self.input_ = keras.layers.Input(shape=(28, 28))
        self.flatten = keras.layers.Flatten(input_shape=(28, 28))
        self.dense_1 = keras.layers.Dense(256, activation="relu")
        self.dense_2 = keras.layers.Dense(128, activation="relu")
        self.output_ = keras.layers.Dense(10, activation="softmax")

    def call(self, inputs):
        x = self.input_(inputs)
        x = self.flatten(x)
        x = self.dense_1(x)
        x = self.dense_2(x)
        x = self.output_(x)
        return x

model = MyModel()

model.compile(loss="sparse_categorical_crossentropy", optimizer="sgd", metrics=["accuracy"])

h = model.fit(X_train, y_train, epochs=10, validation_data=(X_valid, y_valid))

Every time I got the same error when I run this code. The error appeared when the fit(...) -method is called:

Traceback (most recent call last):
  File "c:/Users/MichaelM/Documents/PythonSkripte/MachineLearning/SubclassedModel.py", line 39, in <module>
    h = model.fit(X_train, y_train, epochs=10, validation_data=(X_valid, y_valid))
  File "C:\Python37\lib\site-packages\tensorflow_core\python\keras\engine\training.py", line 819, in fit
    use_multiprocessing=use_multiprocessing)
  File "C:\Python37\lib\site-packages\tensorflow_core\python\keras\engine\training_v2.py", line 235, in fit
    use_multiprocessing=use_multiprocessing)
  File "C:\Python37\lib\site-packages\tensorflow_core\python\keras\engine\training_v2.py", line 593, in _process_training_inputs
    use_multiprocessing=use_multiprocessing)
  File "C:\Python37\lib\site-packages\tensorflow_core\python\keras\engine\training_v2.py", line 646, in _process_inputs
    x, y, sample_weight=sample_weights)
  File "C:\Python37\lib\site-packages\tensorflow_core\python\keras\engine\training.py", line 2346, in _standardize_user_data
    all_inputs, y_input, dict_inputs = self._build_model_with_inputs(x, y)
  File "C:\Python37\lib\site-packages\tensorflow_core\python\keras\engine\training.py", line 2572, in _build_model_with_inputs
    self._set_inputs(cast_inputs)
  File "C:\Python37\lib\site-packages\tensorflow_core\python\keras\engine\training.py", line 2659, in _set_inputs
    outputs = self(inputs, **kwargs)
  File "C:\Python37\lib\site-packages\tensorflow_core\python\keras\engine\base_layer.py", line 773, in __call__
    outputs = call_fn(cast_inputs, *args, **kwargs)
  File "C:\Python37\lib\site-packages\tensorflow_core\python\autograph\impl\api.py", line 237, in wrapper
    raise e.ag_error_metadata.to_exception(e)
TypeError: in converted code:

    c:/Users/MichaelM/Documents/PythonSkripte/MachineLearning/SubclassedModel.py:28 call  *
        x = self.input_(inputs)
    C:\Python37\lib\site-packages\tensorflow_core\python\autograph\impl\api.py:447 converted_call
        f in m.__dict__.values() for m in (collections, pdb, copy, inspect, re)):
    C:\Python37\lib\site-packages\tensorflow_core\python\autograph\impl\api.py:447 <genexpr>
        f in m.__dict__.values() for m in (collections, pdb, copy, inspect, re)):
    C:\Python37\lib\site-packages\tensorflow_core\python\ops\math_ops.py:1351 tensor_equals
        return gen_math_ops.equal(self, other, incompatible_shape_error=False)
    C:\Python37\lib\site-packages\tensorflow_core\python\ops\gen_math_ops.py:3240 equal
        name=name)
    C:\Python37\lib\site-packages\tensorflow_core\python\framework\op_def_library.py:477 _apply_op_helper
        repr(values), type(values).__name__, err))

    TypeError: Expected float32 passed to parameter 'y' of op 'Equal', got 'collections' of type 'str' instead. Error: Expected float32, got 'collections' of type 'str' instead.

Could you please help me to fix this problem and maybe explain why this isn't working, because I don't know what this error actually means. And can I call then the evaluate(...) and predict(...) methods like in the functional API code? I use the following configurations:

  • Visual Studio Code with Python-Extension as IDE
  • Python-Version: 3.7.6
  • TensorFlow-Version: 2.1.0
  • Keras-Version: 2.2.4-tf

Actually you don't need to implement Input in the call method as you are passing data directly to the subclass. I updated the code and it works well as expected. Please check below.

#Create a keras-subclassing-model:
class MyModel(tf.keras.Model):

    def __init__(self):
        super(MyModel, self).__init__()
        #Define layers
        #self.input_ = keras.layers.Input(shape=(28, 28))
        self.flatten = keras.layers.Flatten(input_shape=(28, 28))
        self.dense_1 = keras.layers.Dense(256, activation="relu")
        self.dense_2 = keras.layers.Dense(128, activation="relu")
        self.output_ = keras.layers.Dense(10, activation="softmax")

    def call(self, inputs):
        #x = self.input_(inputs)
        x = self.flatten(inputs)
        x = self.dense_1(x)
        x = self.dense_2(x)
        x = self.output_(x)
        return x

model = MyModel()

model.compile(loss="sparse_categorical_crossentropy", optimizer="sgd", metrics=["accuracy"])

h = model.fit(X_train, y_train, epochs=10, validation_data=(X_valid, y_valid))

Output is as follows

Epoch 1/10
1719/1719 [==============================] - 6s 3ms/step - loss: 0.6251 - accuracy: 0.8327 - val_loss: 0.3068 - val_accuracy: 0.9180
....
....
Epoch 10/10
1719/1719 [==============================] - 6s 3ms/step - loss: 0.1097 - accuracy: 0.9687 - val_loss: 0.1215 - val_accuracy: 0.9648

Full code is here .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM