[英]How to get the weights of layer in a keras model for each input
I know that you can you Model.layer[layer_number].getWeights() to get weight of layer from a keras model at a certain point.我知道您可以通过 Model.layer[layer_number].getWeights() 在某个点从 keras model 获取层的权重。 I am only to get those weights for an epoch or a batch using callbacks during training.
我只是在训练期间使用回调来获得一个时期或一批的权重。
But I want to get the weights of the layer for each input in the training part.但我想获得训练部分中每个输入的层权重。 Or if possible the activation of a layer for each input instead of an epoch.
或者如果可能的话,为每个输入激活一个层而不是一个纪元。
Is there a way to achieve that?有没有办法做到这一点?
This is a small example.这是一个小例子。 You can use
custom callbacks
inside which you can access model's weights by layers (including activations ( layers.Activation
)).您可以使用
custom callbacks
,您可以在其中按层访问模型的权重(包括激活( layers.Activation
))。 Just change based on your needs.只需根据您的需要进行更改。
This will print the weights after each epoch, you can plot them/ save them too or run any operations on them if you want.这将在每个 epoch 后打印权重,您可以 plot 它们/也保存它们或根据需要对它们运行任何操作。
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.optimizers import Adam
import tensorflow as tf
import numpy as np
from keras.callbacks import LambdaCallback
model=Sequential()
model.add(Dense(32,activation='linear',input_shape=(37,10)))
model.add(Dense(32,activation='linear'))
model.add(Dense(10,activation='linear'))
model.compile(loss='mse',optimizer=Adam(lr=.001),metrics=['accuracy'])
model.summary()
class MyCustomCallback(tf.keras.callbacks.Callback):
def on_train_batch_begin(self, batch, logs=None):
print(model.layers[0].get_weights())
def on_train_batch_end(self, batch, logs=None):
print(model.layers[0].get_weights())
def on_test_batch_begin(self, batch, logs=None):
pass
def on_test_batch_end(self, batch, logs=None):
pass
X_train = np.zeros((10,37,10))
y_train = np.zeros((10,37,10))
weight_print = MyCustomCallback()
model.fit(X_train,
y_train,
batch_size=32,
epochs=5,
callbacks = [weight_print])
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.