繁体   English   中英

尝试连接两个模型并适合Keras时出现AssertionError

[英]AssertionError while trying to concatenate two models and fit in Keras

我正在尝试开发图像字幕模型。 我指的是这个Github存储库 我有三种方法,它们执行以下操作:

  1. 生成图像模型
  2. 生成字幕模型
  3. 将图像和字幕模型连接在一起

由于代码很长,因此我创建了一个Gist来显示方法

这是我的图像模型和标题模型摘要

但是然后我运行代码,却收到此错误:

TraceTraceback (most recent call last):
  File "trainer.py", line 99, in <module>
    model.fit([images, encoded_captions], one_hot_captions, batch_size = 1, epochs = 5)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/engine/training.py", line 950, in fit
    batch_size=batch_size)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/engine/training.py", line 671, in _standardize_user_data
    self._set_inputs(x)
  File "/Library/Frameworks/Python.framework/Versions/3.6/lib/python3.6/site-packages/keras/engine/training.py", line 575, in _set_inputs
    assert len(inputs) == 1
AssertionError

由于错误来自Keras库,所以我不知道如何调试它。 但是当我尝试将它们连接在一起时出了点问题。

我想知道我是否在这里错过了什么

您需要使用output属性获取模型的output ,然后使用Keras功能API进行连接(通过Concatenate层或其等效功能接口concatenate )并创建最终模型:

from keras.models import Model

image_model = get_image_model()
language_model = get_language_model(vocab_size)

merged = concatenate([image_model.output, language_model.output])
x = LSTM(256, return_sequences = False)(merged)
x = Dense(vocab_size)(x)
out = Activation('softmax')(x)

model = Model([image_model.input, language_model.input], out)
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
model.fit([images, encoded_captions], one_hot_captions, ...)

就像现在在代码中一样,您还可以为模型创建逻辑定义一个函数:

def get_concatenated_model(image_model, language_model, vocab_size):
    merged = concatenate([image_model.output, language_model.output])
    x = LSTM(256, return_sequences = False)(merged)
    x = Dense(vocab_size)(x)
    out = Activation('softmax')(x)

    model = Model([image_model.input, language_model.input], out)
    return model

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM