[英]Visualizing ViT Attention maps after fine tuning on medical dataset
I have imported the Vit-b32 model and fine-tuned it to perform classification task on echo images.我已导入 Vit-b32 model 并对其进行微调以对回波图像执行分类任务。 Now I want to visualize the attention maps so that I can know on which part of the image the model is focusing for doing the classification task.
现在我想可视化注意力图,这样我就可以知道 model 专注于图像的哪一部分来执行分类任务。 But I am unable to do it and I am getting an error when I am trying to visualize the attention maps after fine-tuning the model.
但是我无法做到这一点,并且在微调 model 后尝试可视化注意力图时出现错误。 Below is the code:
下面是代码:
!pip install --quiet vit-keras
from vit_keras import vit
vit_model = vit.vit_b32(
image_size = IMAGE_SIZE,
activation = 'softmax',
pretrained = True,
include_top = False,
pretrained_top = False,
classes = 3)
When I try yo visualize the attention map without any finetuning then it is working without any error:当我尝试在没有任何微调的情况下将注意力可视化 map 时,它可以正常工作,没有任何错误:
from vit_keras import visualize
x = test_gen.next()
image = x[0]
attention_map = visualize.attention_map(model = vit_model, image = image)
# Plot results
fig, (ax1, ax2) = plt.subplots(ncols = 2)
ax1.axis('off')
ax2.axis('off')
ax1.set_title('Original')
ax2.set_title('Attention Map')
_ = ax1.imshow(image)
_ = ax2.imshow(attention_map)
Now in the below code I have added some classification layers to the model and fine-tuned it:现在在下面的代码中,我向 model 添加了一些分类层并对其进行了微调:
model = tf.keras.Sequential([
vit_model,
tf.keras.layers.Flatten(),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(11, activation = tfa.activations.gelu),
tf.keras.layers.BatchNormalization(),
tf.keras.layers.Dense(3, 'softmax')
],
name = 'vision_transformer')
model.summary()
Below is the output of the above cell:下面是上述单元格的output:
> Model: "vision_transformer"
> _________________________________________________________________ Layer (type) Output Shape Param #
> ================================================================= vit-b32 (Functional) (None, 768) 87455232
> _________________________________________________________________ flatten_1 (Flatten) (None, 768) 0
> _________________________________________________________________ batch_normalization_2 (Batch (None, 768) 3072
> _________________________________________________________________ dense_2 (Dense) (None, 11) 8459
> _________________________________________________________________ batch_normalization_3 (Batch (None, 11) 44
> _________________________________________________________________ dense_3 (Dense) (None, 3) 36
> ================================================================= Total params: 87,466,843 Trainable params: 87,465,285 Non-trainable
> params: 1,558
> _________________________________________________________________
Now I have trained the model on my own medical dataset:现在我已经在我自己的医疗数据集上训练了 model:
learning_rate = 1e-4
optimizer = tfa.optimizers.RectifiedAdam(learning_rate = learning_rate)
model.compile(optimizer = optimizer,
loss = tf.keras.losses.CategoricalCrossentropy(label_smoothing = 0.2),
metrics = ['accuracy'])
STEP_SIZE_TRAIN = train_gen.n // train_gen.batch_size
STEP_SIZE_VALID = valid_gen.n // valid_gen.batch_size
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor = 'val_accuracy',
factor = 0.2,
patience = 2,
verbose = 1,
min_delta = 1e-4,
min_lr = 1e-6,
mode = 'max')
earlystopping = tf.keras.callbacks.EarlyStopping(monitor = 'val_accuracy',
min_delta = 1e-4,
patience = 5,
mode = 'max',
restore_best_weights = True,
verbose = 1)
checkpointer = tf.keras.callbacks.ModelCheckpoint(filepath = './model.hdf5',
monitor = 'val_accuracy',
verbose = 1,
save_best_only = True,
save_weights_only = True,
mode = 'max')
callbacks = [earlystopping, reduce_lr, checkpointer]
model.fit(x = train_gen,
steps_per_epoch = STEP_SIZE_TRAIN,
validation_data = valid_gen,
validation_steps = STEP_SIZE_VALID,
epochs = EPOCHS,
callbacks = callbacks)
model.save('model.h5', save_weights_only = True)
After training when I am trying to visualize the attention map of the model, it is showing error:训练后,当我试图可视化 model 的注意力 map 时,它显示错误:
from vit_keras import visualize
x = test_gen.next()
image = x[0]
attention_map = visualize.attention_map(model = model, image = image)
# Plot results
fig, (ax1, ax2) = plt.subplots(ncols = 2)
ax1.axis('off')
ax2.axis('off')
ax1.set_title('Original')
ax2.set_title('Attention Map')
_ = ax1.imshow(image)
_ = ax2.imshow(attention_map)
Below is the following error:以下是以下错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-13-f208f2d2b771> in <module>
4 image = x[0]
5
----> 6 attention_map = visualize.attention_map(model = model, image = image)
7
8 # Plot results
/opt/conda/lib/python3.7/site-packages/vit_keras/visualize.py in attention_map(model, image)
14 """
15 size = model.input_shape[1]
---> 16 grid_size = int(np.sqrt(model.layers[5].output_shape[0][-2] - 1))
17
18 # Prepare the input
TypeError: 'NoneType' object is not subscriptable
Please suggest some way to rectify the above error and visualize the attention maps of the fine-tuned model请提出一些纠正上述错误的方法并可视化微调的 model 的注意力图
You can visualize attention maps by doing the following.您可以通过执行以下操作来可视化注意力图。
attention_map = visualize.attention_map(model=model.layers[0], image=image)
Since attention_map assumes a ViT model as the model argument, you need to specify the first element of the fine-tuned model defined as tf.keras.Sequential. Since attention_map assumes a ViT model as the model argument, you need to specify the first element of the fine-tuned model defined as tf.keras.Sequential.
maybe it's a little late, but I have a solution working.也许有点晚了,但我有一个解决方案。
I have the path to the image in a string, open it with OpenCv library, and I previosly load a ViT model fine tunned.我在字符串中有图像的路径,用 OpenCv 库打开它,我预先加载了一个经过微调的 ViT model。
I think you only need to use the method get_layer, and select your Vit, as you use it entirely in your sequential model, it works as a layer.我认为您只需要使用方法 get_layer 和 select 您的 Vit,因为您完全在顺序 model 中使用它,它作为一个层工作。
path='/content/drive/MyDrive/TFM/Harvard_procesado/ISIC_0025612.jpg'
img = cv2.imread(path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
res = cv2.resize(img, dsize=(224,224), interpolation=cv2.INTER_CUBIC)
attention_map1 = visualize.attention_map(model = vit_model_t.get_layer('vit_model'), image = res)
fig = plt.figure(figsize=(20,20))
ax = plt.subplot(1, 2, 1)
ax.axis('off')
ax.set_title('Original')
_ = ax.imshow(res)
ax = plt.subplot(1, 2, 2)
ax.axis('off')
ax.set_title('Attention Map')
_ = ax.imshow(attention_map1)
I hope this help我希望这有帮助
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.