简体   繁体   中英

Selecting rows from a sparse matrix based on the index of a panda dataframe

Suppose I have a panda dataframe of shape = (2,500,000, M) and a scipy csr sparse matrix of shape (2,500,000, N).

Each row of the dataframe and sparse matrix describes one entity. They are already ordered such that row 1 of the dataframe is describing an entity that is also found in row 1 of the sparse matrix. So now the dataframe has a fast mechanism to do filtering ( catalogue.where(catalogue.some_column != '' ), but how do I find the respective rows in the sparse matrix given the filtered dataframe?

Assume the dataframe is called a catalogue , and the sparse matrix is called a collection

def collection_filter_row(catalogue_filtered, catalogue_index_full, collection):
    return scipy.sparse.vstack(ThreadPool(100).map(
        functools.partial(collection_get_row,
             catalogue_index=tuple(catalogue_index_full),
             collection=collection),
        tuple(catalogue_filtered.index.values)))

def collection_get_row(document_id, catalogue_index, collection):
    return collection.getrow(catalogue_index.index(document_id))

collection_partial = partial(
    collection_filter_row,
    catalogue_index_full=catalogue.index.values,
    collection=pickle.load(open('collection-tfidf', 'rb')))
criteria = catalogue['criteria'].where(catalogue.criteria != '')
collection_state = collection_partial(criteria)

but even with any sort of multiprocessing (gevent, threadpool), it is still slow to pick the respective rows, am I doing anything wrong (or rather, is there a faster way of doing this)?

Somehow found a faster way to solve this problem. Start by creating a dictionary of catalogue index => collection index.

index_dict = dict(zip(
    catalogue.index.values.tolist(),
    range(collection.shape[0])))

Then my collection_filter_row becomes

def collection_filter_row(catalogue_filtered, index_dict, collection):
    return collection[[index_dict[document_id]
                       for document_id
                       in catalogue_filtered.index.values.tolist()]]

In order to return a subset of collection, instead of using catalogue.where() I really should be using catalogue.loc[catalogue.some_column != ''] , so the proper call to collection_filter_row is then

collection_sub = collection_filter_row(
    catalogue.loc[catalogue.some_column != ''],
    index_dict,
    collection)

much much faster than the original method shown in question

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM