简体   繁体   中英

How to use multiple inputs in the keras model

I want to combine the four multiple inputs into the single keras model, but it requires inputs with matching shapes :

import tensorflow as tf

input1 = tf.keras.layers.Input(shape=(28, 28, 1))
input2 = tf.keras.layers.Input(shape=(28, 28, 3))
input3 = tf.keras.layers.Input(shape=(128,))
input4 = tf.keras.layers.Input(shape=(1,))

x = tf.keras.layers.Concatenate(axis=1)([input1, input2, input3, input4])

x = tf.keras.layers.Dense(2)(x)

model = tf.keras.models.Model(inputs=[input1, input2, input3, input4], outputs=x)

Here is the output

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_3447/2584043467.py in <cell line: 6>()
      4 input4 = tf.keras.layers.Input(shape=(1,))
      5 
----> 6 x = tf.keras.layers.Concatenate(axis=1)([input1, input2, input3, input4])
      7 
      8 x = tf.keras.layers.Dense(2)(x)

/usr/local/lib/python3.8/site-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67       raise e.with_traceback(filtered_tb) from None
     68     finally:
     69       del filtered_tb

/usr/local/lib/python3.8/site-packages/keras/layers/merging/concatenate.py in build(self, input_shape)
    112       ranks = set(len(shape) for shape in shape_set)
    113       if len(ranks) != 1:
--> 114         raise ValueError(err_msg)
    115       # Get the only rank for the set.
    116       (rank,) = ranks

ValueError: A `Concatenate` layer requires inputs with matching shapes except for the concatenation axis. Received: input_shape=[(None, 28, 28, 1), (None, 28, 28, 3), (None, 128), (None, 1)]

How to combine the above inputs in the single model?

The error message is actually telling you what the problem is. All dimensions except the one you want to concatenate have to be the same and they are not. You can try something like this:

import tensorflow as tf

input1 = tf.keras.layers.Input(shape=(28, 28, 1))
input2 = tf.keras.layers.Input(shape=(28, 28, 3))
input3 = tf.keras.layers.Input(shape=(128,))
input4 = tf.keras.layers.Input(shape=(1,))

input1 = tf.keras.layers.Flatten()(input1)
input2 = tf.keras.layers.Flatten()(input2)
x = tf.keras.layers.Concatenate(axis=-1)([input1, input2, input3, input4])

x = tf.keras.layers.Dense(2)(x)

model = tf.keras.models.Model(inputs=[input1, input2, input3, input4], outputs=x)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM