简体   繁体   中英

Numpy iteration over all dimensions but the last one with unknown number of dimensions

Physical Background

I'm working on a function that calculates some metrics for each vertical profile in an up to four dimensional temperature field (time, longitude, latitude, pressure as height measure). I have a working function that takes the pressure and temperature at a single location and returns the metrics (tropopause information). I want to wrap it with a function that applies it to every vertical profile in the data passed.

Technical Description of the Problem

I want my function to apply another function to every 1D array corresponding to the last dimension in my N-dimensional array, where N <= 4. So I need an efficient loop over all dimensions but the last one without knowing the number of dimensions beforehand.

Why I Open a New Question

I am aware of several questions (eg, iterating over some dimensions of a ndarray , Iterating over the last dimensions of a numpy array , Iterating over 3D numpy using one dimension as iterator remaining dimensions in the loop , Iterating over a numpy matrix with unknown dimension ) asking how to iterate over a specific dimension or how to iterate over an array with unknown dimensions. The combination of these two problems is new as far as I know. Using numpy.nditer for example I haven't found out how to exclude just the last dimension regardless of the number of dimensions left.

EDIT

I tried to do a minimal, reproducible example:

import numpy as np

def outer_function(array, *args):
    """
    Array can be 1D, 2D, 3D, or 4D. Regardless the inner_function 
    should be applied to all 1D arrays spanned by the last axis
    """
    # Unpythonic if-else solution
    if array.ndim == 1:
        return inner_function(array)
    elif array.ndim == 2:
        return [inner_function(array[i,:]) for i in range(array.shape[0])]
    elif array.ndim == 3:
        return [[inner_function(array[i,j,:]) for i in range(array.shape[0])] for j in range(array.shape[1])]
    elif array.ndim == 4:
        return [[[inner_function(array[i,j,k,:]) for i in range(array.shape[0])] for j in range(array.shape[1])] for k in range(array.shape[2])]
    else:
        return -1

def inner_function(array_1d):
    return np.interp(2, np.arange(array_1d.shape[0]), array_1d), np.sum(array_1d)

Please assume that the actual inner_function cannot be modified to be applied to multiple dimensions but only to 1D-arrays.

end of edit

In case it helps here the structure of the code I have/want to have:

def tropopause_ds(ds):
    """
    wraps around tropopause profile calculation. The vertical coordinate has to be the last one.
    """
    
    t = ds.t.values # numpy ndarray
    p_profile = ds.plev.values # 1d numpy ndarray

    len_t = ds.time.size
    len_lon = ds.lon.size
    len_lat = ds.lat.size
    nlevs = ds.plev.size

    ttp = np.empty([len_t, len_lon, len_lat])
    ptp = np.empty([len_t, len_lon, len_lat])
    ztp = np.empty([len_t, len_lon, len_lat])
    dztp = np.empty([len_t, len_lon, len_lat, nlevs])

    # Approach 1: use numpy.ndindex - doesn't work in a list comprehension, slow
    for idx in np.ndindex(*t.shape[:-1]):
        ttp[idx], ptp[idx], ztp[idx], dztp[idx] = tropopause_profile(t[idx], p_profile)

    # Approach 2: use nested list comprehensions - doesn't work for different number of dimensions
    ttp, ptp, ztp, dztp = [[[tropopause_profile(t[i,j,k,:], p_profile) for k in range(len_lat)]
                            for j in range(len_lon)] for i in range(len_t)]

    return ttp, ptp, ztp, dztp

with the inner function's structure as follows:

def tropopause_profile(t_profile, p_profile):
    if tropopause found:
        return ttp, ptp, ztp, dztp
    return np.nan, np.nan, np.nan, np.nan

I have already tried several options. The test data in the timed cases had the shape (2, 360, 180, 105):

  • xarray's apply_ufunc which seems to pass the whole array to the function. My inner function however is based on getting a 1d array and would be hard to reprogram to work on multi-dimensional data
  • nested list comprehensions work and seem to be quite fast but would give an error in case one dimension (eg time) only has one value ( timed : 8.53 s ± 11.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each))
  • using numpy's nditer works in a standard for loop which is sped up using a list comprehension. However using this approach the function doesn't return 4 ndarrays but a list containing the four return values for each index as list elements. ( timed with list comprehension: 1min 4s ± 740 ms per loop (mean ± std. dev. of 7 runs, 1 loop each))

An ugly way to solve this would be to check how many dimensions my data has and then do an if else selection to the right number of list comprehensions but I hope that python has a smoother way to solve this. The ordering of the dimensions can be changed easily in case that would help. I ran the code on a 2 cores, 10 GB memory jupyterhub server.

I've used @hpaulj 's reshape approach several times. It means the loop can iterate the whole array by 1d slices.

Simplified the function and data to have something to test.

import numpy as np

arr = np.arange( 2*3*3*2*6 ).reshape( 2,3,3,2,6 )

def inner_function(array_1d):
    return np.array( [ array_1d.sum(), array_1d.mean() ])
    # return np.array( [np.interp(2, np.arange(array_1d.shape[0]), array_1d), np.sum(array_1d) ])

def outer_function( arr, *args ):
    res_shape = list( arr.shape )
    res_shape[ -1 ] = 2

    result = np.zeros( tuple( res_shape ) )  # result has the same shape as arr for n-1 dimensions, then two

    # Reshape arr and result to be 2D arrays.  These are views into arr and result
    work = arr.reshape( -1, arr.shape[-1] )
    res = result.reshape( -1, result.shape[-1] )

    for ix, w1d in enumerate( work ):  # Loop through all 1D 
        res[ix] = inner_function( w1d )
    return result 

outer_function( arr )

The results are

array([[[[[  15. ,    2.5],
          [  51. ,    8.5]],

         [[  87. ,   14.5],
          [ 123. ,   20.5]],

         ...

         [[1167. ,  194.5],
          [1203. ,  200.5]],

         [[1239. ,  206.5],
          [1275. ,  212.5]]]]])

I'm sure this can be further optimised as it is and to take account of the actual functions required for the application.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM