简体   繁体   English

keras model 总是预测类别而不是概率

[英]keras model always predicting classes instead of probabilities

I am trying to do a binary classification using MLP by outputting a 2 neuron Dense layer with softmax as its activation function.我正在尝试使用 MLP 进行二进制分类,方法是输出一个带有 softmax 的 2 个神经元密集层作为其激活 function。 But i keep getting predicted classes instead.但我不断得到预测的课程。

So what am i doing wrong here?那么我在这里做错了什么?

Please note that i this is a shorter version of my main NN.请注意,这是我的主要 NN 的较短版本。

inp = tf.keras.layers.Input(shape = (len(feat_cols),))
x = tf.keras.layers.Dense(32, activation = "relu")(inp)
out = tf.keras.layers.Dense(2, activation= "softmax")(x)
model = tf.keras.models.Model(inputs = inp, outputs = out)
model.compile(loss='binary_crossentropy',
          optimizer=tf.keras.optimizers.RMSprop(),
          metrics=['accuracy',"roc_auc"])
model.fit(temp[feat_cols].values.astype('float64'), np_utils.to_categorical(temp[target].values.astype('float64')), epochs = 1)
model.predict(test_df[feat_cols].values.astype("float64"))

My output for the above model is:上述 model 的我的 output 是:

750/750 [==============================] - 2s 2ms/step - loss: 948.7341 - accuracy: 0.5005
array([[0., 1.],
       [0., 1.],
       [1., 0.],
       ...,
       [1., 0.],
       [1., 0.],
       [1., 0.]], dtype=float32)

Edit:编辑:

The target is 0 or 1 labels and they are provided to the model as one hot encoded labels.目标是 0 或 1 个标签,它们作为一个热编码标签提供给 model。 And i am using my sklearn roc_auc metric passed as a py_function.我正在使用我的 sklearn roc_auc 指标作为 py_function 传递。 I have also tried by removing the.values.astype("float64") part...still so difference.我还尝试通过删除 the.values.astype("float64") 部分......仍然如此不同。

You use binary_crossentropy loss and 2 neuron output.您使用 binary_crossentropy 损失和 2 个神经元 output。 usually, binary classification means you have 1 output and tries to classify between 0 and 1. try categorical_crossentropy as your loss, you'll get the probability of prediction.通常,二进制分类意味着您有 1 个 output 并尝试在 0 和 1 之间进行分类。尝试 categorical_crossentropy 作为您的损失,您将获得预测的概率。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM