[英]Filter values in pandas dataframe based on complex columns conditions
I have a dataframe that looks like this我有一个看起来像这样的 dataframe
dict = {'trade_date': {1350: 20151201,
6175: 20151201,
3100: 20151201,
5650: 20151201,
3575: 20151201,
1: 20170301,
2: 20170301},
'comId': {1350: '257762',
6175: '1038328',
3100: '315476',
5650: '658776',
3575: '329376',
1: '123456',
2: '987654'},
'return': {1350: -0.0018,
6175: 0.0023,
3100: -0.0413,
5650: 0.1266,
3575: 0.0221,
1: '0.9',
2: '0.01'}}
df = pd.DataFrame(dict)
the expected output should be like this:
dict2 = {'trade_date': {5650: 20151201,
1: 20170301},
'comId': {5650: '658776',
1: '123456'},
'return': {5650: 0.1266,
1: '0.9'}}
I need to filter it based on the following condition: for each trade_date
value, I want to keep only the top 20% entries, based on the value in column return
.我需要根据以下条件对其进行过滤:对于每个trade_date
值,我只想根据return
列中的值保留前 20% 的条目。 So for this example, it would filter out everything but the company with comId
value 658776
and return
value 0.1266
.因此,对于此示例,它将过滤掉comId
值为658776
且return
值为0.1266
的公司之外的所有内容。
Bear in mind there might be trade_dates
with more companies associated to them.请记住,可能有与更多公司相关的trade_dates
。 In that case it should round that up or down to the nearest integer.在这种情况下,它应该向上或向下舍入到最接近的 integer。 For example, if there are 9 companies associated with a date, 20% * 9 = 1.8, so it should only keep the first two based on the values in column return
.例如,如果有 9 家公司与某个日期相关联,则 20% * 9 = 1.8,因此它应该仅根据列return
中的值保留前两个。
Any ideas how to best approach this, I'm a bit lost?任何想法如何最好地解决这个问题,我有点迷茫?
I think this should work:我认为这应该有效:
df\
.groupby("trade_date")\
.apply(lambda x: x[x["return"] >
x["return"].quantile(0.8, interpolation="nearest")])\
.reset_index(drop=True)
You can use groupby().transform
to get the threshold for each row.您可以使用groupby().transform
来获取每行的阈值。 This would be a bit faster than groupby().apply
:这会比groupby().apply
快一点:
thresholds = df.groupby('trade_date')['return'].transform('quantile',q=.8)
df[df['return'] > thresholds]
Output: Output:
trade_date comId return
5650 20151201 658776 0.1266
Create a temporary variable storing only the rows with the same trade_date.创建一个临时变量,仅存储具有相同 trade_date 的行。 Then use this: df.sort_values(by='return', ascending=False) and then remove the bottom 80%.然后使用这个: df.sort_values(by='return', ascending=False) 然后删除底部的 80%。 Loop through all possible dates and everytime you get the 20%, append them to a new dataframe.循环遍历所有可能的日期,每次获得 20%,append 到新的 dataframe。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.