簡體   English   中英

ValueError 在 Keras 中構建具有 2 個輸出的神經網絡

[英]ValueError building a neural network with 2 outputs in Keras

我嘗試構建一個具有單個輸入 X(大小為 Xa*Xb 的二維矩陣)和 2 個輸出 Y1 和 Y2(均為一維)的網絡。 即使我在下面發布的代碼中不是這種情況,Y1 應該是一個輸出單熱向量的分類器,而 Y2 應該是用於回歸(原始代碼引發了相同的錯誤)。

在訓練網絡時,我收到以下錯誤:

ValueError: Shapes (None, None) and (None, 17, 29) are incompatible

顯然, (None, 17, 29)轉換為(None, size_Xa, size_Y1) ,我不明白為什么 Xa 和 Y1 首先應該相關(獨立於 Xb)。

這是我的代碼。 我試圖將其減少到最低限度,以便更容易理解。

import numpy as np
from keras.layers import Dense, LSTM, Input
from keras.models import Model

def dataGenerator():
    while True:
        yield makeBatch()
def makeBatch():
    """generates a batch of artificial training data"""
    x_batch, y_batch = [], {}
    x_batch = np.random.rand(batch_size, size_Xa, size_Xb)
    #x_batch = np.random.rand(batch_size, size_Xa)
    y_batch['output1'] = np.random.rand(batch_size, size_Y1)
    y_batch['output2'] = np.random.rand(batch_size, size_Y2)
    return x_batch, y_batch

def generate_model():
    input_layer = Input(shape=(size_Xa, size_Xb))
    #input_layer = Input(shape=(size_Xa))
    common_branch = Dense(128, activation='relu')(input_layer)
    branch_1  = Dense(size_Y1, activation='softmax', name='output1')(common_branch)
    branch_2  = Dense(size_Y2, activation='relu',    name='output2')(common_branch)
    model = Model(inputs=input_layer,outputs=[branch_1,branch_2])

    losses = {"output1":"categorical_crossentropy", "output2":"mean_absolute_error"}
    model.compile(optimizer="adam",
                        loss=losses,
                        metrics=['accuracy'])
    return model

batch_size=5
size_Xa = 17
size_Xb = 13
size_Y2 = 100 
size_Y1 = 29

model = generate_model()

model.fit(  x=dataGenerator(),
            steps_per_epoch=50,
            epochs=15,
            validation_data=dataGenerator(), validation_steps=50, verbose=1)

如果我取消注釋 makeBatch 和 generate_model 中的 2 行注釋,錯誤就會消失。 因此,如果輸入 X 在 1 維中運行,但是當我將其更改為 2 維(保持其他所有內容相同)時,會出現錯誤。

這與具有 2 個輸出的架構有關嗎? 我認為我在這里缺少一些東西,歡迎任何幫助。

我添加了完整的錯誤日志以供參考:

Epoch 1/15
Traceback (most recent call last):
  File "neuralnet_minimal.py", line 41, in <module>
    model.fit(  x=dataGenerator(),
  File "/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 66, in _method_wrapper
    return method(self, *args, **kwargs)
  File "/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py", line 848, in fit
    tmp_logs = train_function(iterator)
  File "/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 580, in __call__
    result = self._call(*args, **kwds)
  File "/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 627, in _call
    self._initialize(args, kwds, add_initializers_to=initializers)
  File "/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 505, in _initialize
    self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
  File "/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2446, in _get_concrete_function_internal_garbage_collected
    graph_function, _, _ = self._maybe_define_function(args, kwargs)
  File "/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2777, in _maybe_define_function
    graph_function = self._create_graph_function(args, kwargs)
  File "/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/eager/function.py", line 2657, in _create_graph_function
    func_graph_module.func_graph_from_py_func(
  File "/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 981, in func_graph_from_py_func
    func_outputs = python_func(*func_args, **func_kwargs)
  File "/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/eager/def_function.py", line 441, in wrapped_fn
    return weak_wrapped_fn().__wrapped__(*args, **kwds)
  File "/path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/framework/func_graph.py", line 968, in wrapper
    raise e.ag_error_metadata.to_exception(e)
ValueError: in user code:

    /path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:571 train_function  *
        outputs = self.distribute_strategy.run(
    /path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:951 run  **
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    /path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:2290 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    /path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/distribute/distribute_lib.py:2649 _call_for_each_replica
        return fn(*args, **kwargs)
    /path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/training.py:532 train_step  **
        loss = self.compiled_loss(
    /path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/keras/engine/compile_utils.py:205 __call__
        loss_value = loss_obj(y_t, y_p, sample_weight=sw)
    /path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/keras/losses.py:143 __call__
        losses = self.call(y_true, y_pred)
    /path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/keras/losses.py:246 call
        return self.fn(y_true, y_pred, **self._fn_kwargs)
    /path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/keras/losses.py:1527 categorical_crossentropy
        return K.categorical_crossentropy(y_true, y_pred, from_logits=from_logits)
    /path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/keras/backend.py:4561 categorical_crossentropy
        target.shape.assert_is_compatible_with(output.shape)
    /path/of/my/project/venv/lib/python3.8/site-packages/tensorflow/python/framework/tensor_shape.py:1117 assert_is_compatible_with
        raise ValueError("Shapes %s and %s are incompatible" % (self, other))

    ValueError: Shapes (None, None) and (None, 17, 29) are incompatible

奇怪的是,當我在網絡分裂之前添加一個Flatten()層時,錯誤消失了......這與網絡的形狀有關,但我仍然不明白這背后的真正原因。

我會將其標記為正確答案,因為它解決了問題,除非其他人發布了一些東西。 如果這不是正確的方法,請告訴我。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM