I have a keras
model like this-
And for this model, I have a callback function
like this-
import tensorflow as tf
import numpy as np
class WriteLayerValCallback(tf.keras.callbacks.Callback):
def __init__(self):
self.data = np.random.rand(1,10)
def on_epoch_end(self, epoch, logs=None):
#dns_layer = self.model.layers[6]
dns_layer = self.model.get_layer('activation')
outputs = dns_layer(self.data)
tf.print(f'\n input: {self.data}')
tf.print(f'\n output: {outputs}')
And I am predicting my model like this-
yhat = model.predict(X)
I like to call this callback function during Keras prediction
-
Can anyone please help me, with how I can do that?
If I Understand Correctly (IIUC), You need on_predict_batch_end
for calling callback in the keras prediction .
(You can read here , We can write call back with different function call : on_predict_begin
, on_predict_end
, on_predict_batch_end
, on_test_end
, ...)
import tensorflow as tf
import numpy as np
class CustomCallback(tf.keras.callbacks.Callback):
def __init__(self, data):
self.data = data
def on_predict_batch_end(self, batch, logs=None):
dns_layer = self.model.layers[6]
outputs = dns_layer(self.data)
tf.print(f'\n batch: {batch}')
tf.print(f'\n input: {self.data}')
tf.print(f'\n output: {outputs}')
x_train = tf.random.normal((10, 32, 32))
y_train = tf.random.uniform((10, 1), maxval=10)
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.LSTM(256, input_shape=(x_train.shape[1], x_train.shape[2]), return_sequences=True))
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.LSTM(256))
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.Dense(10, activation='softmax'))
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.Dense(5, activation='softmax'))
model.add(tf.keras.layers.Dropout(0.2))
model.add(tf.keras.layers.Dense(10, activation='softmax'))
model.compile(optimizer='adam', loss = tf.keras.losses.SparseCategoricalCrossentropy(False))
model.summary()
for layer in model.layers:
print(layer)
model.fit(x_train, y_train , epochs=1, batch_size=32)
yhat = model.predict(tf.random.normal((5, 32, 32)), batch_size=1, callbacks=[CustomCallback(np.random.rand(1,10))])
Output:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
lstm (LSTM) (None, 32, 256) 295936
dropout (Dropout) (None, 32, 256) 0
lstm_1 (LSTM) (None, 256) 525312
dropout_1 (Dropout) (None, 256) 0
dense (Dense) (None, 10) 2570
dropout_2 (Dropout) (None, 10) 0
dense_1 (Dense) (None, 5) 55
dropout_3 (Dropout) (None, 5) 0
dense_2 (Dense) (None, 10) 60
=================================================================
Total params: 823,933
Trainable params: 823,933
Non-trainable params: 0
_________________________________________________________________
<keras.layers.recurrent_v2.LSTM object at 0x7fd9bf3a1690>
<keras.layers.core.dropout.Dropout object at 0x7fd9b74f60d0>
<keras.layers.recurrent_v2.LSTM object at 0x7fd9b32cb090>
<keras.layers.core.dropout.Dropout object at 0x7fd9b325c110>
<keras.layers.core.dense.Dense object at 0x7fd9b31f4690>
<keras.layers.core.dropout.Dropout object at 0x7fd9b315a890>
<keras.layers.core.dense.Dense object at 0x7fd9b32a1050>
<keras.layers.core.dropout.Dropout object at 0x7fd9b3239f50>
<keras.layers.core.dense.Dense object at 0x7fd9b317e1d0>
1/1 [==============================] - 12s 12s/step - loss: 2.3411
batch: 0
input: [[0.77380147 0.84481026 0.0211125 0.2323317 0.27449231 0.7934265
0.82050726 0.28287153 0.56995795 0.65609332]]
output: [[0.07206324 0.29164252 0.41107002 0.10761011 0.11761407]]
batch: 1
input: [[0.77380147 0.84481026 0.0211125 0.2323317 0.27449231 0.7934265
0.82050726 0.28287153 0.56995795 0.65609332]]
output: [[0.07206324 0.29164252 0.41107002 0.10761011 0.11761407]]
batch: 2
input: [[0.77380147 0.84481026 0.0211125 0.2323317 0.27449231 0.7934265
0.82050726 0.28287153 0.56995795 0.65609332]]
output: [[0.07206324 0.29164252 0.41107002 0.10761011 0.11761407]]
batch: 3
input: [[0.77380147 0.84481026 0.0211125 0.2323317 0.27449231 0.7934265
0.82050726 0.28287153 0.56995795 0.65609332]]
output: [[0.07206324 0.29164252 0.41107002 0.10761011 0.11761407]]
batch: 4
input: [[0.77380147 0.84481026 0.0211125 0.2323317 0.27449231 0.7934265
0.82050726 0.28287153 0.56995795 0.65609332]]
output: [[0.07206324 0.29164252 0.41107002 0.10761011 0.11761407]]
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.