简体   繁体   中英

How to check which if an input array is the most similar to a list of arrays in a faster method if the length of list of arrays is very large?

I have a list containing many lists. I use a dtw approach using the fastdtw python package to calculate the distance between an input list and the lists inside the list. This gives me a list of distances from which I choose the minimum value and estimate it as the one closest to the input array. This process works but it is CPU intensive and time consuming if the number and the length of lists are large.

from fastdtw import fastdtw
import scipy.spatial.distance as ssd

inputlist = [1,2,3,4,5]
complelte_list = [[1,1,3,9,1],[1,2,6,4],[9,8,7,4,2]]
dst = []
for arr in complete_lists:
   distance, path = fastdtw(arr,inputlist,dist=ssd.euclidean)
   dst.append(distance)

If you need the closest, and not necessary all distances, build a tree, eg

from sklearn.neighbors import BallTree
import numpy as np

inputlist = [1,2,3,4,5]
complelte_list = [[1,1,3,9,1],[1,2,6,4,5],[9,8,7,4,2]]

tree = BallTree(np.array(complelte_list), leaf_size=10, metric='euclidean')

And query with

distance, index = tree.query(np.expand_dims(np.array(inputlist),axis=0), k=1, return_distance=True)

Which returns the distance to the closest k=1 and also the index , eg

print('Most similar to inputlist is')
print( complelte_list[ index[0][0] ] )

If Speed is important, you can tweak leaf_size=10 and try what works for your size. Building the tree also takes time so make sure that is part of your benchmark as well if this makes sense in your case.

I would suggest to use the dtaidistance library as it seem to compute faster than other libraries. Its also good to know that the python for loop is quite slow. As far as I am aware, there is not faster way to compute the dtw for one vector with a list of other vectors. Alternatively you can use the euclidian distance , which supports the calculation of all distances between one and many vectors without a for loop.

Here is an example using dtaidistance library:

distanceList = []
inputList = [1,2,3,4,5]

for sample in complete_list:
       d = dtw.distance_fast(inputList, sample, use_pruning=True)
       distanceList.append(d)[enter link description here][1]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM