简体   繁体   中英

Python pandas: Finding cosine similarity of two columns

Suppose I have two columns in a python pandas.DataFrame:

          col1 col2
item_1    158  173
item_2     25  191
item_3    180   33
item_4    152  165
item_5     96  108

What's the best way to take the cosine similarity of these two columns?

Is that what you're looking for?

from scipy.spatial.distance import cosine
from pandas import DataFrame


df = DataFrame({"col1": [158, 25, 180, 152, 96],
                "col2": [173, 191, 33, 165, 108]})

print(1 - cosine(df["col1"], df["col2"]))

You can also use cosine_similarity or other similarity metrics from sklearn.metrics.pairwise .

from sklearn.metrics.pairwise import cosine_similarity

cosine_similarity(df.col1, df.col2)
Out[4]: array([[0.7498213]])

In my case I had a bit more complicated situation where 2 columns I wanted to compare were of different length (in other words, some NaN values were there). In this case the method represented in the accepted answer doesn't work as is (it outputs nan).

So, I used a following little trick to tackle with it. First, you concatenate 2 columns of interest into a new data frame. Then you drop NaN. After that those 2 columns have only corresponding rows, and you can compare them with cosine distance or any other pairwise distance you wish.

import pandas as pd
from scipy.spatial import distance

index = ['item_1', 'item_2', 'item_3', 'item_4', 'item_5']
cols = [pd.Series([158, 25, 180, 152, 96], index=index, name='col1'),
        pd.Series([173, 191, 33, 165, 108], index=index, name='col2'),
        pd.Series([183, 204, 56], index=['item_1', 'item_4', 'item_5'], name='col3')]
df = pd.concat(cols, axis=1)
print(df)
print(distance.cosine(df['col2'], df['col3']))

Output:

        col1  col2   col3
item_1   158   173  183.0
item_2    25   191    NaN
item_3   180    33    NaN
item_4   152   165  204.0
item_5    96   108   56.0
nan

What you do is:

tdf = pd.concat([df['col2'], df['col3']], axis=1).dropna()
print(tdf)
print(distance.cosine(tdf['col2'], tdf['col3']))

Output is:

        col2   col3
item_1   173  183.0
item_4   165  204.0
item_5   108   56.0
0.02741129579408741

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM