简体   繁体   中英

How to remove last layer in keras subclass model but keep weights?

I am training a feature extractor based on densenet, which looks like the following:

# Import the Sequential model and layers
from keras.models import Sequential
import keras
import tensorflow as tf
from keras.layers import Conv2D, MaxPooling2D, Lambda, Dropout, Concatenate
from keras.layers import Activation, Dropout, Flatten, Dense
import pandas as pd
from sklearn import preprocessing
import ast
from keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint

size = 256

class DenseNetBase(tf.keras.Model):
    
    def __init__(self, size, include_top = True):
        
        super(DenseNetBase, self).__init__()
        
        self.include_top = include_top
        
        #base
        self.base = tf.keras.applications.DenseNet201(weights='imagenet',include_top=False, pooling='avg',input_shape = (size,size,3))
        
        #final layer
        self.dense = Dense(1, activation='sigmoid', name='predictions')
        
    def call(self, input_tensor):
        
        input_image = input_tensor[0]
        input_metafeatures = input_tensor[1]
        
        #model        
        x = self.base(input_image)
        
        if self.include_top:
            x = self.dense(x)
    
        return x
    
    def build_graph(self):
        x = self.base.input
        y = tf.keras.Input(shape=(3,))
        return tf.keras.Model(inputs=[x,y], outputs=self.call([x,y]))

I want to then take the DenseNetBase, keep the trained weights, but remove the final dense layer to use for extracting features. Simplified DenseClassifier looks like this:

class DenseClassifier(tf.keras.Model):
    
    def __init__(self, size, feature_extractor):
        
        super(DenseClassifier, self).__init__()
        
        #base tf.keras.layers.Input(shape=(size,size,3))
        self.feature_extractor = tf.keras.Model(inputs = tf.keras.Input(shape=(size,size,3)), outputs = feature_extractor.layers[-2].output)             
        
        #final layer
        self.dense = Dense(1, activation='sigmoid', name='prediction')
        
    def call(self, input_tensor):
        
        input_image = input_tensor[0]
        input_metafeatures = input_tensor[1]
        
        #model        
        x = self.feature_extractor(input_image)
        
        return self.dense(x)
    
    def build_graph(self):
        x = self.base.input
        y = tf.keras.Input(shape=(3,))
        return tf.keras.Model(inputs=[x,y], outputs=self.call([x,y]))

Tying it together:

#build densenet feature extractor we have trained
denseBase = DenseNetBase(256, include_top = True)
denseBase.build([(None, 256, 256, 3), (None,3)])
denseBase.load_weights('./models/DenseBaseSimple.h5')

#this doesn't work
DenseClassifier = DenseClassifier(size = 256, feature_extractor = denseBase)

In the above example, I get an error for the input which I am not sure why. The expected behaviour would be that I could build the latter model, and compile, and the existing weights DenseNetBase would be used for feature extraction.

I have tried to replace the input section with inputs = feature_extractor.layers[-2].input which does compile, but does not seem to evaluate to the same accuracy as denseBase even though it is using the same weights (in the simple example above with no extra layers).

My goal/question:

  • How can I load the weights from the pre-trained denseBase but remove the last dense layer (so the output is (None, 1920) as from DenseNet without the top but with my weights).
  • How can I then load this model without dense into another subclassed model as above to extract features.

Thanks!

To answer my own question, I did some testing looking at the values of initialised weights using logic from here : 在此处输入图片说明

It's what's expected. DenseBaseClassifier (using denseBase) and using imagenet weights both have similar prediction weight initialisations. This is because both these layers are randomly initialised and not trained, while the prediction layer in denseBase has been optimised and hence is different.

For the denseNet section, DenseBaseClassifier (using denseBase) == denseBase (some noise due to only saving weights), whereas original imagenet weights are different.

Using denseBase_featureextractor = tf.keras.Model(inputs = denseBase.layers[-2].input, outputs = denseBase.layers[-2].output) does indeed preserve the weights.

Not sure why self.feature_extractor = tf.keras.Model(inputs = tf.keras.Input(shape=(size,size,3)), outputs = feature_extractor.layers[-2].output) doesn't work though.

denseBase = DenseNetBase(size, include_top = True)
denseBase.build([(None, 256, 256, 3), (None,3)])
denseBase.load_weights('./models/DenseBaseSimple.h5')

denseBase_featureextractor = tf.keras.Model(inputs = denseBase.layers[-2].input, outputs = denseBase.layers[-2].output)
DenseClassifier_denseBase = DenseClassifier(size = 256, feature_extractor = denseBase_featureextractor)
DenseClassifier_denseBase.build([(None, 256, 256, 3), (None,3)])

denseBase_imagenet = tf.keras.applications.DenseNet201(weights='imagenet',include_top=False, pooling='avg',input_shape = (size,size,3))
DenseClassifier_imagenet = DenseClassifier(size = 256, feature_extractor = denseBase_imagenet)
DenseClassifier_imagenet.build([(None, 256, 256, 3), (None,3)])

def get_weights_print_stats(layer):
    W = layer.get_weights()
    #print(len(W))
    #for w in W:
    #    print(w.shape)
    return W

def hist_weights(weights, title, bins=500):
    for weight in weights[0:5]:
        plt.hist(np.ndarray.flatten(weight), bins=bins)
        plt.title(title)

fig = plt.figure(figsize=(15, 10))
fig.subplots_adjust(hspace=0.4, wspace=0.4)

W = get_weights_print_stats(denseBase.layers[1])
plt.subplot(2, 3, 1)
hist_weights(W, "denseBase")
y = plt.ylabel("Final prediction later weights")#, rotation="horizontal")

W = get_weights_print_stats(DenseClassifier_denseBase.layers[1])
plt.subplot(2, 3, 2)
hist_weights(W, "DenseBaseClassifier (using denseBase weights)")

W = get_weights_print_stats(DenseClassifier_imagenet.layers[1])
plt.subplot(2, 3, 3)
hist_weights(W, "DenseBaseClassifier (using imagenet weights)")

W = get_weights_print_stats(denseBase.layers[0])
plt.subplot(2, 3, 4)
hist_weights(W, "denseBase")
y = plt.ylabel("DenseNet base first 5 weights")#, rotation="horizontal")

W = get_weights_print_stats(DenseClassifier_denseBase.layers[0])
plt.subplot(2, 3, 5)
hist_weights(W, "DenseBaseClassifier (using denseBase weights)")

W = get_weights_print_stats(DenseClassifier_imagenet.layers[0])
plt.subplot(2, 3, 6)
hist_weights(W, "DenseBaseClassifier (using imagenet weights)")

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM