![](/img/trans.png)
[英]Can't fix ValueError: Building a simple neural network model in Keras
[英]ValueError building a neural network with 2 outputs in Keras
我嘗試構建一個具有單個輸入 X(大小為 Xa*Xb 的二維矩陣)和 2 個輸出 Y1 和 Y2(均為一維)的網絡。 即使我在下面發布的代碼中不是這種情況,Y1 應該是一個輸出單熱向量的分類器,而 Y2 應該是用於回歸(原始代碼引發了相同的錯誤)。
在訓練網絡時,我收到以下錯誤:
ValueError: Shapes (None, None) and (None, 17, 29) are incompatible
顯然, (None, 17, 29)
轉換為(None, size_Xa, size_Y1)
,我不明白為什么 Xa 和 Y1 首先應該相關(獨立於 Xb)。
這是我的代碼。 我試圖將其減少到最低限度,以便更容易理解。
import numpy as np
from keras.layers import Dense, LSTM, Input
from keras.models import Model
def dataGenerator():
while True:
yield makeBatch()
def makeBatch():
"""generates a batch of artificial training data"""
x_batch, y_batch = [], {}
x_batch = np.random.rand(batch_size, size_Xa, size_Xb)
#x_batch = np.random.rand(batch_size, size_Xa)
y_batch['output1'] = np.random.rand(batch_size, size_Y1)
y_batch['output2'] = np.random.rand(batch_size, size_Y2)
return x_batch, y_batch
def generate_model():
input_layer = Input(shape=(size_Xa, size_Xb))
#input_layer = Input(shape=(size_Xa))
common_branch = Dense(128, activation='relu')(input_layer)
branch_1 = Dense(size_Y1, activation='softmax', name='output1')(common_branch)
branch_2 = Dense(size_Y2, activation='relu', name='output2')(common_branch)
model = Model(inputs=input_layer,outputs=[branch_1,branch_2])
losses = {"output1":"categorical_crossentropy", "output2":"mean_absolute_error"}
model.compile(optimizer="adam",
loss=losses,
metrics=['accuracy'])
return model
batch_size=5
size_Xa = 17
size_Xb = 13
size_Y2 = 100
size_Y1 = 29
model = generate_model()
model.fit( x=dataGenerator(),
steps_per_epoch=50,
epochs=15,
validation_data=dataGenerator(), validation_steps=50, verbose=1)
如果我取消注釋 makeBatch 和 generate_model 中的 2 行注釋,錯誤就會消失。 因此,如果輸入 X 在 1 維中運行,但是當我將其更改為 2 維(保持其他所有內容相同)時,會出現錯誤。
這與具有 2 個輸出的架構有關嗎? 我認為我在這里缺少一些東西,歡迎任何幫助。
我添加了完整的錯誤日志以供參考:
Epoch 1/15
Traceback (most recent call last):
File "neuralnet_minimal.py", line 41, in <module>
model.fit( x=dataGenerator(),
File "/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 66, in _method_wrapper
return method(self, *args, **kwargs)
File "/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 848, in fit
tmp_logs = train_function(iterator)
File "/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 580, in __call__
result = self._call(*args, **kwds)
File "/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 627, in _call
self._initialize(args, kwds, add_initializers_to=initializers)
File "/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 505, in _initialize
self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access
File "/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2446, in _get_concrete_function_internal_garbage_collected
graph_function, _, _ = self._maybe_define_function(args, kwargs)
File "/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2777, in _maybe_define_function
graph_function = self._create_graph_function(args, kwargs)
File "/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2657, in _create_graph_function
func_graph_module.func_graph_from_py_func(
File "/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 981, in func_graph_from_py_func
func_outputs = python_func(*func_args, **func_kwargs)
File "/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 441, in wrapped_fn
return weak_wrapped_fn().__wrapped__(*args, **kwds)
File "/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 968, in wrapper
raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:
/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:571 train_function *
outputs = self.distribute_strategy.run(
/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:951 run **
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:2290 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:2649 _call_for_each_replica
return fn(*args, **kwargs)
/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:532 train_step **
loss = self.compiled_loss(
/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/compile_utils.py:205 __call__
loss_value = loss_obj(y_t, y_p, sample_weight=sw)
/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/keras/losses.py:143 __call__
losses = self.call(y_true, y_pred)
/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/keras/losses.py:246 call
return self.fn(y_true, y_pred, **self._fn_kwargs)
/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/keras/losses.py:1527 categorical_crossentropy
return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits)
/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/keras/backend.py:4561 categorical_crossentropy
target.shape.assert_is_compatible_with(output.shape)
/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:1117 assert_is_compatible_with
raise ValueError("Shapes %s and %s are incompatible" % (self, other))
ValueError: Shapes (None, None) and (None, 17, 29) are incompatible
奇怪的是,當我在網絡分裂之前添加一個Flatten()
層時,錯誤消失了......這與網絡的形狀有關,但我仍然不明白這背后的真正原因。
我會將其標記為正確答案,因為它解決了問題,除非其他人發布了一些東西。 如果這不是正確的方法,請告訴我。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.