简体   繁体   English

如何从列表/ ndarray获取索引?

[英]How to get indices from a list/ndarray?

I have a list which looks like: 我有一个看起来像这样的清单:

[[0,1,2], [1,2,3], [2,3,4], [3,4,5]]

I can make it to an array like: 我可以将其设置为如下数组:

array([[0,1,2],
       [1,2,3],
       [2,3,4],
       [3,4,5]])

So all together I have 4 rows and each row has 3 columns. 因此,我总共有4行,每行有3列。 Now I want to find the indices of all the elements which are greater than 2, so for the whole matrix, the indices should be: 现在,我想查找所有大于2的元素的索引,因此对于整个矩阵,索引应为:

((1,2),(2,1),(2,2),(3,1),(3,2),(3,3))

Then for each row, I will randomly picked out a col index which indicates a value greater than 2. Now my code is like: 然后,对于每一行,我将随机选择一个col索引,该索引指示值大于2。现在我的代码如下:

a = np.array([[0,1,2],[1,2,3],[2,3,4],[3,4,5]]
out = np.ones(4)*-1
cur_row = 0
col_list = []
for r,c in np.nonzero(a>2):
    if r == cur_row:
        col_list.append(c)
    else:
        cur_row = r
        shuffled_list = shuffle(col_list)
        out[r-1] = shuffled_list[0]
        col_list = []
        col_list.append(c) 

I hope to get a out which looks like: 我希望得到一个看起来像这样的东西:

array([-1, 2, 1, 2])

However, now when I run my code, it shows 但是,现在当我运行代码时,它显示

ValueError: too many values to unpack

Anyone knows how I fix this problem? 有人知道我该如何解决这个问题? Or how should I do to achieve my goal? 还是我应该如何实现自己的目标? I just want to run the code as fast as possible, so any other good ideas is also more than welcome. 我只是想尽可能快地运行代码,因此任何其他好主意也非常受欢迎。

Try this. 尝试这个。

import numpy as np

arr = np.array([[0,1,2],
               [1,2,3],
               [2,3,4],
               [3,4,5]])
indices = np.where(arr>2)

for r, c in zip(*indices):
    print(r, c)

Prints 打印

1 2
2 1
2 2
3 0
3 1
3 2

So, it should work. 因此,它应该工作。 You can use itertools.izip as well, it would even be a better choice in this case. 您也可以使用itertools.izip ,在这种情况下,它甚至是更好的选择。

A pure numpy solution (thanks to @AshwiniChaudhary for the proposition): 一个纯粹的numpy解决方案(感谢@AshwiniChaudhary的建议):

for r, c in np.vstack(np.where(arr>2)).T:
    ...

though I'm not sure this will be faster than using izip or zip. 尽管我不确定这会比使用izip或zip更快。

You could just compare the array to your value and use where. 您可以将数组与您的值进行比较,然后在何处使用。

a = np.array([[0,1,2],[1,2,3],[2,3,4],[3,4,5]])
np.where(a>2)

(array([1, 2, 2, 3, 3, 3], dtype=int64), array([2, 1, 2, 0, 1, 2], dtype=int64)) (数组([1、2、2、3、3、3],dtype = int64),数组([2、1、2、0、1、2],dtype = int64))

To get your tuples 获取元组

list(zip(*np.where(a>2)))

[(1, 2), (2, 1), (2, 2), (3, 0), (3, 1), (3, 2)] [(1、2),(2、1),(2、2),(3、0),(3、1),(3、2)]

I have made it out, the code should be: 我已经弄清楚了,代码应该是:

a = np.array([[0,1,2],[1,2,3],[2,3,4],[3,4,5]])
out = np.ones(4)*-1
cur_row = 0
col_list = []
for r,c in zip(*(np.nonzero(a>2))):
    if  r == cur_row:
        col_list.append(c)
    else:
        cur_row = r
        shuffle(col_list)
        if len(col_list) == 0:
            out[r-1] = -1
        else:
            out[r-1] = col_list[0]
        col_list = []
        col_list.append(c)

shuffle(col_list)
if len(col_list) == 0:
    out[len(out)-1] = -1
else:
    out[len(out)-1] = col_list[0]

The part in the end but outside the forloop is to make sure that the last row will be taken care of. 最后但在forloop之外的部分是确保最后一行将得到处理。

It works in my case. 在我的情况下有效。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM