[英]How to remove last layer in keras subclass model but keep weights?
I am training a feature extractor based on densenet, which looks like the following:我正在训练一个基于密集网的特征提取器,如下所示:
# Import the Sequential model and layers
from keras.models import Sequential
import keras
import tensorflow as tf
from keras.layers import Conv2D, MaxPooling2D, Lambda, Dropout, Concatenate
from keras.layers import Activation, Dropout, Flatten, Dense
import pandas as pd
from sklearn import preprocessing
import ast
from keras.callbacks import EarlyStopping
from keras.callbacks import ModelCheckpoint
size = 256
class DenseNetBase(tf.keras.Model):
def __init__(self, size, include_top = True):
super(DenseNetBase, self).__init__()
self.include_top = include_top
#base
self.base = tf.keras.applications.DenseNet201(weights='imagenet',include_top=False, pooling='avg',input_shape = (size,size,3))
#final layer
self.dense = Dense(1, activation='sigmoid', name='predictions')
def call(self, input_tensor):
input_image = input_tensor[0]
input_metafeatures = input_tensor[1]
#model
x = self.base(input_image)
if self.include_top:
x = self.dense(x)
return x
def build_graph(self):
x = self.base.input
y = tf.keras.Input(shape=(3,))
return tf.keras.Model(inputs=[x,y], outputs=self.call([x,y]))
I want to then take the DenseNetBase, keep the trained weights, but remove the final dense layer to use for extracting features.然后我想采用 DenseNetBase,保留训练过的权重,但删除最后的密集层以用于提取特征。 Simplified DenseClassifier looks like this:简化的 DenseClassifier 看起来像这样:
class DenseClassifier(tf.keras.Model):
def __init__(self, size, feature_extractor):
super(DenseClassifier, self).__init__()
#base tf.keras.layers.Input(shape=(size,size,3))
self.feature_extractor = tf.keras.Model(inputs = tf.keras.Input(shape=(size,size,3)), outputs = feature_extractor.layers[-2].output)
#final layer
self.dense = Dense(1, activation='sigmoid', name='prediction')
def call(self, input_tensor):
input_image = input_tensor[0]
input_metafeatures = input_tensor[1]
#model
x = self.feature_extractor(input_image)
return self.dense(x)
def build_graph(self):
x = self.base.input
y = tf.keras.Input(shape=(3,))
return tf.keras.Model(inputs=[x,y], outputs=self.call([x,y]))
Tying it together:把它绑在一起:
#build densenet feature extractor we have trained
denseBase = DenseNetBase(256, include_top = True)
denseBase.build([(None, 256, 256, 3), (None,3)])
denseBase.load_weights('./models/DenseBaseSimple.h5')
#this doesn't work
DenseClassifier = DenseClassifier(size = 256, feature_extractor = denseBase)
In the above example, I get an error for the input which I am not sure why.在上面的例子中,我得到了一个输入错误,我不知道为什么。 The expected behaviour would be that I could build the latter model, and compile, and the existing weights DenseNetBase would be used for feature extraction.预期的行为是我可以构建后一个模型并进行编译,并且现有的权重 DenseNetBase 将用于特征提取。
I have tried to replace the input section with inputs = feature_extractor.layers[-2].input
which does compile, but does not seem to evaluate to the same accuracy as denseBase even though it is using the same weights (in the simple example above with no extra layers).我试图用可以编译的inputs = feature_extractor.layers[-2].input
替换输入部分,但即使它使用相同的权重(在上面的简单示例中),它似乎也没有评估到与denseBase 相同的准确度没有额外的层)。
My goal/question:我的目标/问题:
Thanks!谢谢!
To answer my own question, I did some testing looking at the values of initialised weights using logic from here :为了回答我自己的问题,我使用此处的逻辑对初始化权重的值进行了一些测试:
It's what's expected.这是预期的。 DenseBaseClassifier (using denseBase) and using imagenet weights both have similar prediction weight initialisations. DenseBaseClassifier(使用denseBase)和使用imagenet 权重都具有相似的预测权重初始化。 This is because both these layers are randomly initialised and not trained, while the prediction layer in denseBase has been optimised and hence is different.这是因为这两个层都是随机初始化的,没有经过训练,而denseBase 中的预测层已经过优化,因此是不同的。
For the denseNet section, DenseBaseClassifier (using denseBase) == denseBase (some noise due to only saving weights), whereas original imagenet weights are different.对于denseNet 部分,DenseBaseClassifier(使用denseBase)==denseBase(由于只保存权重而产生一些噪音),而原始imagenet 权重是不同的。
Using denseBase_featureextractor = tf.keras.Model(inputs = denseBase.layers[-2].input, outputs = denseBase.layers[-2].output)
does indeed preserve the weights.使用denseBase_featureextractor = tf.keras.Model(inputs = denseBase.layers[-2].input, outputs = denseBase.layers[-2].output)
确实保留了权重。
Not sure why self.feature_extractor = tf.keras.Model(inputs = tf.keras.Input(shape=(size,size,3)), outputs = feature_extractor.layers[-2].output)
doesn't work though.不知道为什么self.feature_extractor = tf.keras.Model(inputs = tf.keras.Input(shape=(size,size,3)), outputs = feature_extractor.layers[-2].output)
不起作用。
denseBase = DenseNetBase(size, include_top = True)
denseBase.build([(None, 256, 256, 3), (None,3)])
denseBase.load_weights('./models/DenseBaseSimple.h5')
denseBase_featureextractor = tf.keras.Model(inputs = denseBase.layers[-2].input, outputs = denseBase.layers[-2].output)
DenseClassifier_denseBase = DenseClassifier(size = 256, feature_extractor = denseBase_featureextractor)
DenseClassifier_denseBase.build([(None, 256, 256, 3), (None,3)])
denseBase_imagenet = tf.keras.applications.DenseNet201(weights='imagenet',include_top=False, pooling='avg',input_shape = (size,size,3))
DenseClassifier_imagenet = DenseClassifier(size = 256, feature_extractor = denseBase_imagenet)
DenseClassifier_imagenet.build([(None, 256, 256, 3), (None,3)])
def get_weights_print_stats(layer):
W = layer.get_weights()
#print(len(W))
#for w in W:
# print(w.shape)
return W
def hist_weights(weights, title, bins=500):
for weight in weights[0:5]:
plt.hist(np.ndarray.flatten(weight), bins=bins)
plt.title(title)
fig = plt.figure(figsize=(15, 10))
fig.subplots_adjust(hspace=0.4, wspace=0.4)
W = get_weights_print_stats(denseBase.layers[1])
plt.subplot(2, 3, 1)
hist_weights(W, "denseBase")
y = plt.ylabel("Final prediction later weights")#, rotation="horizontal")
W = get_weights_print_stats(DenseClassifier_denseBase.layers[1])
plt.subplot(2, 3, 2)
hist_weights(W, "DenseBaseClassifier (using denseBase weights)")
W = get_weights_print_stats(DenseClassifier_imagenet.layers[1])
plt.subplot(2, 3, 3)
hist_weights(W, "DenseBaseClassifier (using imagenet weights)")
W = get_weights_print_stats(denseBase.layers[0])
plt.subplot(2, 3, 4)
hist_weights(W, "denseBase")
y = plt.ylabel("DenseNet base first 5 weights")#, rotation="horizontal")
W = get_weights_print_stats(DenseClassifier_denseBase.layers[0])
plt.subplot(2, 3, 5)
hist_weights(W, "DenseBaseClassifier (using denseBase weights)")
W = get_weights_print_stats(DenseClassifier_imagenet.layers[0])
plt.subplot(2, 3, 6)
hist_weights(W, "DenseBaseClassifier (using imagenet weights)")
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.