[英]How to dump confusion matrix using TensorBoard logger in pytorch-lightning?
官方文檔只說明
>>> from pytorch_lightning.metrics import ConfusionMatrix
>>> target = torch.tensor([1, 1, 0, 0])
>>> preds = torch.tensor([0, 1, 0, 0])
>>> confmat = ConfusionMatrix(num_classes=2)
>>> confmat(preds, target)
這沒有顯示如何在框架中使用度量。
我的嘗試(方法不完整,只顯示相關部分):
def __init__(...):
self.val_confusion = pl.metrics.classification.ConfusionMatrix(num_classes=self._config.n_clusters)
def validation_step(self, batch, batch_index):
...
log_probs = self.forward(orig_batch)
loss = self._criterion(log_probs, label_batch)
self.val_confusion.update(log_probs, label_batch)
self.log('validation_confusion_step', self.val_confusion, on_step=True, on_epoch=False)
def validation_step_end(self, outputs):
return outputs
def validation_epoch_end(self, outs):
self.log('validation_confusion_epoch', self.val_confusion.compute())
在第 0 個 epoch 之后,這給出了
Traceback (most recent call last):
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 521, in train
self.train_loop.run_training_epoch()
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\training_loop.py", line 588, in run_training_epoch
self.trainer.run_evaluation(test_mode=False)
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\trainer.py", line 613, in run_evaluation
self.evaluation_loop.log_evaluation_step_metrics(output, batch_idx)
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\evaluation_loop.py", line 346, in log_evaluation_step_metrics
self.__log_result_step_metrics(step_log_metrics, step_pbar_metrics, batch_idx)
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\evaluation_loop.py", line 350, in __log_result_step_metrics
cached_batch_pbar_metrics, cached_batch_log_metrics = cached_results.update_logger_connector()
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 378, in update_logger_connector
batch_log_metrics = self.get_latest_batch_log_metrics()
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 418, in get_latest_batch_log_metrics
batch_log_metrics = self.run_batch_from_func_name("get_batch_log_metrics")
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 414, in run_batch_from_func_name
results = [func(include_forked_originals=False) for func in results]
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 414, in <listcomp>
results = [func(include_forked_originals=False) for func in results]
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 122, in get_batch_log_metrics
return self.run_latest_batch_metrics_with_func_name("get_batch_log_metrics",
*args, **kwargs)
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 115, in run_latest_batch_metrics_with_func_name
for dl_idx in range(self.num_dataloaders)
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 115, in <listcomp>
for dl_idx in range(self.num_dataloaders)
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\trainer\connectors\logger_connector\epoch_result_store.py", line 100, in get_latest_from_func_name
results.update(func(*args, add_dataloader_idx=add_dataloader_idx, **kwargs))
File "C:\code\EPMD\Kodex\Templates\Testing\venv\lib\site-packages\pytorch_lightning\core\step_result.py", line 298, in get_batch_log_metrics
result[dl_key] = self[k]._forward_cache.detach()
AttributeError: 'NoneType' object has no attribute 'detach'
它確實在訓練前通過了健全性驗證檢查。
失敗發生在validation_step_end
返回。 對我來說意義不大。
使用 mertics 的完全相同的方法可以准確地工作。
如何獲得正確的混淆矩陣?
您可以使用self.logger.experiment.add_figure(*tag*, *figure*)
報告該圖。
變量self.logger.experiment
實際上是一個SummaryWriter
(來自 PyTorch,不是 Lightning)。 此 class 具有方法add_figure
( 文檔)。
您可以按如下方式使用它: (MNIST 示例)
def validation_step(self, batch, batch_idx):
x, y = batch
preds = self(x)
loss = F.nll_loss(preds, y)
return { 'loss': loss, 'preds': preds, 'target': y}
def validation_epoch_end(self, outputs):
preds = torch.cat([tmp['preds'] for tmp in outputs])
targets = torch.cat([tmp['target'] for tmp in outputs])
confusion_matrix = pl.metrics.functional.confusion_matrix(preds, targets, num_classes=10)
df_cm = pd.DataFrame(confusion_matrix.numpy(), index = range(10), columns=range(10))
plt.figure(figsize = (10,7))
fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
plt.close(fig_)
self.logger.experiment.add_figure("Confusion matrix", fig_, self.current_epoch)
這花了很多時間才找到。
這是我可以粘貼的最小的代碼,它仍然是可讀和可重現的。
我不想把整個 model 數據集和參數放在這里,因為他們對這個問題的讀者沒有興趣,只是噪音。
也就是說,這是創建每個時期的混淆矩陣並在 Tensorboard 中顯示所需的代碼
這是一個單幀,例如:
import pytorch_lightning as pl
import seaborn as sn
import pandas as pd
import numpy as np
import io
import matplotlib.pyplot as plt
from PIL import Image
def __init__(self, config, trained_vae, latent_dim):
self.val_confusion = pl.metrics.classification.ConfusionMatrix(num_classes=self._config.n_clusters)
self.logger: Optional[TensorBoardLogger] = None
def forward(self, x):
...
return log_probs
def validation_step(self, batch, batch_index):
if self._config.dataset == "mnist":
orig_batch, label_batch = batch
orig_batch = orig_batch.reshape(-1, 28 * 28)
log_probs = self.forward(orig_batch)
loss = self._criterion(log_probs, label_batch)
self.val_confusion.update(log_probs, label_batch)
return {"loss": loss, "labels": label_batch}
def validation_step_end(self, outputs):
return outputs
def validation_epoch_end(self, outs):
tb = self.logger.experiment
# confusion matrix
conf_mat = self.val_confusion.compute().detach().cpu().numpy().astype(np.int)
df_cm = pd.DataFrame(
conf_mat,
index=np.arange(self._config.n_clusters),
columns=np.arange(self._config.n_clusters))
plt.figure()
sn.set(font_scale=1.2)
sn.heatmap(df_cm, annot=True, annot_kws={"size": 16}, fmt='d')
buf = io.BytesIO()
plt.savefig(buf, format='jpeg')
buf.seek(0)
im = Image.open(buf)
im = torchvision.transforms.ToTensor()(im)
tb.add_image("val_confusion_matrix", im, global_step=self.current_epoch)
和教練的電話
logger = TensorBoardLogger(save_dir=tb_logs_folder, name='Classifier')
trainer = Trainer(
deterministic=True,
max_epochs=10,
default_root_dir=classifier_checkpoints_path,
logger=logger,
gpus=1
)
class IntHandler:
def legend_artist(self, legend, orig_handle, fontsize, handlebox):
x0, y0 = handlebox.xdescent, handlebox.ydescent
text = plt.matplotlib.text.Text(x0, y0, str(orig_handle))
handlebox.add_artist(text)
return text
class LightningClassifier(LightningModule):
...
def _common_step(self, batch, batch_nb, stage: str):
assert stage in ("train", "val", "test")
logger = self._logger
augmented_image, labels = batch
outputs, aux_outputs = self(augmented_image)
loss = self._criterion(outputs, labels)
return outputs, labels, loss
def validation_step(self, batch, batch_nb):
stage = "val"
outputs, labels, loss = self._common_step(batch, batch_nb, stage=stage)
self._common_log(loss, stage=stage)
return {"loss": loss, "outputs": outputs, "labels": labels}
def validation_epoch_end(self, outs):
# see https://github.com/Lightning-AI/metrics/blob/ff61c482e5157b43e647565fa0020a4ead6e9d61/docs/source/pages/lightning.rst
# each forward pass, thus leading to wrong accumulation. In practice do the following:
tb = self.logger.experiment # noqa
outputs = torch.cat([tmp['outputs'] for tmp in outs])
labels = torch.cat([tmp['labels'] for tmp in outs])
confusion = torchmetrics.ConfusionMatrix(num_classes=self.n_labels).to(outputs.get_device())
confusion(outputs, labels)
computed_confusion = confusion.compute().detach().cpu().numpy().astype(int)
# confusion matrix
df_cm = pd.DataFrame(
computed_confusion,
index=self._label_ind_by_names.values(),
columns=self._label_ind_by_names.values(),
)
fig, ax = plt.subplots(figsize=(10, 5))
fig.subplots_adjust(left=0.05, right=.65)
sn.set(font_scale=1.2)
sn.heatmap(df_cm, annot=True, annot_kws={"size": 16}, fmt='d', ax=ax)
ax.legend(
self._label_ind_by_names.values(),
self._label_ind_by_names.keys(),
handler_map={int: IntHandler()},
loc='upper left',
bbox_to_anchor=(1.2, 1)
)
buf = io.BytesIO()
plt.savefig(buf, format='jpeg', bbox_inches='tight')
buf.seek(0)
im = Image.open(buf)
im = torchvision.transforms.ToTensor()(im)
tb.add_image("val_confusion_matrix", im, global_step=self.current_epoch)
output:
也是基於此
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.