简体   繁体   中英

Checking condition in negative rolling window within GroupBy in Pandas

Following is what my dataframe looks like. Expected_Output column is my desired/target column.

   Group  Value  Expected_Output
0      1      2                1
1      1      3                1
2      1      6                1
3      1     11                0
4      1      7                0
5      2      3                1
6      2     13                1
7      2     14                0

For a given Group , as of a given row , I am looking into the next 5 rows and check if any Value > 10 . If true, then I want to return 1 in Expected_Output else 0.

For example, in Group 1 , as of the very first row, a Value of 11(which is greater than 10) appears within 3 rows and does fall within the "next 5 rows window" which satisfies the condition and thus a 1 is returned in Expected_Output . Similarly as of row 6 in Group 2 , a Value of 14(which is greater than 10) appears within 1 row and does fall within the "next 5 rows window" which satisfies the condition and thus a 1 is returned in Expected_Output .

I tried df.groupby('Group')['Value'].rolling(-5).max() > 10 to no avail.

pd.Series.rolling by default looks backwards. To look forwards, you can reverse the dataframe and then reverse the GroupBy result. You need to include a shift because you are looking for the next 5 values.

def roller(x):
    return x.rolling(window=5, min_periods=1)['Value'].max().shift().gt(10).astype(int)

df['Result'] = df.iloc[::-1].groupby('Group', sort=False).apply(roller).iloc[::-1].values

print(df)

   Group  Value  Result
0      1      2       1
1      1      3       1
2      1      6       1
3      1     11       0
4      1      7       0
5      2      3       1
6      2     13       1
7      2     14       0

you can try grouping the dataframe and make use of data frame index to get next possible 5 values and check for any value greater than 10

df['Expected_Output'] =df.groupby(['Group'])['Value'].transform(lambda y:list(map(lambda x: 1 if any(y.loc[set(np.arange(x+1,x+6)).intersection(y.index)] >10) else 0,y.index)))

Out:

    Group   Value   Expected_Output
0   1   2   1
1   1   3   1
2   1   6   1
3   1   11  0
4   1   7   0
5   2   3   1
6   2   13  1
7   2   14  0

There is a way to do it without any extra hacks, but it requires you to have a sorting dimension. Like with most time-series data you should have your time variable available to you. Then the solution is very simple:

  1. Sort backwards
  2. Use standard .rolling(window) functionality
  3. (optional) sort again

Example: Sleep Study

from pydataset import data
sleep_study = data('sleepstudy')
print(sleep_study.head(5))
   Reaction  Days  Subject
1  249.5600     0      308
2  258.7047     1      308
3  250.8006     2      308
4  321.4398     3      308
5  356.8519     4      308

1) Sort Backwards

sleep_study.sort_values(by=['Subject', 'Days'], ascending=False, inplace=True)

2) Use .rolling(window)

sleep_study['max_react_next_3_days'] = sleep_study\
    .groupby('Subject')['Reaction']\
    .rolling(window=3, min_periods=1, closed='left').max()\
    .droplevel(level=0)
sleep_study['expected_output'] = sleep_study['max_react_next_3_days'] > 400

Explanation:

  • We want to look 3 days ahead, thus window=3
  • However, only 2 or 1 days remaining in the study are also fine, so min_periods=1 -- this depends on your assumptions/liking
  • We want to use the next 3 days, not the current day, so we exclude it by using closed='left' , which means rolling takes a half-open interval, which is open on the "right" and closed on the "left".

3) Sort again, so that there are no surprises

sleep_study.sort_values(by=['Subject', 'Days'], ascending=True, inplace=True)

Result:

print(sleep_study.head(20))
    Reaction  Days  Subject  max_react_next_3_days  expected_output
1   249.5600     0      308               321.4398            False
2   258.7047     1      308               356.8519            False
3   250.8006     2      308               414.6901             True
4   321.4398     3      308               414.6901             True
5   356.8519     4      308               414.6901             True
6   414.6901     5      308               430.5853             True
7   382.2038     6      308               466.3535             True
8   290.1486     7      308               466.3535             True
9   430.5853     8      308               466.3535             True
10  466.3535     9      308                    NaN            False
11  222.7339     0      309               205.2658            False
12  205.2658     1      309               207.7161            False
13  202.9778     2      309               215.9618            False
14  204.7070     3      309               215.9618            False
15  207.7161     4      309               217.7272            False
16  215.9618     5      309               224.2957            False
17  213.6303     6      309               237.3142            False
18  217.7272     7      309               237.3142            False
19  224.2957     8      309               237.3142            False
20  237.3142     9      309                    NaN            False

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM