[英]Numpy - How to remove trailing N*8 zeros
我有1d數組,我需要刪除所有8個零的尾隨塊。
[0,1,1,0,1,0,0,0, 0,0,0,0,0,0,0,0, 0,0,0,0,0,0,0,0]
->
[0,1,1,0,1,0,0,0]
a.shape[0] % 8 == 0
總是這樣,所以不用擔心。
有更好的方法嗎?
import numpy as np
P = 8
arr1 = np.random.randint(2,size=np.random.randint(5,10) * P)
arr2 = np.random.randint(1,size=np.random.randint(5,10) * P)
arr = np.concatenate((arr1, arr2))
indexes = []
arr = np.flip(arr).reshape(arr.shape[0] // P, P)
for i, f in enumerate(arr):
if (f == 0).all():
indexes.append(i)
else:
break
arr = np.delete(arr, indexes, axis=0)
arr = np.flip(arr.reshape(arr.shape[0] * P))
您可以使用視圖和np.argmax
來獲取最后一個非零元素,而無需分配更多空間來做到這一點:
index = arr.size - np.argmax(arr[::-1])
四舍五入到最接近的八位數很容易:
index = np.ceil(index / 8) * 8
現在砍掉其余部分:
arr = arr[:index]
還是單線:
arr = arr[:(arr.size - np.argmax(arr[::-1])) / 8) * 8]
此版本的時間為O(n)
,空間為O(1)
,因為它對所有內容(包括輸出)都使用相同的緩沖區。
這具有額外的優點,即使沒有尾隨零也可以正常工作。 雖然使用argmax
確實依賴於所有相同的元素。 如果不是這種情況,則需要首先計算一個掩碼,例如使用arr.astype(bool)
。
如果要使用原始方法,也可以向量化,盡管會增加一些開銷:
view = arr.reshape(-1, 8)
mask = view.any(axis = 1)
index = view.shape[0] - np.argmax(mask[::-1])
arr = arr[:index * 8]
有一個numpy函數幾乎可以完成您想要的np.trim_zeros
。 我們可以使用:
import numpy as np
def trim_mod(a, m=8):
t = np.trim_zeros(a, 'b')
return a[:len(a)-(len(a)-len(t))//m*m]
def test(a, t, m=8):
assert (len(a) - len(t)) % m == 0
assert len(t) < m or np.any(t[-m:])
assert not np.any(a[len(t):])
for _ in range(1000):
a = (np.random.random(np.random.randint(10, 100000))<0.002).astype(int)
m = np.random.randint(4, 20)
t = trim_mod(a, m)
test(a, t, m)
print("Looks correct")
打印:
Looks correct
它似乎在尾隨零的數量上呈線性比例:
但是絕對值感覺很慢(每個試驗的單位是毫秒),因此np.trim_zeros
可能只是一個python循環。
圖片代碼:
from timeit import timeit
A = (np.random.random(1000000)<0.02).astype(int)
m = 8
T = []
for last in range(1, 1000, 9):
A[-last:] = 0
A[-last] = 1
T.append(timeit(lambda: trim_mod(A, m), number=100)*10)
import pylab
pylab.plot(range(1, 1000, 9), T)
pylab.show()
低級方法:
import numba
@numba.njit
def trim8(a):
n=a.size-1
while n>=0 and a[n]==0 : n-=1
c= (n//8+1)*8
return a[:c]
一些測試:
In [194]: A[-1]=1 # best case
In [196]: %timeit trim_mod(A,8)
5.7 µs ± 323 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [197]: %timeit trim8(A)
714 ns ± 33.7 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
In [198]: %timeit A[:(A.size - np.argmax(A[::-1]) // 8) * 8]
4.83 ms ± 479 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [202]: A[:]=0 #worst case
In [203]: %timeit trim_mod(A,8)
2.5 s ± 49.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
In [204]: %timeit trim8(A)
1.14 ms ± 71.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
In [205]: %timeit A[:(A.size - np.argmax(A[::-1]) // 8) * 8]
5.5 ms ± 950 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
它具有類似於trim_zeros
短路機制,但速度要快得多。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.