简体   繁体   中英

Use pandas.shift() within a group

I have a dataframe with panel data, let's say it's time series for 100 different objects:

object  period  value 
1       1       24
1       2       67
...
1       1000    56
2       1       59
2       2       46
...
2       1000    64
3       1       54
...
100     1       451
100     2       153
...
100     1000    21

I want to add a new column prev_value that will store previous value for each object:

object  period  value  prev_value
1       1       24     nan
1       2       67     24
...
1       99      445    1243
1       1000    56     445
2       1       59     nan
2       2       46     59
...
2       1000    64     784
3       1       54     nan
...
100     1       451    nan
100     2       153    451
...
100     1000    21     1121

Can I use .shift() and .groupby() somehow to do that?

Pandas' grouped objects have a groupby.DataFrameGroupBy.shift method, which will shift a specified column in each group n periods , just like the regular dataframe's shift method:

df['prev_value'] = df.groupby('object')['value'].shift()

For the following example dataframe:

print(df)

     object  period  value
0       1       1     24
1       1       2     67
2       1       4     89
3       2       4      5
4       2      23     23

The result would be:

     object  period  value  prev_value
0       1       1     24         NaN
1       1       2     67        24.0
2       1       4     89        67.0
3       2       4      5         NaN
4       2      23     23         5.0

IFF your DataFrame is already sorted by the grouping keys you can use a single shift on the entire DataFrame and where to NaN the rows that overflow into the next group. For larger DataFrames with many groups this can be a bit faster.

df['prev_value'] = df['value'].shift().where(df.object.eq(df.object.shift()))

   object  period  value  prev_value
0       1       1     24         NaN
1       1       2     67        24.0
2       1       4     89        67.0
3       2       4      5         NaN
4       2      23     23         5.0

Some performance related timings:

import perfplot
import pandas as pd
import numpy as np

perfplot.show(
    setup=lambda N: pd.DataFrame({'object': np.repeat(range(N), 5), 
                                  'value': np.random.randint(1, 1000, 5*N)}), 
    kernels=[
        lambda df: df.groupby('object')['value'].shift(),
        lambda df: df['value'].shift().where(df.object.eq(df.object.shift())),
    ],
    labels=["GroupBy", "Where"],
    n_range=[2 ** k for k in range(1, 22)],
    equality_check=lambda x,y: np.allclose(x, y, equal_nan=True),
    xlabel="# of Groups"
)

在此处输入图像描述

只需从现有列创建一个新列。

data["prev_value"] = data["value"] 

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM