簡體   English   中英

熊貓如何將函數應用於帶參數的groupby對象

[英]pandas how to apply function to groupby objects with argument

我有一個df

cluster_id    memo
   1          m
   1          n
   2          m
   2          m
   2          n
   3          m
   3          m
   3          m
   3          n
   4          m
   4          n
   4          n
   4          n

我想對cluster_id進行groupby並應用以下功能,

def valid_row_dup(df):
    num_real_invs = df[df['memo'] == 'm'].shape[0]
    num_reversals_invs = df[df['memo'] == 'n'].shape[0]

    if num_real_invs == df.shape[0]:
        return True
    elif num_reversals_invs == df.shape[0]:
        return False
    elif abs(num_real_invs - num_reversals_invs) > 0:
        # even diff
        if abs(num_real_invs - num_reversals_invs) % 2 == 0:
            return True
        else:
            if abs(num_real_invs - num_reversals_invs) == 1:
                return False
            # odd diff
            else:
                return True
    elif num_real_invs - num_reversals_invs == 0:
        return False 

將每個groupby對象作為df傳遞到func ; 將布爾結果分配回df

cluster_id    memo     valid
   1          m        False
   1          n        False
   2          m        False
   2          m        False
   2          n        False
   3          m        True
   3          m        True
   3          m        True
   3          n        True
   4          m        True
   4          n        True
   4          n        True   
   4          n        True

應用您的函數然后合並:

df.merge(df.groupby('cluster_id').apply(valid_row_dup).to_frame(), on='cluster_id')

    cluster_id memo      0
0            1    m  False
1            1    n  False
2            2    m  False
3            2    m  False
4            2    n  False
5            3    m   True
6            3    m   True
7            3    m   True
8            3    n   True
9            4    m   True
10           4    n   True
11           4    n   True
12           4    n   True

我同意克里斯的回答。 只是想提供一個完善的解決方案。

df.merge(df.groupby('cluster_id').apply(valid_row_dup).\
    to_frame().reset_index().\
    rename(columns={0:'valid'}),
    on='cluster_id', how='inner')

如果您以其他方式稍稍定義函數:

def valid_row_dup2(ser):
    num_real_invs = ser[ser == 'm'].size        # Number of 'm'
    num_reversals_invs = ser[ser == 'n'].size   # Number of 'n'
    siz = ser.size                  # Total size
    diff = abs(num_real_invs - num_reversals_invs)
    if num_real_invs == siz:        # Only 'm'
        return True
    elif num_reversals_invs == siz: # Only 'n'
        return False
    elif diff > 0:          # Different number of 'm' and 'n'
        if diff % 2 == 0:   # Even diff
            return True
        elif diff == 1:     # Difference by one
            return False
        else:               # Odd diff, > 1
            return True
    else:                   # Equal number of 'm' and 'n'
        return False

您可以按如下方式添加新列:

df['valid'] = df.groupby('cluster_id').memo.transform(valid_row_dup2)

恕我直言,這是一個更簡單的解決方案(無需merge ,您只需添加一個新列)。

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM