繁体   English   中英

使用 NearestNeighbors 和 word2vec 检测句子相似度

[英]Using NearestNeighbors and word2vec to detect sentence similarity

我已经在我的语料库中使用 python 和gensim计算了一个 word2vec 模型。

然后我计算了每个句子的平均 word2vec 向量(对句子中所有单词的所有向量求平均值)并将其存储在 pandas 数据框中。 pandas 数据框df的列是:

  • 句子
  • 书名(句子出处的书)
  • mean-vector(句子中 word2vec 向量的均值 - 大小为 100)

我正在尝试使用scikit-learn NearestNeighbors来检测句子相似性(我可能会使用 doc2vec,但目标之一是将此方法与 doc2vec 进行比较)。

这是我的代码:

X = df['mean_vector'].values
nbrs = NearestNeighbors(n_neighbors=2, algorithm='ball_tree').fit(X)

我收到以下错误:

ValueError: setting an array element with a sequence.

我想我应该以某种方式迭代向量,以便能够在row == sentence基础上计算每行的最近邻居,但这似乎超出了我当前(有限的)python 技能。

这是df['mean_vector'][0]第一个单元格的数据。 它是一个完整的向量大小为 100 的句子向量的平均值。

array([ -2.14208905e-02,   2.42093615e-02,  -5.78106642e-02,
     1.32915592e-02,  -2.43393257e-02,  -1.41872400e-02,
     2.83471867e-02,  -2.02910602e-02,  -5.49359620e-02,
    -6.70913085e-02,  -5.56188896e-02,  -2.95186806e-02,
     4.97652516e-02,   7.16793686e-02,   1.81338750e-02,
    -1.50108105e-02,   1.79438610e-02,  -2.41483524e-02,
     4.97504435e-02,   2.91026086e-02,  -6.87966943e-02,
     3.27585079e-02,   5.10644279e-02,   1.97029337e-02,
     7.73109496e-02,   3.23865712e-02,  -2.81659551e-02,
    -9.69715789e-03,   5.23059331e-02,   3.81100960e-02,
    -3.62489261e-02,  -3.40068117e-02,  -4.90736961e-02,
     8.72346922e-04,   2.27111522e-02,   1.06063476e-02,
    -3.93234752e-02,  -1.10617064e-01,   8.05142429e-03,
     4.56497036e-02,  -1.73281748e-02,   2.35153548e-02,
     5.13465842e-03,   1.88336968e-02,   2.40451116e-02,
     3.79024050e-03,  -4.83284928e-02,   2.10295208e-02,
    -4.92134318e-03,   1.01532964e-02,   8.02216958e-03,
    -6.74675079e-03,  -1.39653292e-02,  -2.07276996e-02,
     9.73508134e-03,  -7.37899616e-02,  -2.58320477e-02,
    -1.10700730e-05,  -4.53227758e-02,   2.31859135e-03,
     1.40053956e-02,   1.61973312e-02,   3.01702786e-02,
    -6.96818605e-02,  -3.47468331e-02,   4.79541793e-02,
    -1.78820305e-02,   5.99209731e-03,  -5.92620336e-02,
     7.34678581e-02,  -5.23381204e-05,  -5.07357903e-02,
    -2.55154949e-02,   5.06089740e-02,  -3.70467864e-02,
    -2.04878468e-02,  -7.62404222e-03,  -5.38200373e-03,
     7.68705690e-03,  -3.27000804e-02,  -2.18365286e-02,
     2.34392099e-03,  -3.02998684e-02,   9.42565035e-03,
     3.24523374e-02,  -1.10793915e-02,   3.06244520e-03,
    -1.82240941e-02,  -5.70741761e-03,   3.13486941e-02,
    -1.15621388e-02,   1.10221673e-02,  -3.55655849e-02,
    -4.56304513e-02,   5.54837054e-03,   4.38252240e-02,
     1.57828294e-02,   2.65670624e-02,   8.08797963e-03,
     4.55569401e-02], dtype=float32)

我也尝试过:

for vec in df['mean_vector']:
X = vec
nbrs = NearestNeighbors(n_neighbors=2, algorithm='ball_tree').fit(X)

但我只收到以下警告:

DeprecationWarning: Passing 1d arrays as data is deprecated in 0.17 and willraise ValueError in 0.19. Reshape your data either using X.reshape(-1, 1) if your data has a single feature or X.reshape(1, -1) if it contains a single sample.

如果 github 上有一个在类似场景中使用 word2vec 和NearestNeighbors的示例,我很乐意看到它。

您的编辑抛出错误的原因是sklearn需要一个 2D 输入,每个示例都在一个新行中。 您可以使用X.reshape(1, -1)[X] ,第一个是更好的做法。 如果没有原始数据或适当的 MWE,很难说到底出了什么问题,但我的猜测是将数据放入或取出数据框时出现问题。 检查X.shape是否对您有意义。

下面是我用来检查一切对我有用的例子:

from sklearn.neighbors import NearestNeighbors
from gensim.models import Word2Vec
import numpy as np

a = """Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore
magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea 
commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla 
pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est 
laborum."""
a = [x.split(' ') for x in a.split('\n') if len(x)]
model = Word2Vec(a, min_count=1)

# Get the average of all of the words to get data for a sentence
b = np.array([np.mean([model[xx] for xx in x], axis=0) for x in a])
# Check it's the correct shape
print b.shape

nbrs = NearestNeighbors(n_neighbors=2, algorithm='ball_tree').fit(b)

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM