简体   繁体   中英

keras layer restart a part of the weights each epoch

I am looking for a way to restart a random part of the weights of a layer each epoch (or every n epochs), I found this explaining how to re initialize a layer. I could use

weights = layer.get_weights() 

and then using numpy operation in order to re init a part of the weights, or create a dummy layer extracting new initialized weights from it and use them with set_weights. I am looking for a more elegant way to just initialize a certain (or random) part of my weights in a layer.

Thanks

Keras has set_weights method to set the weights of the layer. To reset the weights of the layer at each epoch use call backs.

class My_Callback(keras.callbacks.Callback):

    def on_epoch_begin(self, logs={}):
        return

    def on_epoch_end(self, epoch, logs={}):
        layer_index = 0  ## index of the layer you want to change
        # random weights to reset the layer
        new_weights = numpy.random.randn(*self.model.layers[layer_index].get_weights().shape)

        self.model.layers[layer_index].set_weights(new_weights)

Edit:

To reset random n weights of a layer, one can use numpy to get random indexes to reset. Now the code would be

    def on_epoch_end(self, epoch, logs={}):
        layer_index = np.random.randint(len(self.model.layers)) # Random layer index to reset
        weights_shape = self.model.layers.get_weights().shape
        num = 10 # number of weights to reset
        indexes = np.random.choice(weights_shape[0], num, replace=False)   # indexes of the layer to reset
        reset_weights = numpy.random.randn(*weights_shape[1:]) # random weights to reset the layer

        layer_weights = self.model.layers[layer_index].get_weights()
        layer_weights[indexes] = reset_weights
        self.model.layers[layer_index].set_weights(layer_weights)

Similarly to reset random p % of weights of a layer, first numpy can be used to select p % indexes of layer weights.

    def on_epoch_end(self, epoch, logs={}):
        layer_index = np.random.randint(len(self.model.layers)) # Random layer index to reset
        weights_shape = self.model.layers.get_weights().shape
        percent = 10 # Percentage of weights to reset
        indexes = np.random.choice(weights_shape[0], int(percent/100.) * weights_shape[0], replace=False)   # indexes of the layer  to reset
        reset_weights = numpy.random.randn(*weights_shape[1:]) # random weights to reset the layer

        layer_weights = self.model.layers[layer_index].get_weights()
        layer_weights[indexes] = reset_weights
        self.model.layers[layer_index].set_weights(layer_weights)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM