简体   繁体   中英

How to get the weights of layer in a keras model for each input

I know that you can you Model.layer[layer_number].getWeights() to get weight of layer from a keras model at a certain point. I am only to get those weights for an epoch or a batch using callbacks during training.

But I want to get the weights of the layer for each input in the training part. Or if possible the activation of a layer for each input instead of an epoch.

Is there a way to achieve that?

This is a small example. You can use custom callbacks inside which you can access model's weights by layers (including activations ( layers.Activation )). Just change based on your needs.

This will print the weights after each epoch, you can plot them/ save them too or run any operations on them if you want.

from tensorflow.keras.layers import *
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.optimizers import Adam
import tensorflow as tf
import numpy as np
from keras.callbacks import LambdaCallback

model=Sequential()
model.add(Dense(32,activation='linear',input_shape=(37,10)))
model.add(Dense(32,activation='linear'))
model.add(Dense(10,activation='linear'))
model.compile(loss='mse',optimizer=Adam(lr=.001),metrics=['accuracy'])
model.summary()

class MyCustomCallback(tf.keras.callbacks.Callback):

  def on_train_batch_begin(self, batch, logs=None):
    print(model.layers[0].get_weights())

  def on_train_batch_end(self, batch, logs=None):
    print(model.layers[0].get_weights())

  def on_test_batch_begin(self, batch, logs=None):
    pass

  def on_test_batch_end(self, batch, logs=None):
    pass


X_train = np.zeros((10,37,10))
y_train = np.zeros((10,37,10))

weight_print = MyCustomCallback()
model.fit(X_train, 
          y_train, 
          batch_size=32, 
          epochs=5, 
          callbacks = [weight_print])

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM