简体   繁体   English

使用矢量化访问过滤 Numpy 元组数组

[英]Filter Numpy array of tuples with vectorized access

I'm working with Pandas MuliIndex .我正在使用 Pandas MuliIndex I use the from_product method.我使用from_product方法。 What I get is Numpy ndarray from the MultiIndex values property:我得到的是来自 MultiIndex values属性的 Numpy ndarray:

d = {'col1': [1, 2], 'col2': [3, 4], 'col3': [5, 6]}
df1 = pd.DataFrame(data=d)
df2 = pd.DataFrame(data=d)


multi_index = pd.MultiIndex.from_product((df1.index, df2.index), names=['idx1', 'idx2']).values

It returns a Ndarray of tuples: [(0, 0) (0, 1) (1, 0) (1, 1)] .它返回一个 Ndarray 元组: [(0, 0) (0, 1) (1, 0) (1, 1)] The problem is that I want to keep only the tuples which both elements are equal .问题是我只想保留两个元素相等的元组 But because they're tuple I can't do vectorizations like this one:但是因为它们是元组,所以我不能做这样的矢量化:

equals = multi_index[multi_index[:, 0] == multi_index[:, 1]]

That would be possible if they were Lists instead of Tuples.如果它们是列表而不是元组,那将是可能的。 Is there a way to filter by tuple's elements (could be a more complex condition than the one above)?有没有办法按元组的元素进行过滤(可能比上面的条件更复杂)?

In case there isn't, what could I do?如果没有,我能做什么? Cast every tuple to list?将每个元组投射到列表中? Maybe iterate over all the elements, but it would be too much slow in comparison with a vectorized solution.也许迭代所有元素,但与矢量化解决方案相比,它会太慢。

Any kind of help would be very appreciated.任何形式的帮助将不胜感激。 Thanks in advance提前致谢

Do not add .values at then end so that you can call get_level_values不要在然后结束添加.values以便您可以调用get_level_values

multi_index = pd.MultiIndex.from_product((df1.index, df2.index), names=['idx1', 'idx2'])
equals = multi_index[multi_index.get_level_values(0) == multi_index.get_level_values(1)]
equals
Out[487]: 
MultiIndex([(0, 0),
            (1, 1)],
           names=['idx1', 'idx2'])

For numpy array对于 numpy 数组

idx = np.array(pd.MultiIndex.from_product((df1.index, df2.index), names=['idx1', 'idx2']).tolist())
multi_index = pd.MultiIndex.from_product((df1.index, df2.index), names=['idx1', 'idx2']).values
equals = multi_index[idx[:, 0] == idx[:, 1]]
equals
Out[497]: array([(0, 0), (1, 1)], dtype=object)

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM