繁体   English   中英

如何将两个 keras 型号连接到一个 model 中?

[英]How do I connect two keras models into one model?

Let's say I have a ResNet50 model and I wish to connect the output layer of this model to the input layer of a VGG model.

这是 ResNet50 的 ResNet model 和 output 张量:

img_shape = (164, 164, 3)
resnet50_model = ResNet50(include_top=False, input_shape=img_shape, weights = None)

print(resnet50_model.output.shape)

我得到 output:

TensorShape([Dimension(None), Dimension(6), Dimension(6), Dimension(2048)])

现在我想要一个新层,我将这个 output 张量重塑为 (64,64,18)

然后我有一个 VGG16 model:

VGG_model = VGG_model = VGG16(include_top=False, weights=None)

我希望 ResNet50 的 output 重塑为所需的张量并作为 VGG model 的输入输入。 所以本质上我想连接两个模型。 有人可以帮我这样做吗? 谢谢!

有多种方法可以做到这一点。 这是使用顺序 model API 执行此操作的一种方法。

import tensorflow as tf
from tensorflow.keras.applications import ResNet50, VGG16

model = tf.keras.Sequential()
img_shape = (164, 164, 3)
model.add(ResNet50(include_top=False, input_shape=img_shape, weights = None))

model.add(tf.keras.layers.Reshape(target_shape=(64,64,18)))
model.add(tf.keras.layers.Conv2D(3,kernel_size=(3,3),name='Conv2d'))

VGG_model = VGG16(include_top=False, weights=None)
model.add(VGG_model)

model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

model.summary()

Model总结如下

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
resnet50 (Model)             (None, 6, 6, 2048)        23587712  
_________________________________________________________________
reshape (Reshape)            (None, 64, 64, 18)        0         
_________________________________________________________________
Conv2d (Conv2D)              (None, 62, 62, 3)         489       
_________________________________________________________________
vgg16 (Model)                multiple                  14714688  
=================================================================
Total params: 38,302,889
Trainable params: 38,249,769
Non-trainable params: 53,120
_________________________________________________________________

完整代码在这里

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM