简体   繁体   中英

Define which neighbours a Ball tree should return

I have a dataframe with several locations. I want to find each locations nearest neighbours.

To do this, I am using a Ball tree. However, the output seems to be comparing all of the locations with each other, including the original location. For example, I have locations A,B,C..... the output will list A as a neighbour for A.

Also, I have a column for time that I want to use in my analysis. I have set the time column to the index before fitting the Ball tree. But the output will return A at time 1, A at time 2, A at time 3 as neighbours of A.

I created a smaller dataframe with fake data to mirror my own (displayed below) and using this smaller dataset, I can run the tree with a larger number of neighbours that I would otherwise use and remove the 'wrong' neighbours from the output.

However, this method is too computationally expensive to use with my real data.

Is there a way to defining which type of neighbours to return in the Ball Tree?

sample code:

from sklearn.neighbors import BallTree
import numpy as np
import pandas as pd

test_data = pd.DataFrame({'latitude':[51.51, 51.52,61.53,61.54,71.55, 71.56,
                                      51.51, 51.52,61.53,61.54,71.55, 71.56,
                                      51.51, 51.52,61.53,61.54,71.55, 71.56],
                         'longitude':[-0.13,-0.13,-0.13,-0.14,-0.13,-0.13,
                                      -0.13,-0.13,-0.13,-0.14,-0.13,-0.13,
                                      -0.13,-0.13,-0.13,-0.14,-0.13,-0.13],
                         'id':['A','B','C','D','E','F',
                               'A','B','C','D','E','F',
                               'A','B','C','D','E','F'],
                         'target':[35,410,1,100,114,78,
                                   14,254,101,278,3578,435,
                                   254,254,37,47,38,101],
                      'time':['2019-03-10 11:00:00','2019-03-10 11:00:00','2019-03-10 11:00:00','2019-03-10 11:00:00','2019-03-10 11:00:00','2019-03-10 11:00:00',
                              '2019-03-10 11:10:00','2019-03-10 11:10:00','2019-03-10 11:10:00','2019-03-10 11:10:00','2019-03-10 11:10:00','2019-03-10 11:10:00',
                              '2019-03-10 11:20:00','2019-03-10 11:20:00','2019-03-10 11:20:00','2019-03-10 11:20:00','2019-03-10 11:20:00','2019-03-10 11:20:00',
                              ]})

# --- STEP 1) Prepairing the data
test_data=test_data.reset_index()

# Convert latitude and longitude to radions
for column in test_data[['latitude','longitude']]:
    rad = np.deg2rad(test_data[column].values)
    test_data[f'{column}'] = rad

# Creating a duplicate of the time column, one will be set as the index
test_data['time2']=test_data['time']

#  Convert time to datetime
test_data['time']=pd.to_datetime(test_data['time'])
test_data = test_data.set_index('time').astype('str')

# --- STEP 2) FITTING THE BALL TREE
locations_a = test_data
locations_b = test_data
col_name = 'ss_id'
latitude = "latitude"
longitude = "longitude"

# make ball tree
ball = BallTree(locations_a[[latitude, longitude]].values, metric='haversine')

# The amount of neighbors to return
k = 6
# Calculating distances
distances, indices = ball.query(locations_b[[latitude, longitude]].values, k = k)

# --- STEP 3) Merging Results into dataframe
dists = pd.DataFrame(distances).stack()
rel = pd.DataFrame(indices).stack()

# Create dataframe
neighbor_info_df = pd.merge(dists.rename('distance'), rel.rename('neighbor_idx'), right_index=True, left_index=True)
# Resetting and renaming indexes
neighbor_info_df = neighbor_info_df.reset_index(level=1).rename({'level_1': 'neighbor_number'}, axis=1) 
neighbor_info_df = neighbor_info_df.reset_index().rename({'index': 'id_index_no'}, axis=1)
neighbor_info_df.head(10)

As I suggested, you have to remove the first match to remove the point itself from the k neighbors (distance=0). However, you need to group by time before:

# Setup: let the default numeric index (0->N), don't use time as index
# test_data = pd.DataFrame(...)
# test_data['time']=pd.to_datetime(test_data['time'])

def find_k_neighbors(df, k=6):
    """Find k neighbors exclude current point itself."""

    # Prepare data and get neighbors
    coords = np.radians(df[['latitude', 'longitude']])
    tree = BallTree(coords, metric='haversine')
    distances, indices = tree.query(coords, k=k+1)

    # Flat all arrays then build dataframe
    distances = np.ravel(distances[:, 1:])
    indices = np.ravel(indices[:, 1:])
    return (pd.DataFrame({'id_index_no': df.index.repeat(k),
                          'neighbor_number': np.tile(range(1, k+1), len(df)),
                          'distance': distances,
                          'neighbor_idx': df.iloc[indices].index}))

# For demonstration purpose, I limit k=3
neighbor_info_df = test_data.groupby('time').apply(find_k_neighbors, k=3).droplevel(1)

Output:

>>> neighbor_info_df
                     id_index_no  neighbor_number  distance  neighbor_idx
time                                                                     
2019-03-10 11:00:00            0                1  0.000175             1
2019-03-10 11:00:00            0                2  0.174882             2
2019-03-10 11:00:00            0                3  0.175057             3
2019-03-10 11:00:00            1                1  0.000175             0
2019-03-10 11:00:00            1                2  0.174707             2
2019-03-10 11:00:00            1                3  0.174882             3
2019-03-10 11:00:00            2                1  0.000193             3
2019-03-10 11:00:00            2                2  0.174707             1
2019-03-10 11:00:00            2                3  0.174882             0
2019-03-10 11:00:00            3                1  0.000193             2
2019-03-10 11:00:00            3                2  0.174707             4
2019-03-10 11:00:00            3                3  0.174882             5
2019-03-10 11:00:00            4                1  0.000175             5
2019-03-10 11:00:00            4                2  0.174707             3
2019-03-10 11:00:00            4                3  0.174882             2
2019-03-10 11:00:00            5                1  0.000175             4
2019-03-10 11:00:00            5                2  0.174882             3
2019-03-10 11:00:00            5                3  0.175057             2
2019-03-10 11:10:00            6                1  0.000175             7
2019-03-10 11:10:00            6                2  0.174882             8
2019-03-10 11:10:00            6                3  0.175057             9
2019-03-10 11:10:00            7                1  0.000175             6
2019-03-10 11:10:00            7                2  0.174707             8
2019-03-10 11:10:00            7                3  0.174882             9
2019-03-10 11:10:00            8                1  0.000193             9
2019-03-10 11:10:00            8                2  0.174707             7
2019-03-10 11:10:00            8                3  0.174882             6
2019-03-10 11:10:00            9                1  0.000193             8
2019-03-10 11:10:00            9                2  0.174707            10
2019-03-10 11:10:00            9                3  0.174882            11
2019-03-10 11:10:00           10                1  0.000175            11
2019-03-10 11:10:00           10                2  0.174707             9
2019-03-10 11:10:00           10                3  0.174882             8
2019-03-10 11:10:00           11                1  0.000175            10
2019-03-10 11:10:00           11                2  0.174882             9
2019-03-10 11:10:00           11                3  0.175057             8
2019-03-10 11:20:00           12                1  0.000175            13
2019-03-10 11:20:00           12                2  0.174882            14
2019-03-10 11:20:00           12                3  0.175057            15
2019-03-10 11:20:00           13                1  0.000175            12
2019-03-10 11:20:00           13                2  0.174707            14
2019-03-10 11:20:00           13                3  0.174882            15
2019-03-10 11:20:00           14                1  0.000193            15
2019-03-10 11:20:00           14                2  0.174707            13
2019-03-10 11:20:00           14                3  0.174882            12
2019-03-10 11:20:00           15                1  0.000193            14
2019-03-10 11:20:00           15                2  0.174707            16
2019-03-10 11:20:00           15                3  0.174882            17
2019-03-10 11:20:00           16                1  0.000175            17
2019-03-10 11:20:00           16                2  0.174707            15
2019-03-10 11:20:00           16                3  0.174882            14
2019-03-10 11:20:00           17                1  0.000175            16
2019-03-10 11:20:00           17                2  0.174882            15
2019-03-10 11:20:00           17                3  0.175057            14

Input dataframe:

>>> test_data
    latitude  longitude id  target                time
0      51.51      -0.13  A      35 2019-03-10 11:00:00
1      51.52      -0.13  B     410 2019-03-10 11:00:00
2      61.53      -0.13  C       1 2019-03-10 11:00:00
3      61.54      -0.14  D     100 2019-03-10 11:00:00
4      71.55      -0.13  E     114 2019-03-10 11:00:00
5      71.56      -0.13  F      78 2019-03-10 11:00:00
6      51.51      -0.13  A      14 2019-03-10 11:10:00
7      51.52      -0.13  B     254 2019-03-10 11:10:00
8      61.53      -0.13  C     101 2019-03-10 11:10:00
9      61.54      -0.14  D     278 2019-03-10 11:10:00
10     71.55      -0.13  E    3578 2019-03-10 11:10:00
11     71.56      -0.13  F     435 2019-03-10 11:10:00
12     51.51      -0.13  A     254 2019-03-10 11:20:00
13     51.52      -0.13  B     254 2019-03-10 11:20:00
14     61.53      -0.13  C      37 2019-03-10 11:20:00
15     61.54      -0.14  D      47 2019-03-10 11:20:00
16     71.55      -0.13  E      38 2019-03-10 11:20:00
17     71.56      -0.13  F     101 2019-03-10 11:20:00

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM