简体   繁体   中英

Is there a way to speed up looping over numpy.where?

Imagine you have a segmentation map, where each object is identified by a unique index, eg looking similar to this:

在此处输入图像描述

For each object, I would like to save which pixels it covers, but I could only come up with the standard for loop so far. Unfortunately, for larger images with thousands of individual objects, this turns out to be very slow--for my real data at least. Can I somehow speed things up?

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from skimage.draw import random_shapes


# please ignore that this does not always produce 20 objects each with a
# unique color. it is simply a quick way to produce data that is similar to
# my problem that can also be visualized.
segmap, labels = random_shapes(
    (100, 100), 20, min_size=6, max_size=20, multichannel=False,
    intensity_range=(0, 20), num_trials=100,
)
segmap = np.ma.masked_where(segmap == 255, segmap)

object_idxs = np.unique(segmap)[:-1]
objects = np.empty(object_idxs.size, dtype=[('idx', 'i4'), ('pixels', 'O')])

# important bit here:

# this I can vectorize
objects['idx'] = object_idxs
# but this I cannot. and it takes forever.
for i in range(object_idxs.size):
    objects[i]['pixels'] = np.where(segmap == i)

# just plotting here
fig, ax = plt.subplots(constrained_layout=True)
image = ax.imshow(
    segmap, cmap='tab20', norm=mpl.colors.Normalize(vmin=0, vmax=20)
)
fig.colorbar(image)
fig.show()

Using np.where in a loop is not efficient algorithmically since the time complexity is O(snm) where s = object_idxs.size and n, m = segmap.shape . This operation can be done in O(nm) .

One solution using Numpy is to first select all the object pixel locations, then sort them based on their associated object in segmap , and finally split them based on the number of objects. Here is the code:

background = np.max(segmap)
mask = segmap != background
objects = segmap[mask]
uniqueObjects, counts = np.unique(objects, return_counts=True)
ordering = np.argsort(objects)
i, j = np.where(mask)
indices = np.vstack([i[ordering], j[ordering]])
indicesPerObject = np.split(indices, counts.cumsum()[:-1], axis=1)

objects = np.empty(uniqueObjects.size, dtype=[('idx', 'i4'), ('pixels', 'O')])
objects['idx'] = uniqueObjects
for i in range(uniqueObjects.size):
    # Use `tuple(...)` to get the exact same type as the initial code here
    objects[i]['pixels'] = tuple(indicesPerObject[i])
# In case the conversion to tuple is not required, the loop can also be accelerated:
# objects['pixels'] = indicesPerObject

It sounds like you would like to see where any object is located. So if we start with one matrix (that is, all shapes are in one array, where empty spaces are zeros and object one consists of 1s, object 2 of 2s etc.) then You can create a mask, showing which pixels (or values in a matrix) are non-zero like this:

my_array != 0

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM