简体   繁体   中英

Pandas pivot_table() aggfunc aggregation conditional on multiple columns?

I want to aggregate one column with a pandas pivot table, but the custom aggregation should be conditional on a different column in the dataframe.

See the example below: Say I want to sum the "Number_mentions" column for each value in the "Newspaper" column if the value of "Number_mentions" is above a threshold. This is easy to do with a custom aggfunc. But what if, in addition, I only want to sum those "Number_mentions" which are not in the same row as the value "RU" in the "Country" column? It seems like the aggfunc can only take one column in isolation from the others and I don't know how to get the entire dataframe into the aggfunc to do conditional subsetting in the aggfunc.

df = pd.DataFrame({"Number_mentions": [1,5,2,3,6,5], 
                   "Newspaper": ["Newspaper1", "Newspaper1", "Newspaper2", "Newspaper3", "Newspaper4", "Newspaper5"], 
                   "Country": ["US", "US", "CN", "CN", "RU", "RU"]})

def articles_above_thresh_with_condition(input_series, thresh=2):
    series_bool = input_series > thresh
    # ! add some if condition based on additional column in df: 
    # ! only aggregate those values where column "Country" is not "RU". 
    # ? code ? 
    n_articles_above_thresh = sum(series_bool)
    return n_articles_above_thresh

df_piv = pd.pivot_table(df, values=["Number_mentions"],
                        index="Newspaper", columns=None, margins=False,
                        aggfunc=articles_above_thresh_with_condition)

You need different approach, because pivot_table cannot working with 2 columns.

So first replace non matched values to missing values by Series.where and then processing this new column:

df["Number_mentions1"] = df["Number_mentions"].where(df["Country"].ne('RU'))
print (df)
   Number_mentions   Newspaper Country  Number_mentions1
0                1  Newspaper1      US               1.0
1                5  Newspaper1      US               5.0
2                2  Newspaper2      CN               2.0
3                3  Newspaper3      CN               3.0
4                6  Newspaper4      RU               NaN
5                5  Newspaper5      RU               NaN

df_piv = pd.pivot_table(df, values=["Number_mentions1"],
                        index="Newspaper", columns=None, margins=False,
                        aggfunc=articles_above_thresh_with_condition)
print (df_piv)
            Number_mentions1
Newspaper                   
Newspaper1               1.0
Newspaper2               0.0
Newspaper3               1.0
Newspaper4               0.0
Newspaper5               0.0

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM