简体   繁体   English

正确将自己的函数应用于分组的熊猫数据框

[英]Correct apply own function to grouped pandas dataframe

I have a Pandas dataframe like:我有一个 Pandas 数据框,如:

   ticket date         close  
0    AAA  2018-01-12  176.16
1    AAA  2018-01-13  176.49
3    AAA  2018-01-14  176.00
4    BBB  2018-01-12  78.19
5    BBB  2018-01-13  79.90
6    BBB  2018-01-14  78.10

I have a function:我有一个功能:

def rsi(dataframe, period, column = 'close'):
    delta = dataframe[column].diff()
    up, down = delta.copy(), delta.copy()
    up[up < 0] = 0
    down[down > 0] = 0
    rolling_up = up.ewm(com=period - 1, adjust=False).mean()
    rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
    rsi = 100 - 100 / (1 + rolling_up / rolling_down)
    dataframe['rsi'] =  rsi
    return dataframe

What I need is to apply this function to my dataframe for each groupby('ticket').我需要的是将此函数应用于每个 groupby('ticket') 的数据帧。 I tried this but it doesn't work.我试过这个,但它不起作用。 Give me please some advice.请给我一些建议。

print(dataframe.groupby('ticket').apply(rsi, 2))

I get a error:我收到一个错误:

cannot reindex from a duplicate axis无法从重复的轴重新索引

Whole source code is:整个源代码是:

# -*- coding: utf-8 -*-

import json
import pandas
import requests
import datetime

def get_historical_prices(tickets, range):
    request_params = {'symbols': ','.join(tickets), 'types': 'chart', 'range': range}
    json = requests.get('https://api.iextrading.com/1.0/stock/market/batch', params = request_params).json()
    united_dataframe = pandas.DataFrame()
    for symbol in json:
        ticket_dataframe = pandas.DataFrame(json[symbol]['chart'])
        ticket_dataframe.insert(0, 'ticket', symbol)
        united_dataframe = united_dataframe.append(ticket_dataframe)
    return united_dataframe[['ticket', 'date', 'close']]

def rsi(dataframe, period, column = 'close'):
    delta = all_prices[column].diff()
    up, down = delta.copy(), delta.copy()
    up[up < 0] = 0
    down[down > 0] = 0
    rolling_up = up.ewm(com=period - 1, adjust=False).mean()
    rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
    rsi = 100 - 100 / (1 + rolling_up / rolling_down)
    dataframe['rsi'] =  rsi
    return dataframe

# Get the data
tickets = ['AAPL', 'FB', 'TSLA']
all_prices = get_historical_prices(tickets, '1m')

print(all_prices.groupby('ticket').apply(rsi, 2))

there is a problem in the source code.源代码有问题。 the line线

delta = all_prices[column].diff()

should be应该

delta = dataframe[column].diff() 

fixing it will also run without problem.修复它也将毫无问题地运行。 reassignment will add the column rsi to all_prices ie重新分配会将列rsi添加到all_prices

all_prices = all_prices.groupby('ticket').apply(rsi, 2)

so the final cod and results is shown below所以最终的鳕鱼和结果如下所示

In [20]: # -*- coding: utf-8 -*-
    ...: 
    ...: import json
    ...: import pandas
    ...: import requests
    ...: import datetime
    ...: 
    ...: def get_historical_prices(tickets, range):
    ...:     request_params = {'symbols': ','.join(tickets), 'types': 'chart', 'range': range}
    ...:     json = requests.get('https://api.iextrading.com/1.0/stock/market/batch', params = request_params).json()
    ...:     united_dataframe = pandas.DataFrame()
    ...:     for symbol in json:
    ...:         ticket_dataframe = pandas.DataFrame(json[symbol]['chart'])
    ...:         ticket_dataframe.insert(0, 'ticket', symbol)
    ...:         united_dataframe = united_dataframe.append(ticket_dataframe)
    ...:     return united_dataframe[['ticket', 'date', 'close']]
    ...: 
    ...: def rsi(dataframe, period, column = 'close'):
    ...:     delta = dataframe[column].diff()
    ...:     up, down = delta.copy(), delta.copy()
    ...:     up[up < 0] = 0
    ...:     down[down > 0] = 0
    ...:     rolling_up = up.ewm(com=period - 1, adjust=False).mean()
    ...:     rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
    ...:     rsi = 100 - 100 / (1 + rolling_up / rolling_down)
    ...:     dataframe['rsi'] = rsi
    ...:     return dataframe
    ...: 
    ...: # Get the data
    ...: tickets = ['AAPL', 'FB', 'TSLA']
    ...: all_prices = get_historical_prices(tickets, '1m')
    ...: 
    ...: all_prices = all_prices.groupby('ticket').apply(rsi, 2)
    ...: print(all_prices.head())
    ...: 
    ...: 
  ticket        date   close        rsi
0   AAPL  2018-01-12  177.09        NaN
1   AAPL  2018-01-16  176.19   0.000000
2   AAPL  2018-01-17  179.10  76.377953
3   AAPL  2018-01-18  179.26  78.208232
4   AAPL  2018-01-19  178.46  44.065484

The issue here is related to the line这里的问题与线路有关

dataframe['rsi'] =  rsi
return dataframe

the problem is that rsi does not have the same index as dataframe, further more rsi has a different length问题是 rsi 与数据帧的索引不同,而且 rsi 的长度不同

I changed the lines above to我将上面的几行更改为

return rsi

and the code ran without issues并且代码运行没有问题

so the final cod and results is shown below所以最终的鳕鱼和结果如下所示

In [12]: # -*- coding: utf-8 -*-
    ...: 
    ...: import json
    ...: import pandas
    ...: import requests
    ...: import datetime
    ...: 
    ...: def get_historical_prices(tickets, range):
    ...:     request_params = {'symbols': ','.join(tickets), 'types': 'chart', 'range': range}
    ...:     json = requests.get('https://api.iextrading.com/1.0/stock/market/batch', params = request_params).json()
    ...:     united_dataframe = pandas.DataFrame()
    ...:     for symbol in json:
    ...:         ticket_dataframe = pandas.DataFrame(json[symbol]['chart'])
    ...:         ticket_dataframe.insert(0, 'ticket', symbol)
    ...:         united_dataframe = united_dataframe.append(ticket_dataframe)
    ...:     return united_dataframe[['ticket', 'date', 'close']]
    ...: 
    ...: def rsi(dataframe, period, column = 'close'):
    ...:     delta = all_prices[column].diff()
    ...:     up, down = delta.copy(), delta.copy()
    ...:     up[up < 0] = 0
    ...:     down[down > 0] = 0
    ...:     rolling_up = up.ewm(com=period - 1, adjust=False).mean()
    ...:     rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
    ...:     rsi = 100 - 100 / (1 + rolling_up / rolling_down)
    ...:     
    ...:     return rsi
    ...: 
    ...: # Get the data
    ...: tickets = ['AAPL', 'FB', 'TSLA']
    ...: all_prices = get_historical_prices(tickets, '1m')
    ...: 
    ...: print(all_prices.groupby('ticket').apply(rsi, 2))
    ...: 
    ...: 
close   0    1          2          3          4          5          6   \
ticket                                                                   
AAPL   NaN  0.0  76.377953  78.208232  44.065484  16.991057  19.694656   
FB     NaN  0.0  76.377953  78.208232  44.065484  16.991057  19.694656   
TSLA   NaN  0.0  76.377953  78.208232  44.065484  16.991057  19.694656   

close         7         8        9     ...             11         12  \
ticket                                 ...                             
AAPL    3.521704  1.252711  15.2917    ...      76.523444  48.103572   
FB      3.521704  1.252711  15.2917    ...      76.523444  48.103572   
TSLA    3.521704  1.252711  15.2917    ...      76.523444  48.103572   

close          13         14         15       16         17         18  \
ticket                                                                   
AAPL    80.777497  46.146025  23.884925  8.34273  16.897125  75.919398   
FB      80.777497  46.146025  23.884925  8.34273  16.897125  75.919398   
TSLA    80.777497  46.146025  23.884925  8.34273  16.897125  75.919398   

close          19         20  
ticket                        
AAPL    15.705838  12.501725  
FB      15.705838  12.501725  
TSLA    15.705838  12.501725  

[3 rows x 61 columns]

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM