[英]How to use multiple inputs in the keras model
我想將four multiple inputs
組合成single
keras model,但它需要inputs with matching shapes
:
import tensorflow as tf
input1 = tf.keras.layers.Input(shape=(28, 28, 1))
input2 = tf.keras.layers.Input(shape=(28, 28, 3))
input3 = tf.keras.layers.Input(shape=(128,))
input4 = tf.keras.layers.Input(shape=(1,))
x = tf.keras.layers.Concatenate(axis=1)([input1, input2, input3, input4])
x = tf.keras.layers.Dense(2)(x)
model = tf.keras.models.Model(inputs=[input1, input2, input3, input4], outputs=x)
這是 output
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_3447/2584043467.py in <cell line: 6>()
4 input4 = tf.keras.layers.Input(shape=(1,))
5
----> 6 x = tf.keras.layers.Concatenate(axis=1)([input1, input2, input3, input4])
7
8 x = tf.keras.layers.Dense(2)(x)
/usr/local/lib/python3.8/site-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
65 except Exception as e: # pylint: disable=broad-except
66 filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67 raise e.with_traceback(filtered_tb) from None
68 finally:
69 del filtered_tb
/usr/local/lib/python3.8/site-packages/keras/layers/merging/concatenate.py in build(self, input_shape)
112 ranks = set(len(shape) for shape in shape_set)
113 if len(ranks) != 1:
--> 114 raise ValueError(err_msg)
115 # Get the only rank for the set.
116 (rank,) = ranks
ValueError: A `Concatenate` layer requires inputs with matching shapes except for the concatenation axis. Received: input_shape=[(None, 28, 28, 1), (None, 28, 28, 3), (None, 128), (None, 1)]
如何在single
model 中組合上述inputs
?
錯誤消息實際上是在告訴您問題所在。 除了要連接的維度之外的所有維度都必須相同,而它們不是。 你可以嘗試這樣的事情:
import tensorflow as tf
input1 = tf.keras.layers.Input(shape=(28, 28, 1))
input2 = tf.keras.layers.Input(shape=(28, 28, 3))
input3 = tf.keras.layers.Input(shape=(128,))
input4 = tf.keras.layers.Input(shape=(1,))
input1 = tf.keras.layers.Flatten()(input1)
input2 = tf.keras.layers.Flatten()(input2)
x = tf.keras.layers.Concatenate(axis=-1)([input1, input2, input3, input4])
x = tf.keras.layers.Dense(2)(x)
model = tf.keras.models.Model(inputs=[input1, input2, input3, input4], outputs=x)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.