简体   繁体   中英

How to filter a Dataframe based on a criteria using .shift()

I am trying to remove any rows in a dataframe from the first non-sequential 'Period' onwards in a groupby. I would rather avoid looping if possible.

import pandas as pd

data = {'Country': ['DE', 'DE', 'DE', 'DE', 'DE', 'US', 'US', 'US', 'US','US'],
    'Product': ['Blue', 'Blue', 'Blue', 'Blue','Blue','Green', 'Green', 'Green', 'Green','Green'],
    'Period': [1, 2, 3,5,6, 1, 2, 4, 5, 6]}

df = pd.DataFrame(data, columns= ['Country','Product', 'Period'])
print df


  Country Product  Period
0      DE    Blue       1
1      DE    Blue       2
2      DE    Blue       3
3      DE    Blue       5
4      DE    Blue       6
5      US   Green       1
6      US   Green       2
7      US   Green       4
8      US   Green       5
9      US   Green       6

So for example, the final output I would like is below:

  Country Product  Period
0      DE    Blue       1
1      DE    Blue       2
2      DE    Blue       3
5      US   Green       1
6      US   Green       2

The way I was attempting to do this is the below to give you an idea but I have so many mistakes in it. But you can probably see the logic of what I am trying to do.

df = df.groupby(['Country','Product']).apply(lambda x: x[x.Period.shift(x.Period - 1) == 1]).reset_index(drop=True)

the tricky part is rather than just using .shift(1) or something I am trying to input a value into the .shift() ie if that rows Period is 5 then I want to say .shift(5-1) so it shifts up 4 places and checks the value of that Period. If it equals 1 then it means it is still sequential. in this case it would go into Nan territory I guess.

Instead of using shift() you could use diff() and cumsum() :

result = grouped['Period'].apply(
    lambda x: x.loc[(x.diff() > 1).cumsum() == 0])

import pandas as pd

data = {'Country': ['DE', 'DE', 'DE', 'DE', 'DE', 'US', 'US', 'US', 'US','US'],
    'Product': ['Blue', 'Blue', 'Blue', 'Blue','Blue','Green', 'Green', 'Green', 'Green','Green'],
    'Period': [1, 2, 3,5,6, 1, 2, 4, 5, 6]}

df = pd.DataFrame(data, columns= ['Country','Product', 'Period'])
grouped = df.groupby(['Country','Product'])
result = grouped['Period'].apply(
    lambda x: x.loc[(x.diff() > 1).cumsum() == 0])
result.name = 'Period'
result = result.reset_index(['Country', 'Product'])


  Country Product  Period
0      DE    Blue       1
1      DE    Blue       2
2      DE    Blue       3
5      US   Green       1
6      US   Green       2

Explanation :

A sequential run of numbers have adjacent diffs of 1. For example, if we for the moment treat df['Period'] as part of all one group,

In [41]: df['Period'].diff()
0   NaN
1     1
2     1
3     2
4     1
5    -5
6     1
7     2
8     1
9     1
Name: Period, dtype: float64

In [42]: df['Period'].diff() > 1
0    False
1    False
2    False
3     True       <--- We want to cut off before here
4    False
5    False
6    False
7     True
8    False
9    False
Name: Period, dtype: bool

To find the cutoff location -- the first True in df['Period'].diff() > 1 -- we can use cumsum() , and select those rows that equal 0:

In [43]: (df['Period'].diff() > 1).cumsum()
0    0
1    0
2    0
3    1
4    1
5    1
6    1
7    2
8    2
9    2
Name: Period, dtype: int64

In [44]: (df['Period'].diff() > 1).cumsum() == 0
0     True
1     True
2     True
3    False
4    False
5    False
6    False
7    False
8    False
9    False
Name: Period, dtype: bool

Taking diff() and cumsum() might seem wasteful because these operations may be computing a lot of values that are not needed -- especially if x is very large and the first sequential run is very short.

Despite the wastefulness, the speed gained by calling NumPy or Pandas methods (implemented in C/Cython/C++ or Fortran) usually overpowers a less wasteful algorithm coded in pure Python.

You could however replace the call to cumsum with a call to argmax :

result = grouped['Period'].apply(
    lambda x: x.loc[:(x.diff() > 1).argmax()].iloc[:-1])

For very large x this might be somewhat quicker:

x = df['Period']
x = pd.concat([x]*1000)
x = x.reset_index(drop=True)

In [68]: %timeit x.loc[:(x.diff() > 1).argmax()].iloc[:-1]
1000 loops, best of 3: 884 µs per loop

In [69]: %timeit x.loc[(x.diff() > 1).cumsum() == 0]
1000 loops, best of 3: 1.12 ms per loop

Note, however, that argmax returns an index level value, not an ordinal index location. Therefore, using argmax will not work if x.index contains duplicate values. (That's why I had to set x = x.reset_index(drop=True) .)

So while using argmax is a bit faster in some situations, this alternative is not quite as robust.

Sorry .. am not aware of pandas.. But generally it can be achieved in python straight forward.

and the result will be a list ..
[('DE', 'Blue', 1), ('DE', 'Blue', 2), ('DE', 'Blue', 3), ('DE', 'Blue', 5), 
('DE', 'Blue', 6), ('US', 'Green', 1), ('US', 'Green', 2), ('US', 'Green', 4),
('US', 'Green', 5), ('US', 'Green', 6)]

After this the result can be easily fed to ur function

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

粤ICP备18138465号  © 2020-2024 STACKOOM.COM