简体   繁体   English

根据复杂列条件过滤 pandas dataframe 中的值

[英]Filter values in pandas dataframe based on complex columns conditions

I have a dataframe that looks like this我有一个看起来像这样的 dataframe

dict = {'trade_date': {1350: 20151201,
  6175: 20151201,
  3100: 20151201,
  5650: 20151201,
  3575: 20151201,
     1: 20170301,
     2: 20170301},
 'comId': {1350: '257762',
  6175: '1038328',
  3100: '315476',
  5650: '658776',
  3575: '329376',
     1: '123456',
     2: '987654'},
 'return': {1350: -0.0018,
  6175: 0.0023,
  3100: -0.0413,
  5650: 0.1266,
  3575: 0.0221,
  1: '0.9',
  2: '0.01'}}

df = pd.DataFrame(dict)

the expected output should be like this:
dict2 = {'trade_date': {5650: 20151201,
     1: 20170301},
 'comId': {5650: '658776',
     1: '123456'},
 'return': {5650: 0.1266,
  1: '0.9'}}

I need to filter it based on the following condition: for each trade_date value, I want to keep only the top 20% entries, based on the value in column return .我需要根据以下条件对其进行过滤:对于每个trade_date值,我只想根据return列中的值保留前 20% 的条目。 So for this example, it would filter out everything but the company with comId value 658776 and return value 0.1266 .因此,对于此示例,它将过滤掉comId值为658776return值为0.1266的公司之外的所有内容。

Bear in mind there might be trade_dates with more companies associated to them.请记住,可能有与更多公司相关的trade_dates In that case it should round that up or down to the nearest integer.在这种情况下,它应该向上或向下舍入到最接近的 integer。 For example, if there are 9 companies associated with a date, 20% * 9 = 1.8, so it should only keep the first two based on the values in column return .例如,如果有 9 家公司与某个日期相关联,则 20% * 9 = 1.8,因此它应该仅根据列return中的值保留前两个。

Any ideas how to best approach this, I'm a bit lost?任何想法如何最好地解决这个问题,我有点迷茫?

I think this should work:我认为这应该有效:

df\
.groupby("trade_date")\
.apply(lambda x: x[x["return"] >
    x["return"].quantile(0.8, interpolation="nearest")])\
.reset_index(drop=True)

You can use groupby().transform to get the threshold for each row.您可以使用groupby().transform来获取每行的阈值。 This would be a bit faster than groupby().apply :这会比groupby().apply快一点:

thresholds = df.groupby('trade_date')['return'].transform('quantile',q=.8)
df[df['return'] > thresholds]

Output: Output:

      trade_date   comId  return
5650    20151201  658776  0.1266

Create a temporary variable storing only the rows with the same trade_date.创建一个临时变量,仅存储具有相同 trade_date 的行。 Then use this: df.sort_values(by='return', ascending=False) and then remove the bottom 80%.然后使用这个: df.sort_values(by='return', ascending=False) 然后删除底部的 80%。 Loop through all possible dates and everytime you get the 20%, append them to a new dataframe.循环遍历所有可能的日期,每次获得 20%,append 到新的 dataframe。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM