簡體   English   中英

使用 BERT 計算兩個詞之間的余弦相似度

[英]Calculate cosine similarity between 2 words using BERT

我正在嘗試使用 BERT 計算兩個給定單詞之間的余弦相似度,但我收到一條錯誤消息:

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)

在行中:

similarity = torch.cosine_similarity(word1_embedding, word2_embedding)

在下面你可以找到我到目前為止使用的代碼,有人知道問題出在哪里嗎?

from transformers import BertTokenizer, BertModel
import torch

# Load the BERT model and tokenizer
model = BertModel.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenize the input words
word1 = "cat"
word2 = "dog"
input_ids = torch.tensor([tokenizer.encode(word1, word2, add_special_tokens=True)])

# Get the BERT embeddings for the input words
output = model(input_ids)[0]

# Get the first and second word embeddings
word1_embedding = output[0, 1, :]
word2_embedding = output[0, 2, :]

# Calculate the cosine similarity between the two words
similarity = torch.cosine_similarity(word1_embedding, word2_embedding)
print(similarity)

您需要更改標記器和 output 索引以輸入cosine_similarity

如果我們將word2dog更改為cat ,我們將得到 cosine_similarity 的cosine_similarity ,如下所示: tensor([1.0000, 1.0000, 1.0000], grad_fn=<SumBackward1>)

# !pip install transformers
from transformers import BertTokenizer, BertModel
import torch

# Load the BERT model and tokenizer
model = BertModel.from_pretrained('bert-base-uncased')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

# Tokenize the input words
word1 = "cat"
word2 = "dog"
input_ids = torch.tensor(tokenizer([word1, word2], add_special_tokens=True)['input_ids'])

# Get the BERT embeddings for the input words
output = model(input_ids)[0]

# Get the first and second word embeddings
word1_embedding = output[0]
word2_embedding = output[1]

# Calculate the cosine similarity between the two words
similarity = torch.cosine_similarity(word1_embedding, word2_embedding)

print(similarity)
# tensor([0.9665, 0.7953, 0.9809], grad_fn=<SumBackward1>)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM