簡體   English   中英

是否有可能在角膜板中可視化keras嵌入?

[英]Is it possible to visualize keras embeddings in tensorboard?

keras能夠使用keras.callbacks.TensorBoard以張量板可編碼格式導出其中一些訓練數據

但是,它不支持在tensorboard中嵌入可視化

有沒有解決的辦法?

找到了解決方案:

import os

import keras
import tensorflow

ROOT_DIR = '/tmp/tfboard'

os.makedirs(ROOT_DIR, exist_ok=True)


OUTPUT_MODEL_FILE_NAME = os.path.join(ROOT_DIR,'tf.ckpt')

# get the keras model
model = get_model()
# get the tensor name from the embedding layer
tensor_name = next(filter(lambda x: x.name == 'embedding', model.layers)).W.name

# the vocabulary
metadata_file_name = os.path.join(ROOT_DIR,tensor_name)

embedding_df = get_embedding()
embedding_df.to_csv(metadata_file_name, header=False, columns=[])

saver = tensorflow.train.Saver()
saver.save(keras.backend.get_session(), OUTPUT_MODEL_FILE_NAME)

summary_writer = tensorflow.train.SummaryWriter(ROOT_DIR)

config = tensorflow.contrib.tensorboard.plugins.projector.ProjectorConfig()
embedding = config.embeddings.add()
embedding.tensor_name = tensor_name
embedding.metadata_path = metadata_file_name
tensorflow.contrib.tensorboard.plugins.projector.visualize_embeddings(summary_writer, config)

有此功能的pull請求 - https://github.com/fchollet/keras/pull/5247回調擴展為特定嵌入層創建可視化。

現在可以直接使用keras.callbacks.TensorBoard回調:

from keras import callbacks

model.fit(x_train, y_train,
        batch_size=batch_size,
        epochs=10,
        callbacks=[
                   callbacks.TensorBoard(batch_size=batch_size,
                                         embeddings_freq=3,  # Store embeddings every 3 epochs (this can be time consuming)
                                         embeddings_layer_names=['fc1', 'fc2'],  # Embeddings are taken from layers with names fc1 and fc2
                                         embeddings_metadata='metadata.tsv',  # This file will describe the embeddings data (see below)
                                         embeddings_data=x_test),  # Data used for the embeddings
                   ],
        )


# Use this metadata.tsv file before you have a trained model:
with open("metadata.tsv", 'w') as f:
    f.write("label\tidx\n")
    f.write('\n'.join(["{}\t{}".format(class_names[int(y.argmax())], i)
                       for i, y in enumerate(y_test)]))


# After the model is trained, you can update the metadata file to include more information, such as the predicted labels and the mistakes:
y_pred = model.predict(x_test)
with open("metadata.tsv", 'w') as f:
    f.write("label\tidx\tpredicted\tcorrect\n")
    f.write('\n'.join(["{}\t{}\t{}\t{}".format(class_names[int(y.argmax())],
                                               i,
                                               class_names[int(y_pred[i].argmax())],
                                               class_names[int(y.argmax())]==class_names[int(y_pred[i].argmax())])
                       for i, y in enumerate(y_test)]))

注意: Tensorboard通常會在logs目錄中查找metadata.tsv 如果找不到它,它會告訴你它在哪條路徑上,你可以將它復制到那里並刷新張量板。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM