简体   繁体   中英

How do I save the embeddings my model creates during training?

I am doing am image similarity problem and want to save the image embeddings the model creates during training. Is there a way I can capture the embeddings before they are passed into the loss function?

Here is my model

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(filters=128, kernel_size=2, padding='same', activation='relu', input_shape=(32,32,3)),
    tf.keras.layers.MaxPooling2D(pool_size=2),
    tf.keras.layers.Dropout(0.3),
    tf.keras.layers.Conv2D(filters=64, kernel_size=2, padding='same', activation='relu'),
    tf.keras.layers.MaxPooling2D(pool_size=2),
    tf.keras.layers.Dropout(0.3),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(32, activation=None), # No activation on final dense layer
    tf.keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1)) # L2 normalize embeddings

])

model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tfa.losses.TripletSemiHardLoss(),
    metrics = ["accuracy"])

history = model.fit(train_dataset, epochs=3, validation_data=test_dataset)

To be clear I do not want the outputs just from the final layer shown here. I want to save the final resulting vector that is output from my model.

You can create a custom callback and at the end of training step of each batch, call your model with input batch and get the output and save it.

class SaveEmbeddingCallback(tf.keras.callbacks.Callback):
    def on_train_batch_end(self, batch, logs=None):
        embedding = self.model.predict(batch)
        # IN THIS STAGE YOU HAVE THE OUTPUT OF THE MODEL
        # YOU CAN SAVE IT OR WHATEVER YOU WANT
        ...


model.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tfa.losses.TripletSemiHardLoss(),
    metrics = ["accuracy"])

history = model.fit(train_dataset,
    epochs=3,
    validation_data=test_dataset,
    callbacks=[SaveEmbeddingCallback()])

For more information on keras custom callbacks read this TensorFlow tutorial

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM