简体   繁体   中英

How to filter a Dataframe based on a criteria using .shift()

I am trying to remove any rows in a dataframe from the first non-sequential 'Period' onwards in a groupby. I would rather avoid looping if possible.

import pandas as pd


data = {'Country': ['DE', 'DE', 'DE', 'DE', 'DE', 'US', 'US', 'US', 'US','US'],
    'Product': ['Blue', 'Blue', 'Blue', 'Blue','Blue','Green', 'Green', 'Green', 'Green','Green'],
    'Period': [1, 2, 3,5,6, 1, 2, 4, 5, 6]}

df = pd.DataFrame(data, columns= ['Country','Product', 'Period'])
print df

OUTPUT:

  Country Product  Period
0      DE    Blue       1
1      DE    Blue       2
2      DE    Blue       3
3      DE    Blue       5
4      DE    Blue       6
5      US   Green       1
6      US   Green       2
7      US   Green       4
8      US   Green       5
9      US   Green       6

So for example, the final output I would like is below:

  Country Product  Period
0      DE    Blue       1
1      DE    Blue       2
2      DE    Blue       3
5      US   Green       1
6      US   Green       2

The way I was attempting to do this is the below to give you an idea but I have so many mistakes in it. But you can probably see the logic of what I am trying to do.

df = df.groupby(['Country','Product']).apply(lambda x: x[x.Period.shift(x.Period - 1) == 1]).reset_index(drop=True)

the tricky part is rather than just using .shift(1) or something I am trying to input a value into the .shift() ie if that rows Period is 5 then I want to say .shift(5-1) so it shifts up 4 places and checks the value of that Period. If it equals 1 then it means it is still sequential. in this case it would go into Nan territory I guess.

Instead of using shift() you could use diff() and cumsum() :

result = grouped['Period'].apply(
    lambda x: x.loc[(x.diff() > 1).cumsum() == 0])

import pandas as pd

data = {'Country': ['DE', 'DE', 'DE', 'DE', 'DE', 'US', 'US', 'US', 'US','US'],
    'Product': ['Blue', 'Blue', 'Blue', 'Blue','Blue','Green', 'Green', 'Green', 'Green','Green'],
    'Period': [1, 2, 3,5,6, 1, 2, 4, 5, 6]}

df = pd.DataFrame(data, columns= ['Country','Product', 'Period'])
print(df)
grouped = df.groupby(['Country','Product'])
result = grouped['Period'].apply(
    lambda x: x.loc[(x.diff() > 1).cumsum() == 0])
result.name = 'Period'
result = result.reset_index(['Country', 'Product'])
print(result)

yields

  Country Product  Period
0      DE    Blue       1
1      DE    Blue       2
2      DE    Blue       3
5      US   Green       1
6      US   Green       2

Explanation :

A sequential run of numbers have adjacent diffs of 1. For example, if we for the moment treat df['Period'] as part of all one group,

In [41]: df['Period'].diff()
Out[41]: 
0   NaN
1     1
2     1
3     2
4     1
5    -5
6     1
7     2
8     1
9     1
Name: Period, dtype: float64

In [42]: df['Period'].diff() > 1
Out[42]: 
0    False
1    False
2    False
3     True       <--- We want to cut off before here
4    False
5    False
6    False
7     True
8    False
9    False
Name: Period, dtype: bool

To find the cutoff location -- the first True in df['Period'].diff() > 1 -- we can use cumsum() , and select those rows that equal 0:

In [43]: (df['Period'].diff() > 1).cumsum()
Out[43]: 
0    0
1    0
2    0
3    1
4    1
5    1
6    1
7    2
8    2
9    2
Name: Period, dtype: int64

In [44]: (df['Period'].diff() > 1).cumsum() == 0
Out[44]: 
0     True
1     True
2     True
3    False
4    False
5    False
6    False
7    False
8    False
9    False
Name: Period, dtype: bool

Taking diff() and cumsum() might seem wasteful because these operations may be computing a lot of values that are not needed -- especially if x is very large and the first sequential run is very short.

Despite the wastefulness, the speed gained by calling NumPy or Pandas methods (implemented in C/Cython/C++ or Fortran) usually overpowers a less wasteful algorithm coded in pure Python.

You could however replace the call to cumsum with a call to argmax :

result = grouped['Period'].apply(
    lambda x: x.loc[:(x.diff() > 1).argmax()].iloc[:-1])

For very large x this might be somewhat quicker:

x = df['Period']
x = pd.concat([x]*1000)
x = x.reset_index(drop=True)

In [68]: %timeit x.loc[:(x.diff() > 1).argmax()].iloc[:-1]
1000 loops, best of 3: 884 µs per loop

In [69]: %timeit x.loc[(x.diff() > 1).cumsum() == 0]
1000 loops, best of 3: 1.12 ms per loop

Note, however, that argmax returns an index level value, not an ordinal index location. Therefore, using argmax will not work if x.index contains duplicate values. (That's why I had to set x = x.reset_index(drop=True) .)

So while using argmax is a bit faster in some situations, this alternative is not quite as robust.

Sorry .. am not aware of pandas.. But generally it can be achieved in python straight forward.

zip(data['Country'],data['Product'],data['Period'])
and the result will be a list ..
[('DE', 'Blue', 1), ('DE', 'Blue', 2), ('DE', 'Blue', 3), ('DE', 'Blue', 5), 
('DE', 'Blue', 6), ('US', 'Green', 1), ('US', 'Green', 2), ('US', 'Green', 4),
('US', 'Green', 5), ('US', 'Green', 6)]

After this the result can be easily fed to ur function

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM