[英]ValueError: Unknown loss function: categorical crossentropy. Please ensure this object is passed to the `custom_objects` argument
[英]ValueError: Unknown loss function: custom_loss_function. Please ensure this object is passed to the `custom_objects` argument
我正在嘗試使用此自定義損失函數訓練我的模型: 1
其中 S(pn;ω) 是預測值(y_pred),MOSn 是目標(y_true),所以我是這樣寫的:
import keras.backend as K
def custom_loss_function(y_true,y_pred):
for i in range(1,n+1):
l= K.abs(y_pred-y_true)
l = K.mean(l, axis=-1)
return l
然后我建立了我的模型:
#Model definition
from keras import models
from keras import layers
def build_model():
model = models.Sequential()
model.add(layers.Conv2D(32, (5, 5), activation='relu', input_shape=(32, 32, 1)))
model.add(layers.MaxPooling2D((2, 2)))
model.add(layers.Conv2D(32, (5, 5), activation='relu'))
model.add(layers.MaxPooling2D((10, 10)))
model.add(layers.Dense(250, activation='relu'))
model.add(layers.Dense(250, activation='relu'))
model.add(layers.Dense(1))
model.compile(optimizer='rmsprop', loss='custom_loss_function', metrics=['mae'])
return model
model = build_model()
但是當我運行訓練過程時:
num_epochs = 20
history = model.fit(train_data, train_labels, epochs=num_epochs, batch_size=None, verbose=0)
我收到此錯誤:
ValueError Traceback (most recent call last)
<ipython-input-16-eb44c7fa4ec7> in <module>()
1 num_epochs = 20
----> 2 history = model.fit(train_data, train_labels, epochs=num_epochs, batch_size=None, verbose=0)
9 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/func_graph.py in wrapper(*args, **kwargs)
992 except Exception as e: # pylint:disable=broad-except
993 if hasattr(e, "ag_error_metadata"):
--> 994 raise e.ag_error_metadata.to_exception(e)
995 else:
996 raise
ValueError: in user code:
/usr/local/lib/python3.7/dist-packages/keras/engine/training.py:853 train_function *
return step_function(self, iterator)
/usr/local/lib/python3.7/dist-packages/keras/engine/training.py:842 step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
/usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:1286 run
return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:2849 call_for_each_replica
return self._call_for_each_replica(fn, args, kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/distribute/distribute_lib.py:3632 _call_for_each_replica
return fn(*args, **kwargs)
/usr/local/lib/python3.7/dist-packages/keras/engine/training.py:835 run_step **
outputs = model.train_step(data)
/usr/local/lib/python3.7/dist-packages/keras/engine/training.py:789 train_step
y, y_pred, sample_weight, regularization_losses=self.losses)
/usr/local/lib/python3.7/dist-packages/keras/engine/compile_utils.py:184 __call__
self.build(y_pred)
/usr/local/lib/python3.7/dist-packages/keras/engine/compile_utils.py:133 build
self._losses = tf.nest.map_structure(self._get_loss_object, self._losses)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/nest.py:869 map_structure
structure[0], [func(*x) for x in entries],
/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/nest.py:869 <listcomp>
structure[0], [func(*x) for x in entries],
/usr/local/lib/python3.7/dist-packages/keras/engine/compile_utils.py:273 _get_loss_object
loss = losses_mod.get(loss)
/usr/local/lib/python3.7/dist-packages/keras/losses.py:2136 get
return deserialize(identifier)
/usr/local/lib/python3.7/dist-packages/keras/losses.py:2095 deserialize
printable_module_name='loss function')
/usr/local/lib/python3.7/dist-packages/keras/utils/generic_utils.py:709 deserialize_keras_object
.format(printable_module_name, object_name))
ValueError: Unknown loss function: custom_loss_function. Please ensure this object is passed to the `custom_objects` argument. See https://www.tensorflow.org/guide/keras/save_and_serialize#registering_the_custom_object for details.
我閱讀了有關“自定義對象”參數的詳細信息並嘗試應用它,但仍然無法弄清楚,我究竟如何將我的自定義函數傳遞給“自定義對象”參數?
在 python 中,您可以將函數作為自包含對象或Callable
傳遞。 在這種情況下,您可以傳遞不帶單引號的參數loss=custom_loss_function
。 這是有關Tensorflow 中可調用對象和自定義損失函數的更多信息。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.