簡體   English   中英

如何在 keras model 中使用多個輸入

[英]How to use multiple inputs in the keras model

我想將four multiple inputs組合成single keras model,但它需要inputs with matching shapes

import tensorflow as tf

input1 = tf.keras.layers.Input(shape=(28, 28, 1))
input2 = tf.keras.layers.Input(shape=(28, 28, 3))
input3 = tf.keras.layers.Input(shape=(128,))
input4 = tf.keras.layers.Input(shape=(1,))

x = tf.keras.layers.Concatenate(axis=1)([input1, input2, input3, input4])

x = tf.keras.layers.Dense(2)(x)

model = tf.keras.models.Model(inputs=[input1, input2, input3, input4], outputs=x)

這是 output

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/tmp/ipykernel_3447/2584043467.py in <cell line: 6>()
      4 input4 = tf.keras.layers.Input(shape=(1,))
      5 
----> 6 x = tf.keras.layers.Concatenate(axis=1)([input1, input2, input3, input4])
      7 
      8 x = tf.keras.layers.Dense(2)(x)

/usr/local/lib/python3.8/site-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
     65     except Exception as e:  # pylint: disable=broad-except
     66       filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67       raise e.with_traceback(filtered_tb) from None
     68     finally:
     69       del filtered_tb

/usr/local/lib/python3.8/site-packages/keras/layers/merging/concatenate.py in build(self, input_shape)
    112       ranks = set(len(shape) for shape in shape_set)
    113       if len(ranks) != 1:
--> 114         raise ValueError(err_msg)
    115       # Get the only rank for the set.
    116       (rank,) = ranks

ValueError: A `Concatenate` layer requires inputs with matching shapes except for the concatenation axis. Received: input_shape=[(None, 28, 28, 1), (None, 28, 28, 3), (None, 128), (None, 1)]

如何在single model 中組合上述inputs

錯誤消息實際上是在告訴您問題所在。 除了要連接的維度之外的所有維度都必須相同,而它們不是。 你可以嘗試這樣的事情:

import tensorflow as tf

input1 = tf.keras.layers.Input(shape=(28, 28, 1))
input2 = tf.keras.layers.Input(shape=(28, 28, 3))
input3 = tf.keras.layers.Input(shape=(128,))
input4 = tf.keras.layers.Input(shape=(1,))

input1 = tf.keras.layers.Flatten()(input1)
input2 = tf.keras.layers.Flatten()(input2)
x = tf.keras.layers.Concatenate(axis=-1)([input1, input2, input3, input4])

x = tf.keras.layers.Dense(2)(x)

model = tf.keras.models.Model(inputs=[input1, input2, input3, input4], outputs=x)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM