繁体   English   中英

交叉熵验证损失是一条直线

[英]Cross-entropy validation losses comes out as a straight line

我正在尝试使用 Iris 数据集计算交叉熵损失,但是当我运行 model 并启动我的绘图时,我的损失和验证损失都保持为零。 我不知道我做错了什么。 这是我的代码:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from tensorflow import keras
from keras import Sequential
from keras.layers import BatchNormalization, Dense, Dropout
from keras.callbacks import EarlyStopping

iris = sns.load_dataset('iris')
X = iris.iloc[:,:4]
y = iris.species.replace({'setosa': 0, 'versicolor': 1, 'virginica': 2})

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.3, random_state=69)

sc = StandardScaler()
sc.fit_transform(X_train)
sc.fit_transform(X_test)

nn_model = Sequential([Dense(4, activation='relu', input_shape=[X.shape[1]]),
                     BatchNormalization(),
                     Dropout(.3),
                     Dense(4, activation='relu'),
                     BatchNormalization(),
                     Dropout(.3),
                     Dense(1, activation='sigmoid')])

nn_model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['categorical_accuracy'])

early_stopping = EarlyStopping(min_delta=1e-3, patience=10, restore_best_weights=True)

fit = nn_model.fit(X_train, y_train, validation_data=(X_test,y_test), 
                   batch_size=16, epochs=200, callbacks=[early_stopping], verbose=1)

losses = pd.DataFrame(fit.history)

这就是情节的样子:

在此处输入图像描述

在此处输入图像描述

有什么理由这样做吗?

StandardScaler()的拟合转换不是就地操作。 您必须执行以下操作

sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)

此外,您有 3 个输出(检查: y_train.value_counts() ),因此 output 层应该是:

 nn_model = Sequential([ ..., 
                       Dropout(.3),
                       Dense(3, activation='softmax')])

最后,对于您的 integer 目标,丢失的 function 应该是sparse_categorical_crossentropy

nn_model.compile(optimizer='sgd', 
                 loss='sparse_categorical_crossentropy', 
                 metrics=['accuracy'])

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM