I am training a CNN followed by DNN to classify images . I want to plot the average of weights at each layer , with the training accuracy at the end of every epoch . I added following code:
avg_weight = []
print_weights = LambdaCallback(on_epoch_end=lambda batch, logs: avg_weight.append([(model.layers[0].get_weights())[0].mean(),logs.get('acc')]))
# Fit the model
history = model.fit_generator(train_gen,
epochs = epochs,
steps_per_epoch = X_train.shape[0] // batch_size,
validation_data = test_gen,
validation_steps = X_test.shape[0] // batch_size,
callbacks = [print_weights])
print(avg_weight)
As you can see above, avg_weight is a 2d array which will store average of layer 0 weights at the end of every epoch, along with accuracy. I have created a Lambdacallback that will help do this. Though I am able to get the mean of the weights of layer 0 at the end of every epoch, I am not getting the training accuracy in the callback. Is there a way to get the training accuracy value in the Lambdacallback function?
I could solve the problem by using logs.get('accuracy') instead of logs.get('acc'). This is a change in recent version of Keras. Final code:
avg_weight = []
print_weights = LambdaCallback(on_epoch_end=lambda batch, logs: avg_weight.append([(model.layers[0].get_weights())[0].mean(),logs.get('accuracy')]))
# Fit the model
history = model.fit_generator(train_gen,
epochs = epochs,
steps_per_epoch = X_train.shape[0] // batch_size,
validation_data = test_gen,
validation_steps = X_test.shape[0] // batch_size,
callbacks = [print_weights])
print(avg_weight)
when I copy and run your code after the first epoch I get the error
list index out of range
Not sure why. Just as a check I created a similar type of callback and it works fine
acc=[]
lc = tf.keras.callbacks.LambdaCallback( on_epoch_end=lambda epochs,logs:acc.append([logs.get('accuracy'), logs.get('val_loss')]))
ran for 2 epochs and printed out the acc list and get
[[0.936170220375061, 0.8074023723602295], [0.9381402730941772, 0.6878575086593628]]
which is correct. Notice you use on_epoch_end which should include parameters epochs,logs but you used batch,logs. I changed it in your code but still got the index error. You could of course create two callbacks one to save the weights in a list and another to save the accuracy. I ran your code without logs.get('accuracy') but still got the index error. Will try a few more things to see if I can find the problem.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.