![](/img/trans.png)
[英]How to plot the ROC curve for ANN for 10 fold Cross validation in Keras using Python?
[英]How can we build a ROC curve for customized ANN Model on Python?
我正在嘗試在 Python 上構建定制的 ANN Model。 我構建 model 的方法如下:
def binary_class(x_train,nodes,activation,n):
#Creating customized ANN Model
model=Sequential()
for i in range(len(nodes)):
if(i==0):
if(activation=='sigmoid'):
model.add(Dense(units = nodes[i], kernel_initializer = 'glorot_uniform',activation='sigmoid',input_dim = len(x_train[1])))
if(activation=='relu'):
model.add(Dense(units = nodes[i], kernel_initializer = 'he_uniform',activation='relu',input_dim = len(x_train[1])))
if(activation=='tanh'):
model.add(Dense(units = nodes[i], kernel_initializer = 'glorot_normal',activation='tanh',input_dim = len(x_train[1])))
if(activation=='softmax'):
model.add(Dense(units = nodes[i], kernel_initializer = 'glorot_normal',activation='softmax',input_dim = len(x_train[1])))
if(activation== 'elu'):
model.add(Dense(units = nodes[i], kernel_initializer = 'he_normal',activation='elu',input_dim = len(x_train[1])))
if(activation=='softplus'):
model.add(Dense(units = nodes[i], kernel_initializer = 'he_normal',activation='softplus',input_dim = len(x_train[1])))
else:
if(activation=='sigmoid'):
model.add(Dense(units = nodes[i], kernel_initializer = 'glorot_uniform',activation='sigmoid'))
if(activation=='relu'):
model.add(Dense(units = nodes[i], kernel_initializer = 'he_uniform',activation='relu'))
if(activation=='tanh'):
model.add(Dense(units = nodes[i], kernel_initializer = 'glorot_normal',activation='tanh'))
if(activation=='softmax'):
model.add(Dense(units = nodes[i], kernel_initializer = 'glorot_uniform',activation='softmax'))
if(activation=='elu'):
model.add(Dense(units = nodes[i], kernel_initializer = 'he_normal',activation='elu'))
if(activation=='softplus'):
model.add(Dense(units = nodes[i], kernel_initializer = 'he_normal',activation='softplus'))
model.add(Dropout(n))
#Adding output layer
model.add(Dense(units=1, kernel_initializer = 'glorot_uniform',activation='sigmoid'))
return model
我的優化器 function 如下:
def optibin(model,opt,x_train,y_train,spl,bs,epochs,x_test,y_test):
#Choosing the proper optimizer to use
if(opt=='sgd'):
print("Enter Momentum:")
mom=float(input())
lr=float(input("Enter value of Learning rate:"))
opti=keras.optimizers.SGD(learning_rate=lr, momentum=mom, nesterov=False)
if(opt=='Adam'):
lr=float(input("Enter value of Learning rate:"))
opti=keras.optimizers.Adam(learning_rate=lr)
if(opt=='Adamax'):
lr=float(input("Enter value of Learning rate:"))
beta_1=float(input("Enter value of beta 1 (Generally close to 1)"))
beta_2=float(input("Enter value of beta 2 (Generally close to 1)"))
opti=keras.optimizers.Adamax(learning_rate=lr, beta_1=beta_1, beta_2=beta_2)
if(opt=='Nadam'):
lr=float(input("Enter value of Learning rate:"))
beta_1=float(input("Enter value of beta 1 (Generally close to 1)"))
beta_2=float(input("Enter value of beta 2 (Generally close to 1)"))
opti=keras.optimizers.Nadam(learning_rate=lr, beta_1=beta_1, beta_2=beta_2)
if(opt=='RMSprop'):
lr=float(input("Enter value of Learning rate:"))
opti=keras.optimizers.RMSprop(learning_rate=lr)
if(opt=='Adagrad'):
lr=float(input("Enter value of Learning rate:"))
opti=keras.optimizers.Adagrad(learning_rate=lr)
model.compile(optimizer = opti, loss = 'binary_crossentropy', metrics = ['accuracy'])
model_history=model.fit(x_train, y_train,validation_split=spl, batch_size = bs,epochs = epochs)
return model_history, model
我必須嘗試創建 model 的性能指標,其中之一是構建 ROC 和 AUC。 我使用 sklearn 來制作混淆矩陣、特異性和敏感性。 但我也需要制作 ROC 曲線。 我們如何從中構建 ROC 曲線?
這樣的事情應該可以解決問題:
from sklearn import metrics
fpr, tpr, thresholds = metrics.roc_curve(true_values, predicted_values, pos_label=1)
roc_auc = metrics.auc(fpr, tpr)
lw = 2
plt.figure()
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--', alpha=0.15)
plt.plot(fpr, tpr, lw=lw, label=f'ROC curve (area = {roc_auc: 0.2f})')
plt.xlabel('(1–Specificity) - False Positive Rate')
plt.ylabel('Sensitivity - True Positive Rate')
plt.title(f'Receiver Operating Characteristic')
plt.legend(loc="lower right")
plt.show()
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.