[英]Get Indices To Split NumPy Array
假設我有一個 NumPy 數組:
x = np.array([3, 9, 2, 1, 5, 4, 7, 7, 8, 6])
如果我總結這個數組,我得到52
。 我需要的是一種將它從左到右分成大約n
塊的方法,其中n
由用戶選擇。 本質上,分裂以貪婪的方式發生。 因此,對於某些數量的塊n
,前n - 1
塊必須每個總和至少為52/n
並且它們必須是從左到右的連續索引。
因此,如果n = 2
,那么第一個塊將由前 7 個元素組成:
chunk[0] = x[:7] # [3, 9, 2, 1, 5, 4, 7], sum = 31
chunk[1] = x[7:] # [7, 8, 6], sum = 21
請注意,第一個塊不會僅包含前 6 個元素,因為總和為24
,小於52/2 = 26
。 此外,請注意,只要滿足總和標准,每個塊中的元素數量就可以變化。 最后,最后一個塊不接近52/2 = 26
是完全可以的,因為其他塊可能需要更多。
但是,我需要的 output 是一個兩列數組,其中包含第一列中的開始索引和第二列中的(獨占)停止索引:
[[0, 7],
[7, 10]]
如果n = 4
,那么前 3 個塊需要每個總和至少為52/4 = 13
並且看起來像這樣:
chunk[0] = x[:3] # [3, 9, 2], sum = 14
chunk[1] = x[3:7] # [1, 5, 4], sum = 17
chunk[2] = x[7:9] # [7, 8], sum = 15
chunk[3] = x[9:] # [6], sum = 6
我需要的 output 是:
[[0, 3],
[3, 7],
[7, 9],
[9, 10]
因此,使用 for 循環的一種天真的方法可能是:
ranges = np.zeros((n_chunks, 2), np.int64)
ranges_idx = 0
range_start_idx = start
sum = 0
for i in range(x.shape[0]):
sum += x[i]
if sum > x.sum() / n_chunks:
ranges[ranges_idx, 0] = range_start_idx
ranges[ranges_idx, 1] = min(
i + 1, x.shape[0]
) # Exclusive stop index
# Reset and Update
range_start_idx = i + 1
ranges_idx += 1
sum = 0
# Handle final range outside of for loop
ranges[ranges_idx, 0] = range_start_idx
ranges[ranges_idx, 1] = x.shape[0]
if ranges_idx < n_chunks - 1:
left[ranges_idx:] = x.shape[0]
return ranges
我正在尋找更好的矢量化解決方案。
我在回答的類似問題中找到了靈感:
def func(x, n):
out = np.zeros((n, 2), np.int64)
cum_arr = x.cumsum() / x.sum()
idx = 1 + np.searchsorted(cum_arr, np.linspace(0, 1, n, endpoint=False)[1:])
out[1:, 0] = idx # Fill the first column with start indices
out[:-1, 1] = idx # Fill the second column with exclusive stop indices
out[-1, 1] = x.shape[0] # Handle the stop index for the final chunk
return out
更新
為了涵蓋病理情況,我們需要更精確一些,並執行以下操作:
def func(x, n, truncate=False):
out = np.zeros((n_chunks, 2), np.int64)
cum_arr = x.cumsum() / x.sum()
idx = 1 + np.searchsorted(cum_arr, np.linspace(0, 1, n, endpoint=False)[1:])
out[1:, 0] = idx # Fill the first column with start indices
out[:-1, 1] = idx # Fill the second column with exclusive stop indices
out[-1, 1] = x.shape[0] # Handle the stop index for the final chunk
# Handle pathological case
diff_idx = np.diff(idx)
if np.any(diff_idx == 0):
row_truncation_idx = np.argmin(diff_idx) + 2
out[row_truncation_idx:, 0] = x.shape[0]
out[row_truncation_idx-1:, 1] = x.shape[0]
if truncate:
out = out[:row_truncation_idx]
return out
這是一個不會遍歷所有元素的解決方案:
def fun2(array, n):
min_sum = np.sum(array) / n
cumsum = np.cumsum(array)
i = -1
count = min_sum
out = []
while i < len(array)-1:
j = np.searchsorted(cumsum, count)
out.append([i+1, j+1])
i = j
if i < len(array):
count = cumsum[i] + min_sum
out[-1][1] -= 1
return np.array(out)
對於這兩個測試用例,它會產生您預期的結果。 HTH
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.