简体   繁体   中英

How do I write this nested for loop as a list comprehension?

I am working with 4D data set, where I have a nested for loop (4 loops). The for loop works, but it takes a while to run: ~5 minutes. I am trying to write this properly with list comprehension instead, but I am getting confused on exactly how to do this given my nested loops:

data = np.random.rand(12, 27, 282, 375)

stdev_data = np.std(data, axis=1)

## nested for loop 

count = []

for i in range(data.shape[0]):
    for j in range(data.shape[1]):
        for lat in range(data.shape[2]):
            for lon in range(data.shape[3]):
                count.append((data[i, j, lat, lon] < -1.282 * stdev_data[i, lat, lon]).sum(axis=0))

reshape_counts = np.reshape(count, data.shape)

This is my attempt at the list comprehension:

i, j, lat, lon = data.shape[0], data.shape[1], data.shape[2], data.shape[3]
print(i, j, lat, lon)

test_list = [[(data < -1.282 * stdev_data).sum(axis=0) for lon in lat] for j in i]

I get an error saying 'int' object is not iterable. How do I rewrite my nested for loop in the form of list comprehension to speed up the process?

Given that you are using numpy, I suggest you take advantage of the fact that their for loops are written in C, and often optimized. You will still end up stepping through the data, but a lot faster. This approach is called vectorization.

In this case, you seek to make a boolean mask, which arguably simplifies the operation. Keep in mind that the .sum() call in your expression is a red herring: you are actually summing a scalar boolean, which will always give you zero or one.

Here is how you would find points smaller than -1.282 of the sigma in the second dimension:

result = data < -1.282 * stdev_data[:, None, ...]

Alternatively, you could do

result = data < -1.282 * stdev_data.reshape(stdev_data.shape[0], 1, *stdev_data.shape[1:])

or

result = data < -1.282 * np.reshape(stdev_data, stdev_data.shape[:1] + (1,) + stdev_data.shape[1:])

An even easier solution would be to pass keepdims=True tonp.std from the very beginning:

result = data < -1.282 * np.std(data, axis=1, keepdims=True)

keepdims=True ensures that the output of std has the shape (12, 1, 282, 375) instead of just (12, 282, 375) , so you don't need to re-insert the dimension yourself.

Now if you actually wanted to compute the counts as your question seems to imply, you could just sum the result mask along the second dimension:

counts = result.sum(axis=1)

Finally, to answer your actual question exactly as stated: for loops translate directly into list comprehensions. In your case, that means four for s in the comprehension, in exactly the order you originally had them:

[data[i, j, lat, lon] < -1.282 * stdev_data[i, lat, lon]
    for i in range(data.shape[0])
        for j in range(data.shape[1])
            for lat in range(data.shape[2])
                for lon in range(data.shape[3])]

Since comprehensions are surrounded by brackets, you are free to write their contents on separate lines as I've done, although this is of course not required. Notice that the only real differences are that the contents of the append comes first and there are no colons. Also, that red herring sum is gone.

I don't think the sum would do anything except convert False to 0 and True to 1 since you'd just be comparing two numbers to each other. I think this would do the same thing (I couldn't find a way to get rid of last loop, if you really need it to be faster maybe joblib or numba would help but I don't use them very much so not sure):

count = np.empty(data.shape)
for j in range(data.shape[1]):
    count[:,j,...] = (data[:,j,...] < -1.282*stdev_data).astype(np.int32)

But also the standard deviation can't be negative so nothing will satisfy the above condition since you're multiplying by a negative number but all your data is between 0 and 1 so I'd recommend double checking everything

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM