[英]In Keras, how to use custom callbacks during prediction
I have a keras
model like this-我有一个像这样的
keras
模型-
And for this model, I have a callback function
like this-对于这个模型,我有一个这样的
callback function
——
import tensorflow as tf
import numpy as np
class WriteLayerValCallback(tf.keras.callbacks.Callback):
def __init__(self):
self.data = np.random.rand(1,10)
def on_epoch_end(self, epoch, logs=None):
#dns_layer = self.model.layers[6]
dns_layer = self.model.get_layer('activation')
outputs = dns_layer(self.data)
tf.print(f'\n input: {self.data}')
tf.print(f'\n output: {outputs}')
And I am predicting my model like this-我正在预测我的模型是这样的-
yhat = model.predict(X)
I like to call this callback function during Keras prediction
-我喜欢在
Keras prediction
期间调用这个回调函数——
Can anyone please help me, with how I can do that?谁能帮助我,我该怎么做?
If I Understand Correctly (IIUC), You need on_predict_batch_end
for calling callback in the keras prediction .如果我理解正确(IIUC),您需要
on_predict_batch_end
在keras prediction中调用回调。
(You can read here , We can write call back with different function call : on_predict_begin
, on_predict_end
, on_predict_batch_end
, on_test_end
, ...) (你可以在这里阅读,我们可以用不同的函数调用写回调:
on_predict_begin
, on_predict_end
, on_predict_batch_end
, on_test_end
,...)
import tensorflow as tf
import numpy as np
class CustomCallback(tf.keras.callbacks.Callback):
def __init__(self, data):
self.data = data
def on_predict_batch_end(self, batch, logs=None):
dns_layer = self.model.layers[6]
outputs = dns_layer(self.data)
tf.print(f'\n batch: {batch}')
tf.print(f'\n input: {self.data}')
tf.print(f'\n output: {outputs}')
x_train = tf.random.normal((10, 32, 32))
y_train = tf.random.uniform((10, 1), maxval=10)
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.LSTM(256, input_shape=(x_train.shape[1], x_train.shape[2]), return_sequences=True))
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.LSTM(256))
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.Dense(10, activation='softmax'))
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.Dense(5, activation='softmax'))
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.Dense(10, activation='softmax'))
model.compile(optimizer='adam', loss = tf.keras.losses.SparseCategoricalCrossentropy(False))
model.summary()
for layer in model.layers:
print(layer)
model.fit(x_train, y_train , epochs=1, batch_size=32)
yhat = model.predict(tf.random.normal((5, 32, 32)), batch_size=1, callbacks=[CustomCallback(np.random.rand(1,10))])
Output:输出:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm (LSTM) (None, 32, 256) 295936
dropout (Dropout) (None, 32, 256) 0
lstm_1 (LSTM) (None, 256) 525312
dropout_1 (Dropout) (None, 256) 0
dense (Dense) (None, 10) 2570
dropout_2 (Dropout) (None, 10) 0
dense_1 (Dense) (None, 5) 55
dropout_3 (Dropout) (None, 5) 0
dense_2 (Dense) (None, 10) 60
=================================================================
Total params: 823,933
Trainable params: 823,933
Non-trainable params: 0
_________________________________________________________________
<keras.layers.recurrent_v2.LSTM object at 0x7fd9bf3a1690>
<keras.layers.core.dropout.Dropout object at 0x7fd9b74f60d0>
<keras.layers.recurrent_v2.LSTM object at 0x7fd9b32cb090>
<keras.layers.core.dropout.Dropout object at 0x7fd9b325c110>
<keras.layers.core.dense.Dense object at 0x7fd9b31f4690>
<keras.layers.core.dropout.Dropout object at 0x7fd9b315a890>
<keras.layers.core.dense.Dense object at 0x7fd9b32a1050>
<keras.layers.core.dropout.Dropout object at 0x7fd9b3239f50>
<keras.layers.core.dense.Dense object at 0x7fd9b317e1d0>
1/1 [==============================] - 12s 12s/step - loss: 2.3411
batch: 0
input: [[0.77380147 0.84481026 0.0211125 0.2323317 0.27449231 0.7934265
0.82050726 0.28287153 0.56995795 0.65609332]]
output: [[0.07206324 0.29164252 0.41107002 0.10761011 0.11761407]]
batch: 1
input: [[0.77380147 0.84481026 0.0211125 0.2323317 0.27449231 0.7934265
0.82050726 0.28287153 0.56995795 0.65609332]]
output: [[0.07206324 0.29164252 0.41107002 0.10761011 0.11761407]]
batch: 2
input: [[0.77380147 0.84481026 0.0211125 0.2323317 0.27449231 0.7934265
0.82050726 0.28287153 0.56995795 0.65609332]]
output: [[0.07206324 0.29164252 0.41107002 0.10761011 0.11761407]]
batch: 3
input: [[0.77380147 0.84481026 0.0211125 0.2323317 0.27449231 0.7934265
0.82050726 0.28287153 0.56995795 0.65609332]]
output: [[0.07206324 0.29164252 0.41107002 0.10761011 0.11761407]]
batch: 4
input: [[0.77380147 0.84481026 0.0211125 0.2323317 0.27449231 0.7934265
0.82050726 0.28287153 0.56995795 0.65609332]]
output: [[0.07206324 0.29164252 0.41107002 0.10761011 0.11761407]]
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.