繁体   English   中英

同一字典值之间的余弦相似度

[英]Cosine similarity between the same dictionary's values

我有这个叫做queries的字典:

{'q1': ['similar',
  'law',
  'must',
  'obey',
  'construct',
  'aeroelast',
  'model',
  'heat',
  'high',
  'speed',
  'aircraft'],
 'q2': ['structur',
  'aeroelast',
  'problem',
  'associ',
  'flight',
  'high',
  'speed',
  'aircraft'],
 'q3': ['problem', 'heat', 'conduct', 'composit', 'slab', 'solv', 'far']
...
}

并使用此代码将其转换为矢量化 arrays 的字典:

class RetrievalSystem:
    def __init__(self, docs, num_concepts, min_df=1, alpha=1.0, beta=0.75, gamma=0.15):
        # create a doc-term matrix out of our doc collection
        self.vec = TfidfVectorizer(tokenizer=str.split, min_df=min_df)
        doc_term_mat = self.vec.fit_transform([" ".join(docs[doc_id]) for doc_id in docs])
        self.q_vecs = dict() # query vectors
        
        self.svd = TruncatedSVD(n_components = num_concepts, random_state = 42)
        
        self.doc_vecs = self.svd.fit_transform(doc_term_mat)

    def retrieve_n_rank_docs(self, queries, max_docs=-1):
       
        for query in queries:
            s = self.vec.transform([" ".join(queries[query])])
            s = self.svd.transform(s)
            if query not in self.q_vecs.keys():
                self.q_vecs[query] = s

max_docs arguments 控制每个查询要返回的最大文档数。 现在self.q_vecs看起来像这样:

{'q217': array([[ 0.16555858,  0.12041974,  0.10034606,  0.03249144,  0.00843294,
         0.16582048, -0.20520625, -0.05597786, -0.12666519, -0.10517737,
         0.14363559, -0.01525909, -0.16574115, -0.04112081, -0.1374631 ,
         0.05047798,  0.05825697, -0.01779095, -0.05663042, -0.14333234,
        -0.09671375, -0.02205753,  0.03309577, -0.04512224, -0.01605542,
         0.00762974,  0.02407301,  0.00426722,  0.00654344,  0.08085963,
         0.08657383, -0.09913353,  0.01492773, -0.06813004, -0.01151318,
        -0.08565942,  0.03826287, -0.00330817,  0.13141591,  0.04920131,
        -0.08375895,  0.09465868, -0.03466024,  0.01838176, -0.00336209,
         0.02372735, -0.03390722,  0.0440413 ,  0.00371048,  0.09835254,
        -0.01099799,  0.0014484 ,  0.06276236,  0.04311937, -0.0867389 ,
         0.00850617,  0.00496759, -0.17198825,  0.07988587,  0.05727097,
         0.13304752,  0.08784825, -0.06141824, -0.01383098, -0.02348199,
        -0.04522944,  0.05257815,  0.08263177, -0.01140021, -0.05829286,
        -0.04885191,  0.09377792,  0.0190092 ,  0.00947696,  0.05598195,
        -0.03815088, -0.02834209,  0.0281708 , -0.02843137, -0.03210851,
         0.04751607, -0.01162277,  0.02034976, -0.02088302,  0.07665635,
         0.0195319 , -0.0157795 ,  0.01210985, -0.03183579,  0.01161029,
         0.02409737, -0.01007874,  0.10754846,  0.01010833, -0.05662593,
        -0.01729383, -0.03097083,  0.03369774,  0.00572065,  0.02632313]]), 'q99': array([[ 0.10287323, -0.01085065, -0.00967409, -0.04218846,  0.09239141,
         0.07992809, -0.00359886, -0.03796564,  0.01250241,  0.01951022,
        -0.03673524, -0.02372439, -0.03240905, -0.03081271,  0.02817431,
         0.12468386, -0.02051108,  0.12191644,  0.00624408, -0.05094331,
         0.09598166, -0.02341246, -0.0020474 , -0.05629724,  0.03516377,
         0.09028871,  0.02806492, -0.02300581, -0.02998558, -0.00270938,
         0.01611941,  0.04106955,  0.05371339, -0.02561045, -0.01916819,
         0.08158927, -0.03353019, -0.01020131, -0.03670832,  0.02845091,
         0.07133292, -0.0944471 , -0.00662414,  0.0920997 , -0.00206586,
         0.07063442, -0.00814919, -0.00374118, -0.01353651,  0.07968094,
         0.00796783, -0.01397921, -0.07712498, -0.00308536,  0.07785687,
        -0.01220938, -0.06646712,  0.04048088,  0.01321445,  0.00041508,
        -0.04644943,  0.09307773,  0.0188646 , -0.03233048, -0.04803833,
        -0.06355723, -0.00560934, -0.05478746,  0.03196071,  0.08420215,
        -0.07706163, -0.12595219, -0.01330823, -0.00079499, -0.02515943,
         0.00087481, -0.00596035,  0.01680558,  0.0138655 , -0.01290259,
        -0.0497661 , -0.04627047, -0.00239779, -0.06377815, -0.01103349,
         0.00205314, -0.0774958 ,  0.00223332, -0.00976858,  0.02365778,
         0.02600081,  0.01212485,  0.03451618,  0.00642054, -0.00025119,
         0.00898667,  0.00749051,  0.02099796, -0.00906813, -0.06770008]])
...
}

我想取向量表示之间的余弦相似度,然后按余弦相似度的降序对这些查询来自的文档进行排序。 所以所需的 output 看起来像这样:

{
    'q217': ['d983', 'd554', ..., 'd623'],
    'q99' : ['d716', 'd67', ..., 'd164'],
    ...
}

我编写了这段代码来尝试将 output 与 cos 相似,但它只返回 1 个键值对:

class RetrievalSystem:
    def __init__(self, docs, num_concepts, min_df=1, alpha=1.0, beta=0.75, gamma=0.15):
        self.alpha, self.beta, self.gamma = alpha, beta, gamma
        
        # create a doc-term matrix out of our doc collection
        self.vec = TfidfVectorizer(tokenizer=str.split, min_df=min_df)
        doc_term_mat = self.vec.fit_transform([" ".join(docs[doc_id]) for doc_id in docs])
        self.q_vecs = dict() # query vectors
        
        self.svd = TruncatedSVD(n_components = num_concepts, random_state = 42)
        
        self.doc_vecs = self.svd.fit_transform(doc_term_mat)
        # YOUR CODE HERE
        #raise NotImplementedError()

    def retrieve_n_rank_docs(self, queries, max_docs=-1):
       
        for query in queries:
            s = self.vec.transform([" ".join(queries[query])])
            s = self.svd.transform(s)
            if query not in self.q_vecs.keys():
                self.q_vecs[query] = s
        
            all_keys = list(self.q_vecs.keys())
            new_d = {}
        
            for i in range(len(all_keys)):
                for j in range(i+1,len(all_keys)):
                    new_d[query] = {1 - spatial.distance.cosine(self.q_vecs[all_keys[i]], self.q_vecs[all_keys[j]])}

因为代码不是最小的可重现示例,所以我不能完全提供帮助。 但是要创建每个键组合的余弦相似度字典,您可以执行以下操作:

import itertools
import numpy as np


q_ves = {
    "q1": np.array([0, 1]),
    "q2": np.array([1, 0]),
    "q3": np.array([0, 2]),
    "q4": np.array([10, 10])
}
new_q = {}
for k1, k2 in list(map(dict, itertools.combinations(
    q_vecs.items(), 2))):
    new_d[(k1, k2)] = 1 - scipy.spatial.distance.cosine(q_vecs[k1], q_vecs[k2])

这会给你:

{
    ('q1', 'q2'): -0.04468849422512422,
    ('q1', 'q3'): 1,
    ('q1', 'q4'): -0.04468849422512422,
    ('q2', 'q3'): -0.04468849422512422,
    ('q2', 'q4'): 1,
    ('q3', 'q4'): -0.04468849422512422
}

我希望这就是您所追求的,因为我不明白您如何为以下内容生成字符串:

{
    'q217': ['d983', 'd554', ..., 'd623'],
    'q99' : ['d716', 'd67', ..., 'd164'],
    ...
}

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM