简体   繁体   中英

Split numpy 2D array based on separate label array

I have a 2D numpy array A . For example:

A = np.array([[1, 2],
              [3, 4],
              [5, 6],
              [7, 8],
              [9, 0]])

I have another label array B corresponding to rows of A . For example:

B = np.array([0, 1, 2, 0, 1])

I want to split A into 3 arrays based on their labels, so the result would be:

[[[1, 2],
  [7, 8]],
 [[3, 4],
  [9, 0]],
 [[5, 6]]]

Are there any numpy built in functions to achieve this?

Right now, my solution is rather ugly and involves repeating calling numpy.where in a for -loop, and slicing the indices tuples to contain only the rows.

Here's one way to do it:

  1. hstack both the array together.
  2. sort the array by the last column
  3. split the array based on unique value index
a = np.hstack((A,B[:,None]))
a = a[a[:, -1].argsort()]
a = np.split(a[:,:-1], np.unique(a[:, -1], return_index=True)[1][1:])
OUTPUT:
[array([[1, 2],
        [7, 8]]),
 array([[3, 4],
        [9, 0]]),
 array([[5, 6]])]

You could also use Pandas for this because it's designed for labelled data and has a powerful groupby method.

import pandas as pd
index = pd.Index(B, name='label')
df = pd.DataFrame(A, index=index)
groups = {k: v.values for k, v in df.groupby('label')}
print(groups)

This produces a dictionary of arrays of the grouped values:

{0: array([[1, 2],
        [7, 8]]), 1: array([[3, 4],
        [9, 0]]), 2: array([[5, 6]])}

For a list of the arrays you can do this instead:

groups = [v.values for k, v in df.groupby('label')]

If the output can always be an array because the labels are equally distributed, you only need to sort the data by label:

idx = B.argsort()
n = np.flatnonzero(np.diff(idx))[0] + 1
result = A[idx].reshape(n, A.shape[0] // n, A.shape[1])

If the labels aren't equally distributed, you'll have to make a list in the outer dimension:

_, indices, counts = np.unique(B, return_counts=True, return_inverse=True)
result = np.split(A[indices.argsort()], counts.cumsum()[:-1])

Using the equivalent of np.where is not very efficient, but you can do it without a loop:

b, idx = np.unique(B, return_inverse=True)
mask = idx[:, None] == np.arange(b.size)
result = np.split(A[idx.argsort()], np.count_nonzero(mask, axis=0).cumsum()[:-1])

You can compute the mask simulataneously for all the labels and apply it to the sorted A ( A[idx.argsort()] ) by counting the number of matching elements in each category ( np.count_nonzero(mask, axis=0).cumsum() ). The last index is stripped off the cumulative sum because np.split always adds an implicit total index.

This is probably the simplest way:

groups = [A[B == label, :] for label in np.unique(B)]
print(groups)

Output:

[array([[1, 2],
       [7, 8]]), array([[3, 4],
       [9, 0]]), array([[5, 6]])]

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM