简体   繁体   中英

Is it possible to use the Tensorflow Keras functional API within a subclassed Model?

I'm trying to create a keras Model that needs particular preprocessing on the inputs before it can be called by the model itself. I'm subclassing as the model is only part of a complex network so I can concatenate outputs and directly access model behaviour from other parts of my code...etc.

The way I've designed it is using the keras functional API in the constructor to chain the layers which works fine if I DON'T define the call method (it seems to behave exactly as if I used the fAPI normally when I call the instance).

My problem is when I DO want to define the call method, where I'm not sure what function to call to access the default behaviour of the compiled model from the constructor:

import tensorflow as tf
import numpy as np

class MyModel(tf.keras.Model):

  def __init__(self, inputshape):
    inputs = tf.keras.Input(shape=inputshape)
    x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
    outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)

    super(MyModel, self).__init__(inputs=inputs, outputs=outputs)


  def call(self, inputs, training=False):
    reduced_input = tf.expand_dims(inputs['b'], axis=0)

    # Want to call my compiled self MODEL with 'reduced_input' as the input argument but not sure how...
    return something(reduced_input)


myModelInstance = MyModel(inputshape=(3,))
myInput = {'a': [1, 2], 'b': np.array([3, 4, 5]), 'c': 6}

# Example preprocessing that I want to implement from within the model when called. Won't be this simple
reduced_input  = tf.expand_dims(myInput['b'], axis=0)

print(myModelInstance(reduced_input ))

In this snippet I've simplified the constructor as well as the input preprocessing (here it extracts only the 'b' element from the input and adds a batch dimension) but my actual implementation is more complex.

I'd prefer a) to avoid preprocessing the data before calling the instance and b) to subclass Model instead of storing the model as a class attribute.

Is there a way to combine Model subclassing with the functional API as I'm trying to do?

You should do this way:

class MyModel(tf.keras.Model):

  def __init__(self, inputshape):
    super(MyModel, self).__init__()
    inputs = tf.keras.Input(shape=inputshape)
    x = tf.keras.layers.Dense(4, activation=tf.nn.relu)(inputs)
    outputs = tf.keras.layers.Dense(5, activation=tf.nn.softmax)(x)
    self.model = tf.keras.Model(inputs, outputs)

  def call(self, inputs, training=False):
    reduced_input = tf.expand_dims(inputs['b'], axis=0)

    # Want to call my compiled self MODEL with 'reduced_input' as the input argument but not sure how...
    return self.model(reduced_input)


myModelInstance = MyModel(inputshape=(3,))
myInput = {'a': [1, 2], 'b': np.array([3, 4, 5]), 'c': 6}

# Example preprocessing that I want to implement from within the model when called. Won't be this simple
#reduced_input  = tf.expand_dims(myInput['b'], axis=0)

print(myModelInstance(myInput))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM