简体   繁体   English

多任务 CNN model 提前停止,带有验证损失

[英]Multi tasking CNN model early stop with validation loss

I am trying to fit a multitasking model with validation data and tracing validation loss for early stopping.我正在尝试为多任务 model 安装验证数据并跟踪验证损失以提前停止。 Is there any way to trace and early stop with validation loss?有什么方法可以追踪验证损失并提前停止? My demo code is following it shows a warning that validation loss is not available.我的演示代码紧随其后,它显示了验证丢失不可用的警告。


    def main_model(height, width): 
            input_img = Input(shape = (height, width, 1))
            conv01_1 = Conv2D(64, (3, 3), activation='relu', padding='same')(input_img)
            pool01_1 = AveragePooling2D(pool_size=(2, 2),strides=None, padding="same")( conv01_1)
            batch_nor01_1= BatchNormalization()(pool01_1)
            drout01_1= Dropout(0.1)(batch_nor01_1)
            flatten_layer = Flatten()(drout01_1)
            x1_dense = Dense(512,activation='relu')( flatten_layer )
            out_1=Dense(6,activation = "softmax",name='activity_output')( x1_dense)
            out_2=Dense(1,activation='linear',name='energy_output')( x1_dense)
            model = Model(inputs=input_img,outputs = [out_1,out_2])
            model.compile(optimizer=SGD(lr=0.001,momentum=0.9),loss={'activity_output':'categorical_crossentropy', 'energy_output': 'mse'},loss_weights={'activity_output': 0.5, 'energy_output': 0.5},metrics=['accuracy','mae'])
            model.summary()
            return model
    
    model_name=s+'_best_model.h5'
    mc = ModelCheckpoint(model_name, monitor='validation_loss', mode='auto', verbose=1, save_best_only=True)
    es = EarlyStopping(monitor='validation_loss',min_delta=0,patience=20,verbose=0, mode='auto')
    ```
    
    batch_size=500
    epochs=1
    model=main_model(height, width)
    History = model.fit(x_train,[y_train,y_train_1],epochs = epochs, validation_data = (x_valid,y_valid,y_valid_1),verbose = 1,callbacks=[callback_test,es,lrs,mc,])

'''


        

I have got the solution.我有解决办法。 Basically, I have replaced the validation_loss with val_loss so the code is now:基本上,我已经用val_loss替换了validation_loss ,所以现在的代码是:

mc = ModelCheckpoint(model_name, monitor='val_loss', mode='min', verbose=1, 
                     save_best_only=True)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM