简体   繁体   中英

Remove and extract minimum from Numpy ndarray along 3rd dimension

I have an array of shape (3, 4, 3) which represents an image of 3x4 pixels and with 3 color channels (the last index). My goal is to be left with a (3, 4, 2) array where each pixel has the lowest color channel removed. I could do a pixelwise iteration but this would be very time consuming. Using np.argmin I can easily extract the index of the minimum value, so that I know which color channel contains the minimum value for each pixel. However, I cannot find a clever way of indexing to remove these values so that I am left with a (3, 4, 2) array.

Additionally, I also tried to select the minimum values by using something like array[:, :, indexMin)] but was unable to get the desired array of shape (3, 4) that contains the minimum channel value for each pixel. I know there is a function np.amin for this, but it will give me a better understanding of Numpy arrays. A minimum example of my code structure is provided below:

import numpy as np

arr = [[1, 2, 3, 4],
       [5, 6, 7, 8],
       [9, 10, 11, 12]]
array = np.zeros((3, 4, 3))
array[:,:,0] = arr
array[:,:,1] = np.fliplr(arr)
array[:,:,2] = np.flipud(arr)
indexMin = np.argmin(array, axis=2)

You need arrays that broadcast correctly to the output shape you want. You can add the missing dimension back using np.expand_dims :

index = np.expand_dims(np.argmin(array, axis=2), axis=2)

This makes it easy to set or extract the elements that you want to remove:

index = list(np.indices(array.shape, sparse=True))
index[-1] = np.expand_dims(np.argmin(array, axis=2), axis=2)
minima = array[tuple(index)]

np.indices with sparse=True returns a set of ranges shaped to broadcast the index correctly in each dimension. A nicer alternative is to use np.take_along_axis :

index = np.expand_dims(np.argmin(array, axis=2), axis=2)
minima = np.take_along_axis(array, index, axis=2)

You can use these results to create a mask, eg with np.put_along_axis :

mask = np.ones(array.shape, dtype=bool)
np.put_along_axis(mask, index, 0, axis=2)

Indexing the array with the mask gives you:

result = array[mask].reshape(*array.shape[:2], -1)

The reshape works because your pixels are stored in the last dimension, which should be contiguous in memory. That means that the mask removes one out of three elements correctly, and thus is ordered correctly in memory. That is not usual with masking operations.

Another alternative is to use np.delete with a raveled array and np.ravel_multi_index :

i = np.indices(array.shape[:2], sparse=True)
index = np.ravel_multi_index((*i, np.argmin(array, axis=2)), array.shape)
result = np.delete(array.ravel(), index).reshape(*array.shape[:2], -1)

Just for fun, you can use the fact that you only have three elements per pixel to create a full index of the elements you want to keep. The idea is that the sum of all three indices is 3 . Therefore, 3 - np.argmin(array, axis=2) - np.argmax(array, axis=2) is the median element. If you stack the median and the max, you get an index similar to what sort gives you:

amax = np.argmax(array, axis=2)
amin = np.argmin(array, axis=2)
index = np.stack((np.clip(3 - (amin + amax), 0, 2), amax), axis=2)
result = np.take_along_axis(array, index, axis=2)

The call to np.clip is necessary to handle the case where all elements are equal, in which case both argmax and argmin return zero.

Timing

Comparing the approaches:

def remove_min_indices(array):
    index = list(np.indices(array.shape, sparse=True))
    index[-1] = np.expand_dims(np.argmin(array, axis=2), axis=2)
    mask = np.ones(array.shape, dtype=bool)
    mask[tuple(index)] = False
    return array[mask].reshape(*array.shape[:2], -1)

def remove_min_put(array):
    mask = np.ones(array.shape, dtype=bool)
    np.put_along_axis(mask, np.expand_dims(np.argmin(array, axis=2), axis=2), 0, axis=2)
    return array[mask].reshape(*array.shape[:2], -1)

def remove_min_delete(array):
    i = np.indices(array.shape[:2], sparse=True)
    index = np.ravel_multi_index((*i, np.argmin(array, axis=2)), array.shape)
    return np.delete(array.ravel(), index).reshape(*array.shape[:2], -1)

def remove_min_sort_c(array):
    return np.sort(array, axis=2)[..., 1:]

def remove_min_sort_i(array):
    array.sort(axis=2)
    return array[..., 1:]

def remove_min_median(array):
    amax = np.argmax(array, axis=2)
    amin = np.argmin(array, axis=2)
    index = np.stack((np.clip(3 - (amin + amax), 0, 2), amax), axis=2)
    return np.take_along_axis(array, index, axis=2)

Tested for arrays made like array = np.random.randint(10, size=(N, N, 3), dtype=np.uint8) for N in {100, 1K, 10K, 100K, 1M} :

  N  |   IND   |   PUT   |   DEL   |   S_C   |   S_I   |   MED   |
-----+---------+---------+---------+---------+---------+---------+
100  | 648. µs | 658. µs | 765. µs | 497. µs | 246. µs | 905. µs |
  1K | 67.9 ms | 68.1 ms | 85.7 ms | 51.7 ms | 24.0 ms | 123. ms |
 10K | 6.86 s  | 6.86 s  | 8.72 s  | 5.17 s  | 2.39 s  | 13.2 s  |
-----+---------+---------+---------+---------+---------+---------+

Times scale with N^2 , as expected. Sorting returns a different result from the other approaches, but is clearly the most efficient. Masking with put_along_axis seems to be the more efficient approach for larger arrays, while raw indexing seems to be more efficient for smaller ones.

If you do not care about the output order then array.partition(1) , array.sort() and np.sort(array) are fastest. On a 10K random matrix array.sort() wins by a small margin. array[..., 1:] will contain the results. This is interesting because one might expect that partition() is somewhat faster than sort() .

Python 3.7, Numpy 1.16.5, with MKL
----------------------------------
array.partition : 5.57 s
np.sort         : 5.15 s
array.sort      : 5.02 s

If order is important credits go to Mad Physicist's numpy.put_along_axis solution. My previously suggested solution is about twice as slow as this:

ind = array.argpartition(1)[..., 1:]
ind.sort()
np.take_along_axis(array, ind, -1)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM