[英]Why does keras model predicts all as ones if used with one-hot labels and categorical_crossentropy amnd softmax output
I have a simple tf.keras model:我有一个简单的 tf.keras model:
inputs = keras.Input(shape=(9824,))
dense = layers.Dense(512, activation=keras.activations.relu, kernel_initializer=init)
x = dense(inputs)
x = layers.Dense(512, activation=keras.activations.relu)(x)
outputs = layers.Dense(3, activation=keras.activations.softmax)(x)
model = keras.Model(inputs=inputs, outputs=outputs)
When I compile it with sparse categorical crossentropy and actual labels it works as expected.当我用稀疏的分类交叉熵和实际标签编译它时,它按预期工作。 But when I tried to one-hot encode the labels (with
tf.keras.utils.to_categorical
) and use categorical_crossentropy (so I could use recall and precision as metrics during training), the model predicts everything as ones:但是,当我尝试对标签进行一次热编码(使用
tf.keras.utils.to_categorical
)并使用 categorical_crossentropy (因此我可以在训练期间使用召回率和精度作为指标)时,model 将所有内容预测为:
>>>print(predictions)
[[1. 1. 1.]
[1. 1. 1.]
[1. 1. 1.]
...
[1. 1. 1.]
[1. 1. 1.]
[1. 1. 1.]]
If I understand correctly, softmax activation in the output layer should cause the output to be in range (0,1) and sum to 1. So, how is it possible that the class predictions are all 1?如果我理解正确,output 层中的 softmax 激活应该会导致 output 在范围内(0,1)并且总和为 1。那么, ZA2F2ED4F8EBC2CBB4C21A29DC40AB6 怎么可能是所有预测? I was searching for an answer for hours but to no avail.
我一直在寻找答案几个小时,但无济于事。
Here is a minimalistic example .这是一个简约的例子。
I forgot to mention that I use scikeras package.我忘了说我用的是scikeras package。 Based on examples in the documentation , I assume the model is compiled implicitly.
根据文档中的示例,我假设 model 是隐式编译的。 Here is the classifier constructor:
这是分类器构造函数:
clf = KerasClassifier(
model=keras_model_target,
loss=SparseCategoricalCrossentropy(),
name="model_target",
optimizer=Adam(),
init=GlorotUniform(),
metrics=[SparseCategoricalAccuracy()],
epochs=5,
batch_size=128
)
I fit the model with我适合 model
result = clf.fit(x_train, y_train)
and predict with:并预测:
predictions = clf.predict(x)
This was a bug in SciKeras that was fixed with the v0.3.1 release.这是 SciKeras 中的一个错误,已在 v0.3.1 版本中修复。 Updating to the latest version should fix the issue.
更新到最新版本应该可以解决问题。
As for the bug itself, it was due to how we were indexing numpy arrays, see this diff for details.至于错误本身,这是由于我们如何索引 numpy arrays,请参阅此差异了解详细信息。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.