简体   繁体   English

如何在 keras model 中使用多个输入

[英]How to use multiple inputs in the keras model

I want to combine the four multiple inputs into the single keras model, but it requires inputs with matching shapes :我想将four multiple inputs组合成single keras model,但它需要inputs with matching shapes

import tensorflow as tf

input1 = tf.keras.layers.Input(shape=(28, 28, 1))
input2 = tf.keras.layers.Input(shape=(28, 28, 3))
input3 = tf.keras.layers.Input(shape=(128,))
input4 = tf.keras.layers.Input(shape=(1,))

x = tf.keras.layers.Concatenate(axis=1)([input1, input2, input3, input4])

x = tf.keras.layers.Dense(2)(x)

model = tf.keras.models.Model(inputs=[input1, input2, input3, input4], outputs=x)

Here is the output这是 output

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_3447/2584043467.py in <cell line: 6>()
      4 input4 = tf.keras.layers.Input(shape=(1,))
      5 
----> 6 x = tf.keras.layers.Concatenate(axis=1)([input1, input2, input3, input4])
      7 
      8 x = tf.keras.layers.Dense(2)(x)

/usr/local/lib/python3.8/site-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67       raise e.with_traceback(filtered_tb) from None
     68     finally:
     69       del filtered_tb

/usr/local/lib/python3.8/site-packages/keras/layers/merging/concatenate.py in build(self, input_shape)
    112       ranks = set(len(shape) for shape in shape_set)
    113       if len(ranks) != 1:
--> 114         raise ValueError(err_msg)
    115       # Get the only rank for the set.
    116       (rank,) = ranks

ValueError: A `Concatenate` layer requires inputs with matching shapes except for the concatenation axis. Received: input_shape=[(None, 28, 28, 1), (None, 28, 28, 3), (None, 128), (None, 1)]

How to combine the above inputs in the single model?如何在single model 中组合上述inputs

The error message is actually telling you what the problem is.错误消息实际上是在告诉您问题所在。 All dimensions except the one you want to concatenate have to be the same and they are not.除了要连接的维度之外的所有维度都必须相同,而它们不是。 You can try something like this:你可以尝试这样的事情:

import tensorflow as tf

input1 = tf.keras.layers.Input(shape=(28, 28, 1))
input2 = tf.keras.layers.Input(shape=(28, 28, 3))
input3 = tf.keras.layers.Input(shape=(128,))
input4 = tf.keras.layers.Input(shape=(1,))

input1 = tf.keras.layers.Flatten()(input1)
input2 = tf.keras.layers.Flatten()(input2)
x = tf.keras.layers.Concatenate(axis=-1)([input1, input2, input3, input4])

x = tf.keras.layers.Dense(2)(x)

model = tf.keras.models.Model(inputs=[input1, input2, input3, input4], outputs=x)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM