简体   繁体   中英

Update sample weights in a Keras callback?

I am trying to write a custom callback that would update sample weights on the epoch ended. I initialize the custom callback with the original weight but I am not sure how to make sure keras use the new sample weight defined in the callback for fitting the model. Here is a simple example of my code (* 2 is an example and shouldn't do anything in practice).

class update_weights(tf.keras.callbacks.Callback):
    def __init__(self, sample_weight):
        self.sample_weight = sample_weight

    def on_epoch_end(self, epoch, logs={}):
        self.sample_weight = self.sample_weight * 2

It seems possible to access model parameters from a simple self.model. However I have a hard time accessing the parameters of the fit function. Is it even possible? Would I be able to modify the parameters of the fit function while calibrating the model?

Now that I understand what you want to accomplish it is an interesting question. I guess to find out if the sample weights are actually changing in model.fit (I doubt that it is) you could pick an epoch in on_epoch_end and then set the sample_weight to all zeros for that epoch. Presumably if in fact this changes the value in model.fit when you run the next epoch the resulting loss should be dramatically different. Try something like

 def on_epoch_end(self, epoch, logs={}):
        is epoch==3:
            self.sample_weight = self.sample_weight * 0

On epoch 4 if the sample weights are truly changing I would expect the loss would be zero. Let me know what happens.

I am trying to do the same thing. could you please share what worked with you? Thank you!

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM