简体   繁体   English

如何使用文本热图可视化注意力向量?

[英]How to visualize the attention vector with a text heatmap?

I am working on an NLP research project and I want to visualize the output of the attention vector.我正在做一个 NLP 研究项目,我想可视化注意力向量的输出。

For example, the data looks like this:例如,数据如下所示:

def sample_data():
    
    sent = '''the USS Ronald Reagan - an aircraft carrier docked in Japan - during his tour of the region, vowing to "defeat any attack and meet any use of conventional or nuclear weapons with an overwhelming and effective American response".'''

    words    = sent.split()
    word_num = len(words)
    attention = [(x+1.)/word_num*100 for x in range(word_num)]
    
    return {'text': words, 'attention': attention}

which looks like this:看起来像这样:

{'text': ['the', 'USS', 'Ronald', 'Reagan', '-', 'an', 'aircraft', 'carrier', 'docked', 'in', 'Japan', '-', 'during', 'his', 'tour', 'of', 'the', 'region,', 'vowing', 'to', '"defeat', 'any', 'attack', 'and', 'meet', 'any', 'use', 'of', 'conventional', 'or', 'nuclear', 'weapons', 'with', 'an', 'overwhelming', 'and', 'effective', 'American', 'response".'], 'attention': [2.564102564102564, 5.128205128205128, 7.6923076923076925, 10.256410256410255, 12.82051282051282, 15.384615384615385, 17.94871794871795, 20.51282051282051, 23.076923076923077, 25.64102564102564, 28.205128205128204, 30.76923076923077, 33.33333333333333, 35.8974358974359, 38.46153846153847, 41.02564102564102, 43.58974358974359, 46.15384615384615, 48.717948717948715, 51.28205128205128, 53.84615384615385, 56.41025641025641, 58.97435897435898, 61.53846153846154, 64.1025641025641, 66.66666666666666, 69.23076923076923, 71.7948717948718, 74.35897435897436, 76.92307692307693, 79.48717948717949, 82.05128205128204, 84.61538461538461, 87.17948717948718, 89.74358974358975, 92.3076923076923, 94.87179487179486, 97.43589743589743, 100.0]}

Each token is assigned to one float value (attention score).每个标记都分配给一个浮点值(注意力分数)。 What are the options to visualize this data?有哪些选项可以可视化这些数据? Any library/tools available in any language R/Python/Js?任何语言 R/Python/Js 中可用的库/工具?

A solution that would handle the long sentences great would be to print a colored sentence in the console.一个可以很好地处理长句子的解决方案是在控制台中打印一个彩色句子。 You can do so by printing escape characters in the console: \033[38;2;255;0;0m test \033[0m will print a red test in the console (rgb code (255, 0, 0)).您可以通过在控制台中打印转义字符来做到这一点: \033[38;2;255;0;0m test \033[0m将在控制台中打印红色test (rgb 代码 (255, 0, 0))。
By using this idea, we can make a gradient from green to red (low to high attention) and print the text:通过使用这个想法,我们可以制作从绿色到红色的渐变(从低到高注意力)并打印文本:

import numpy as np

data = sample_data()

def colorFader(c1,c2,mix=0):
    return (1-mix)*np.array(c1) + mix*np.array(c2)
def colored(c, text):
    return "\033[38;2;{};{};{}m{} \033[0m".format(int(c[0]), int(c[1]), int(c[2]), text)

normalizer = max(dic["attention"])
output = ""
for word, attention in zip(dic["text"], dic["attention"]):
    color = colorFader([0, 255, 0], [255, 0, 0], mix=attention/normalizer)
    output += colored(color, word)

print(output)

This solution would output something like this on the console:此解决方案将在控制台上输出如下内容: 彩色控制台输出
I find this to be effective as a visualizer, but the fact that it does the visualization in the console might not be a good thing.我发现这作为可视化工具很有效,但它在控制台中进行可视化这一事实可能不是一件好事。

Another way would be to do a heatmap:另一种方法是做一个热图:

import matplotlib.pyplot as plt

data = sample_data()

# Create a pyplot figure
fig, ax = plt.subplots(1, 1)
# Creating the heatmap image with the <plasma> colormap
img = ax.imshow([data["attention"]], cmap='plasma', aspect='auto', extent=[-1,1,-1,1])
# Setting the x_ticks position to be in the middle of the corresponding color
ax.set_xticks([-1 + (i+0.5)*2/len(data["text"]) for i in range(len(data["text"]))])
# Setting the x_ticks labels as the text, rotated to 80° for space purpose
ax.set_xticklabels(data["text"], rotation=80)
# Display the heatmap
plt.show()

This gives the following result (upon which you can modify some parameters such as the height, width, colors, etc...)这给出了以下结果(在此您可以修改一些参数,例如高度、宽度、颜色等...) 在此处输入图像描述 If you have long sentences, this might not be the optimal solution though, as the ticks labels will be harder to see.如果你的句子很长,这可能不是最佳解决方案,因为刻度标签会更难看到。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM