[英]How to use multiple inputs in the keras model
我想将four multiple inputs
组合成single
keras model,但它需要inputs with matching shapes
:
import tensorflow as tf
input1 = tf.keras.layers.Input(shape=(28, 28, 1))
input2 = tf.keras.layers.Input(shape=(28, 28, 3))
input3 = tf.keras.layers.Input(shape=(128,))
input4 = tf.keras.layers.Input(shape=(1,))
x = tf.keras.layers.Concatenate(axis=1)([input1, input2, input3, input4])
x = tf.keras.layers.Dense(2)(x)
model = tf.keras.models.Model(inputs=[input1, input2, input3, input4], outputs=x)
这是 output
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
/tmp/ipykernel_3447/2584043467.py in <cell line: 6>()
4 input4 = tf.keras.layers.Input(shape=(1,))
5
----> 6 x = tf.keras.layers.Concatenate(axis=1)([input1, input2, input3, input4])
7
8 x = tf.keras.layers.Dense(2)(x)
/usr/local/lib/python3.8/site-packages/keras/utils/traceback_utils.py in error_handler(*args, **kwargs)
65 except Exception as e: # pylint: disable=broad-except
66 filtered_tb = _process_traceback_frames(e.__traceback__)
---> 67 raise e.with_traceback(filtered_tb) from None
68 finally:
69 del filtered_tb
/usr/local/lib/python3.8/site-packages/keras/layers/merging/concatenate.py in build(self, input_shape)
112 ranks = set(len(shape) for shape in shape_set)
113 if len(ranks) != 1:
--> 114 raise ValueError(err_msg)
115 # Get the only rank for the set.
116 (rank,) = ranks
ValueError: A `Concatenate` layer requires inputs with matching shapes except for the concatenation axis. Received: input_shape=[(None, 28, 28, 1), (None, 28, 28, 3), (None, 128), (None, 1)]
如何在single
model 中组合上述inputs
?
错误消息实际上是在告诉您问题所在。 除了要连接的维度之外的所有维度都必须相同,而它们不是。 你可以尝试这样的事情:
import tensorflow as tf
input1 = tf.keras.layers.Input(shape=(28, 28, 1))
input2 = tf.keras.layers.Input(shape=(28, 28, 3))
input3 = tf.keras.layers.Input(shape=(128,))
input4 = tf.keras.layers.Input(shape=(1,))
input1 = tf.keras.layers.Flatten()(input1)
input2 = tf.keras.layers.Flatten()(input2)
x = tf.keras.layers.Concatenate(axis=-1)([input1, input2, input3, input4])
x = tf.keras.layers.Dense(2)(x)
model = tf.keras.models.Model(inputs=[input1, input2, input3, input4], outputs=x)
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.