简体   繁体   中英

Numpy - How to remove trailing N*8 zeros

I have 1d array, I need to remove all trailing blocks of 8 zeros.

[0,1,1,0,1,0,0,0, 0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0]
->
[0,1,1,0,1,0,0,0]

a.shape[0] % 8 == 0 always, so no worries about that.

Is there a better way to do it?

import numpy as np
P = 8
arr1 = np.random.randint(2,size=np.random.randint(5,10) * P)
arr2 = np.random.randint(1,size=np.random.randint(5,10) * P)
arr = np.concatenate((arr1, arr2))

indexes = []
arr = np.flip(arr).reshape(arr.shape[0] // P, P)

for i, f in enumerate(arr):
    if (f == 0).all():
        indexes.append(i)
    else:
        break

arr = np.delete(arr, indexes, axis=0)
arr = np.flip(arr.reshape(arr.shape[0] * P))

You can do it without allocating more space by using views and np.argmax to get the last nonzero element:

index = arr.size - np.argmax(arr[::-1])

Rounding up to the nearest multiple of eight is easy:

index = np.ceil(index / 8) * 8

Now chop off the rest:

arr = arr[:index]

Or as a one-liner:

arr = arr[:(arr.size - np.argmax(arr[::-1])) / 8) * 8]

This version is O(n) in time and O(1) in space because it reuses the same buffers for everything (including the output).

This has the additional advantage that it will work correctly even if there are no trailing zeros. Using argmax does rely on all the elements being the same though. If that is not the case, you will need to compute a mask first, eg with arr.astype(bool) .

If you want to use your original approach, you could vectorize that too, although there will be a bit more overhead:

view = arr.reshape(-1, 8)
mask = view.any(axis = 1)
index = view.shape[0] - np.argmax(mask[::-1])
arr = arr[:index * 8]

There is a numpy function that does almost what you want np.trim_zeros . We can use that:

import numpy as np

def trim_mod(a, m=8):
    t = np.trim_zeros(a, 'b')
    return a[:len(a)-(len(a)-len(t))//m*m]

def test(a, t, m=8):
    assert (len(a) - len(t)) % m == 0
    assert len(t) < m or np.any(t[-m:])
    assert not np.any(a[len(t):])

for _ in range(1000):
    a = (np.random.random(np.random.randint(10, 100000))<0.002).astype(int)
    m = np.random.randint(4, 20)
    t = trim_mod(a, m)
    test(a, t, m)

print("Looks correct")

Prints:

Looks correct

It seems to scale linearly in the number of trailing zeros:

在此处输入图片说明

But feels rather slow in absolute terms (units are ms per trial), so maybe np.trim_zeros is just a python loop.

Code for the picture:

from timeit import timeit

A = (np.random.random(1000000)<0.02).astype(int)
m = 8
T = []
for last in range(1, 1000, 9):
    A[-last:] = 0
    A[-last] = 1
    T.append(timeit(lambda: trim_mod(A, m), number=100)*10)

import pylab
pylab.plot(range(1, 1000, 9), T)
pylab.show()

A low level approach :

import numba
@numba.njit
def trim8(a):
    n=a.size-1
    while n>=0 and a[n]==0 : n-=1
    c= (n//8+1)*8
    return a[:c]

Some tests :

In [194]: A[-1]=1  # best case

In [196]: %timeit trim_mod(A,8)
5.7 µs ± 323 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

In [197]: %timeit trim8(A)
714 ns ± 33.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)

In [198]: %timeit A[:(A.size - np.argmax(A[::-1]) // 8) * 8]
4.83 ms ± 479 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

In [202]: A[:]=0 #worst case

In [203]: %timeit trim_mod(A,8)
2.5 s ± 49.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [204]: %timeit trim8(A)
1.14 ms ± 71.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

In [205]: %timeit A[:(A.size - np.argmax(A[::-1]) // 8) * 8]
5.5 ms ± 950 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

It has a short circuit mechanism like trim_zeros , but is much faster.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM