简体   繁体   中英

How to Efficiently Find the Indices of Max Values in a Multidimensional Array of Matrices using Pytorch and/or Numpy

Background

It is common in machine learning to deal with data of a high dimensionality. For example, in a Convolutional Neural Network (CNN) the dimensions of each input image may be 256x256, and each image may have 3 color channels (Red, Green, and Blue). If we assume that the model takes in a batch of 16 images at a time, the dimensionality of the input going into our CNN is [16,3,256,256] . Each individual convolutional layer expects data in the form [batch_size, in_channels, in_y, in_x] , and all of these quantities often change layer-to-layer (except batch_size). The term we use for the matrix made up of the [in_y, in_x] values is feature map , and this question is concerned with finding the maximum value, and its index, in every feature map at a given layer.

Why do I want to do this? I want to apply a mask to every feature map, and I want to apply that mask centered at the max value in each feature map , and to do that I need to know where each max value is located. This mask application is done during both training and testing of the model, so efficiency is vitally important to keep computational times down. There are many Pytorch and Numpy solutions for finding singleton max values and indices, and for finding the maximum values or indices along a single dimension, but no (that I could find) dedicated and efficient built-in functions for finding the indices of maximum values along 2 or more dimensions at a time. Yes, we can nest functions that operate on a single dimension, but these are some of the least efficient approaches.

What I've Tried

  • I've looked at this Stackoverflow question , but the author is dealing with a special-case 4D array which is trivially squeezed to a 3D array. The accepted answer is specialized for this case, and the answer pointing to TopK is misguided because it not only operates on a single dimension, but would necessitate that k=1 given the question asked, thus devlolving to a regular torch.max call.
  • I've looked at this Stackoverflow question , but this question, and its answer, focus on looking through a single dimension.
  • I have looked at this Stackoverflow question , but I already know of the answer's approach as I independently formulated it in my own answer here (where I amended that the approach is very inefficient).
  • I have looked at this Stackoverflow question , but it does not satisfy the key part of this question, which is concerned with efficiency.
  • I have read many other Stackoverflow questions and answers, as well as the Numpy documentation, Pytorch documentation, and posts on the Pytorch forums.
  • I've tried implementing a LOT of varying approaches to this problem, enough that I have created this question so that I can answer it and give back to the community, and anyone who goes looking for a solution to this problem in the future.

Standard of Performance

If I am asking a question about efficiency I need to detail expectations clearly. I am trying to find a time-efficient solution (space is secondary) for the problem above without writing C code/extensions, and which is reasonably flexible (hyper specialized approaches aren't what I'm after). The approach must accept an [a,b,c,d] Torch tensor of datatype float32 or float64 as input, and output an array or tensor of the form [a,b,2] of datatype int32 or int64 (because we are using the output as indices). Solutions should be benchmarked against the following typical solution:

max_indices = torch.stack([torch.stack([(x[k][j]==torch.max(x[k][j])).nonzero()[0] for j in range(x.size()[1])]) for k in range(x.size()[0])])

The Approach

We are going to take advantage of the Numpy community and libraries, as well as the fact that Pytorch tensors and Numpy arrays can be converted to/from one another without copying or moving the underlying arrays in memory (so conversions are low cost). From the Pytorch documentation :

Converting a torch Tensor to a Numpy array and vice versa is a breeze. The torch Tensor and Numpy array will share their underlying memory locations, and changing one will change the other.

Solution One

We are first going to use the Numba library to write a function that will be just-in-time (JIT) compiled upon its first usage, meaning we can get C speeds without having to write C code ourselves. Of course, there are caveats to what can get JIT-ed, and one of those caveats is that we work with Numpy functions. But this isn't too bad because, remember, converting from our torch tensor to Numpy is low cost. The function we create is:

@njit(cache=True)
def indexFunc(array, item):
    for idx, val in np.ndenumerate(array):
        if val == item:
            return idx

This function if from another Stackoverflow answer located here (This was the answer which introduced me to Numba). The function takes an N-Dimensional Numpy array and looks for the first occurrence of a given item . It immediately returns the index of the found item on a successful match. The @njit decorator is short for @jit(nopython=True) , and tells the compiler that we want it to compile the function using no Python objects, and to throw an error if it is not able to do so (Numba is the fastest when no Python objects are used, and speed is what we are after).

With this speedy function backing us, we can get the indices of the max values in a tensor as follows:

import numpy as np

x =  x.numpy()
maxVals = np.amax(x, axis=(2,3))
max_indices = np.zeros((n,p,2),dtype=np.int64)
for index in np.ndindex(x.shape[0],x.shape[1]):
    max_indices[index] = np.asarray(indexFunc(x[index], maxVals[index]),dtype=np.int64)
max_indices = torch.from_numpy(max_indices)

We use np.amax because it can accept a tuple for its axis argument, allowing it to return the max values of each 2D feature map in the 4D input. We initialize max_indices with np.zeros ahead of time because appending to numpy arrays is expensive , so we allocate the space we need ahead of time. This approach is much faster than the Typical Solution in the question (by an order of magnitude), but it also uses a for loop outside the JIT-ed function, so we can improve...

Solution Two

We will use the following solution:

@njit(cache=True)
def indexFunc(array, item):
    for idx, val in np.ndenumerate(array):
        if val == item:
            return idx
    raise RuntimeError

@njit(cache=True, parallel=True)
def indexFunc2(x,maxVals):
    max_indices = np.zeros((x.shape[0],x.shape[1],2),dtype=np.int64)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            max_indices[i,j] = np.asarray(indexFunc(x[i,j], maxVals[i,j]),dtype=np.int64)
    return max_indices

x = x.numpy()
maxVals = np.amax(x, axis=(2,3))
max_indices = torch.from_numpy(indexFunc2(x,maxVals))

Instead of iterating through our feature maps one-at-a-time with a for loop, we can take advantage of parallelization using Numba's prange function (which behaves exactly like range but tells the compiler we want the loop to be parallelized) and the parallel=True decorator argument. Numba also parallelizes the np.zeros function . Because our function is compiled Just-In-Time and uses no Python objects, Numba can take advantage of all the threads available in our system! It is worth noting that there is now a raise RuntimeError in the indexFunc . We need to include this, otherwise the Numba compiler will try to infer the return type of the function and infer that it will either be an array or None. This doesn't jive with our usage in indexFunc2 , so the compiler would throw an error. Of course, from our setup we know that indexFunc will always return an array, so we can simply raise and error in the other logical branch.

This approach is functionally identical to Solution One, but changes the iteration using nd.index into two for loops using prange . This approach is about 4x faster than Solution One.

Solution Three

Solution Two is fast, but it is still finding the max values using regular Python. Can we speed this up using a more comprehensive JIT-ed function?

@njit(cache=True)
def indexFunc(array, item):
    for idx, val in np.ndenumerate(array):
        if val == item:
            return idx
    raise RuntimeError

@njit(cache=True, parallel=True)
def indexFunc3(x):
    maxVals = np.zeros((x.shape[0],x.shape[1]),dtype=np.float32)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            maxVals[i][j] = np.max(x[i][j])
    max_indices = np.zeros((x.shape[0],x.shape[1],2),dtype=np.int64)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            x[i][j] == np.max(x[i][j])
            max_indices[i,j] = np.asarray(indexFunc(x[i,j], maxVals[i,j]),dtype=np.int64)
    return max_indices

max_indices = torch.from_numpy(indexFunc3(x))

It might look like there is a lot more going on in this solution, but the only change is that instead of calculating the maximum values of each feature map using np.amax , we have now parallelized the operation. This approach is marginally faster than Solution Two.

Solution Four

This solution is the best I've been able to come up with:

@njit(cache=True, parallel=True)
def indexFunc4(x):
    max_indices = np.zeros((x.shape[0],x.shape[1],2),dtype=np.int64)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            maxTemp = np.argmax(x[i][j])
            max_indices[i][j] = [maxTemp // x.shape[2], maxTemp % x.shape[2]] 
    return max_indices

max_indices = torch.from_numpy(indexFunc4(x))

This approach is more condensed and also the fastest at 33% faster than Solution Three and 50x faster than the Typical Solution. We use np.argmax to get the index of the max value of each feature map, but np.argmax only returns the index as if each feature map were flattened. That is, we get a single integer telling us which number the element is in our feature map, not the indices we need to be able to access that element. The math [maxTemp // x.shape[2], maxTemp % x.shape[2]] is to turn that singular int into the [row,column] that we need.

Benchmarking

All approaches were benchmarked together against a random input of shape [32,d,64,64] , where d was incremented from 5 to 245. For each d, 15 samples were gathered and the times were averaged. An equality test ensured that all solutions provided identical values. An example of the benchmark output is:

解决方案基准

A plot of the benchmarking times as d increased is (leaving out the Typical Solution so the graph isn't squashed):

基准图

Woah! What is going on at the start with those spikes?

Solution Five

Numba allows us to produce Just-In-Time compiled functions, but it doesn't compile them until the first time we use them; It then caches the result for when we call the function again. This means the very first time we call our JIT-ed functions we get a spike in compute time as the function is compiled. Luckily, there is a way around this- if we specify ahead of time what our function's return type and argument types will be, the function will be eagerly compiled instead of compiled just-in-time. Applying this knowledge to Solution Four we get:

@njit('i8[:,:,:](f4[:,:,:,:])',cache=True, parallel=True)
def indexFunc4(x):
    max_indices = np.zeros((x.shape[0],x.shape[1],2),dtype=np.int64)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            maxTemp = np.argmax(x[i][j])
            max_indices[i][j] = [maxTemp // x.shape[2], maxTemp % x.shape[2]] 
    return max_indices    

max_indices6 = torch.from_numpy(indexFunc4(x))

And if we restart our kernel and rerun our benchmark, we can look at the first result where d==5 and the second result where d==10 and note that all of the JIT-ed solutions were slower when d==5 because they had to be compiled, except for Solution Four, because we explicitly provided the function signature ahead of time:

在此处输入图片说明

There we go! That's the best solution I have so far for this problem.


EDIT #1

Solution Six

An improved solution has been developed which is 33% faster than the previously posted best solution. This solution only works if the input array is C-contiguous, but this isn't a big restriction since numpy arrays or torch tensors will be contiguous unless they are reshaped, and both have functions to make the array/tensor contiguous if needed.

This solution is the same as the previous best, but the function decorator which specifies the input and return types are changed from

@njit('i8[:,:,:](f4[:,:,:,:])',cache=True, parallel=True)

to

@njit('i8[:,:,::1](f4[:,:,:,::1])',cache=True, parallel=True)

The only difference is that the last : in each array typing becomes ::1 , which signals to the numba njit compiler that the input arrays are C-contiguous, allowing it to better optimize.

The full solution six is then:

@njit('i8[:,:,::1](f4[:,:,:,::1])',cache=True, parallel=True)
def indexFunc5(x):
    max_indices = np.zeros((x.shape[0],x.shape[1],2),dtype=np.int64)
    for i in prange(x.shape[0]):
        for j in prange(x.shape[1]):
            maxTemp = np.argmax(x[i][j])
            max_indices[i][j] = [maxTemp // x.shape[2], maxTemp % x.shape[2]] 
    return max_indices 

max_indices7 = torch.from_numpy(indexFunc5(x))

The benchmark including this new solution confirms the speedup:

基准包括解决方案 6

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM