[英]Keras Nested Models save and load weights separately or view Summary of all nested models
[英]Method to transfer weights between nested keras models
我正在嘗試依次建立混合模型,迭代添加子模型。
我首先構建和訓練一個簡單的模型。 然后我構建了一個稍微復雜的模型,它包含所有原始模型但具有更多層。 我想將訓練好的權重從第一個模型移到新模型中。 我怎樣才能做到這一點? 第一個模型嵌套在第二個模型中。
這是一個虛擬的 MWE:
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import (concatenate, Conv1D, Dense, LSTM)
from tensorflow.keras import Model, Input, backend
# data
x = np.random.normal(size = 100)
y = np.sin(x)+np.random.normal(size = 100)
# model 1
def make_model_1():
inp = Input(1)
l1 = Dense(5, activation = 'relu')(inp)
out1 = Dense(1)(l1)
model1 = Model(inp, out1)
return model1
model1 = make_model_1()
model1.compile(optimizer = tf.keras.optimizers.SGD(),
loss = tf.keras.losses.mean_squared_error)
model1.fit(x, y, epochs = 3, batch_size = 10)
# make model 2
def make_model_2():
inp = Input(1)
l1 = Dense(5, activation = 'relu')(inp)
out1 = Dense(1)(l1)
l2 = Dense(15, activation = 'sigmoid')(inp)
out2 = Dense(1)(l2)
bucket = tf.stack([out1, out2], axis=2)
out = backend.squeeze(Dense(1)(bucket), axis = 2)
model2 = Model(inp, out)
return model2
model2 = make_model_2()
我怎么可以轉移重量從model1
到model2
? 以一種自動且完全不可知這兩個模型的性質的方式,除了它們是嵌套的?
你可以簡單地在你感興趣的新模型的特定部分加載訓練好的權重。我通過在model2
創建一個新的model1
實例來做到這一點。 之后,我加載訓練好的權重。
這里是完整的例子
# data
x = np.random.normal(size = 100)
y = np.sin(x)+np.random.normal(size = 100)
# model 1
def make_model_1():
inp = Input(1)
l1 = Dense(5, activation = 'relu')(inp)
out1 = Dense(1)(l1)
model1 = Model(inp, out1)
return model1
model1 = make_model_1()
model1.compile(optimizer = tf.keras.optimizers.SGD(),
loss = tf.keras.losses.mean_squared_error)
model1.fit(x, y, epochs = 3, batch_size = 10)
# make model 2
def make_model_2(trained_model):
inp = Input(1)
m = make_model_1()
m.set_weights(trained_model.get_weights())
out1 = m(inp)
l2 = Dense(15, activation = 'sigmoid')(inp)
out2 = Dense(1)(l2)
bucket = tf.stack([out1, out2], axis=2)
out = tf.keras.backend.squeeze(Dense(1)(bucket), axis = 2)
model2 = Model(inp, out)
return model2
model2 = make_model_2(model1)
model2.summary()
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.