簡體   English   中英

Python Numpy。 如果該元素位於一對指定元素之間,則刪除二維數組中的一個(或多個元素)

[英]Python Numpy. Delete an element (or elements) in a 2D array if said element is located between a pair of specified elements

我有一個 2D NumPy 陣列,專門用 1 和 0 填充。

a = [[0 0 0 0 1 0 0 0 1]
     [1 1 1 1 1 1 1 1 1]
     [1 1 1 1 1 1 1 1 1]
     [1 1 1 1 0 0 0 0 1]
     [1 1 1 1 1 1 1 1 1]
     [1 1 1 0 1 1 1 1 1]
     [1 1 1 1 1 1 0 0 1]
     [1 1 1 1 1 1 1 1 1]]

為了獲取 0 的位置,我使用了以下代碼:

new_array = np.transpose(np.nonzero(a==0))

正如預期的那樣,我得到以下結果,顯示了數組中 0 的位置

new_array = [[0 0]
             [0 1]
             [0 2]
             [0 3]
             [0 5]
             [0 6]
             [0 7]
             [3 4]
             [3 5]
             [3 6]
             [3 7]
             [5 3]
             [6 6]
             [6 7]]

現在我的問題來了:如果所說的組大於 2,有沒有辦法在水平組的開始和結束處獲取 0 的位置?

編輯:如果組要在一行的末尾完成並繼續在它下面的一個,它將計為 2 個單獨的組。

我的第一個想法是實現一個過程,如果它們位於 0 之間,將刪除 0,但我無法弄清楚如何做到這一點。

我希望“new_array” output 是:

new_array = [[0 0]
             [0 3]
             [0 5]
             [0 7]
             [3 4]
             [3 7]
             [5 3]
             [6 6]
             [6 7]]

先謝謝了!!

一種更容易遵循的可能解決方案是:

b = np.diff(a, prepend=1)  # prepend a column of 1s and detect
                           # jumps between adjacent columns (left to right)
y, x = np.where(b > 0)  # find positions of the jumps 0->1 (left to right)
# shift positive jumps to the left by 1 position while filling gaps with 0:
b[y, x - 1] = 1
b[y, x] = 0
new_array = list(zip(*np.where(b)))

另一個是:

new_array = list(zip(*np.where(np.diff(a, n=2, prepend=1, append=1) > 0)))

兩種解決方案都基於np.diff計算連續列之間的差異(當axis=-1用於 2D 數組時)。

另一種解決方案的一個缺陷是它報告了所有的零序列,無論它們的長度如何。 您預期的 output 也包含此類組,由 1 個或 2 個零組成,但我認為它不應該。

我的解決方案沒有上述缺陷。

處理相鄰相等元素組的優雅工具是itertools.groupby ,所以從:

import itertools

然后將您的預期結果生成為:

res = []
for rowIdx, row in enumerate(a):
    colIdx = 0  # Start column index
    for k, grp in itertools.groupby(row):
        vals = list(grp)        # Values in the group
        lgth = len(vals)        # Length of the group
        colIdx2 = colIdx + lgth - 1  # End column index
        if k == 0 and lgth > 2: # Record this group
            res.append([rowIdx, colIdx])
            res.append([rowIdx, colIdx2])
        colIdx = colIdx2 + 1    # Advance column index
result = np.array(res)

對於您的源數據,結果是:

array([[0, 0],
       [0, 3],
       [0, 5],
       [0, 7],
       [3, 4],
       [3, 7]])

如您所見,它不包括第 5 行和第 6 行中較短的零序列。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM