[英]How to get indices from a list/ndarray?
I have a list which looks like: 我有一个看起来像这样的清单:
[[0,1,2], [1,2,3], [2,3,4], [3,4,5]]
I can make it to an array like: 我可以将其设置为如下数组:
array([[0,1,2],
[1,2,3],
[2,3,4],
[3,4,5]])
So all together I have 4 rows and each row has 3 columns. 因此,我总共有4行,每行有3列。 Now I want to find the indices of all the elements which are greater than 2, so for the whole matrix, the indices should be:
现在,我想查找所有大于2的元素的索引,因此对于整个矩阵,索引应为:
((1,2),(2,1),(2,2),(3,1),(3,2),(3,3))
Then for each row, I will randomly picked out a col index which indicates a value greater than 2. Now my code is like: 然后,对于每一行,我将随机选择一个col索引,该索引指示值大于2。现在我的代码如下:
a = np.array([[0,1,2],[1,2,3],[2,3,4],[3,4,5]]
out = np.ones(4)*-1
cur_row = 0
col_list = []
for r,c in np.nonzero(a>2):
if r == cur_row:
col_list.append(c)
else:
cur_row = r
shuffled_list = shuffle(col_list)
out[r-1] = shuffled_list[0]
col_list = []
col_list.append(c)
I hope to get a out which looks like: 我希望得到一个看起来像这样的东西:
array([-1, 2, 1, 2])
However, now when I run my code, it shows 但是,现在当我运行代码时,它显示
ValueError: too many values to unpack
Anyone knows how I fix this problem? 有人知道我该如何解决这个问题? Or how should I do to achieve my goal?
还是我应该如何实现自己的目标? I just want to run the code as fast as possible, so any other good ideas is also more than welcome.
我只是想尽可能快地运行代码,因此任何其他好主意也非常受欢迎。
Try this. 尝试这个。
import numpy as np
arr = np.array([[0,1,2],
[1,2,3],
[2,3,4],
[3,4,5]])
indices = np.where(arr>2)
for r, c in zip(*indices):
print(r, c)
Prints 打印
1 2
2 1
2 2
3 0
3 1
3 2
So, it should work. 因此,它应该工作。 You can use
itertools.izip
as well, it would even be a better choice in this case. 您也可以使用
itertools.izip
,在这种情况下,它甚至是更好的选择。
A pure numpy
solution (thanks to @AshwiniChaudhary for the proposition): 一个纯粹的
numpy
解决方案(感谢@AshwiniChaudhary的建议):
for r, c in np.vstack(np.where(arr>2)).T:
...
though I'm not sure this will be faster than using izip or zip. 尽管我不确定这会比使用izip或zip更快。
You could just compare the array to your value and use where. 您可以将数组与您的值进行比较,然后在何处使用。
a = np.array([[0,1,2],[1,2,3],[2,3,4],[3,4,5]])
np.where(a>2)
(array([1, 2, 2, 3, 3, 3], dtype=int64), array([2, 1, 2, 0, 1, 2], dtype=int64))
(数组([1、2、2、3、3、3],dtype = int64),数组([2、1、2、0、1、2],dtype = int64))
To get your tuples 获取元组
list(zip(*np.where(a>2)))
[(1, 2), (2, 1), (2, 2), (3, 0), (3, 1), (3, 2)]
[(1、2),(2、1),(2、2),(3、0),(3、1),(3、2)]
I have made it out, the code should be: 我已经弄清楚了,代码应该是:
a = np.array([[0,1,2],[1,2,3],[2,3,4],[3,4,5]])
out = np.ones(4)*-1
cur_row = 0
col_list = []
for r,c in zip(*(np.nonzero(a>2))):
if r == cur_row:
col_list.append(c)
else:
cur_row = r
shuffle(col_list)
if len(col_list) == 0:
out[r-1] = -1
else:
out[r-1] = col_list[0]
col_list = []
col_list.append(c)
shuffle(col_list)
if len(col_list) == 0:
out[len(out)-1] = -1
else:
out[len(out)-1] = col_list[0]
The part in the end but outside the forloop is to make sure that the last row will be taken care of. 最后但在forloop之外的部分是确保最后一行将得到处理。
It works in my case. 在我的情况下有效。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.