[英]Correct apply own function to grouped pandas dataframe
I have a Pandas dataframe like:我有一个 Pandas 数据框,如:
ticket date close
0 AAA 2018-01-12 176.16
1 AAA 2018-01-13 176.49
3 AAA 2018-01-14 176.00
4 BBB 2018-01-12 78.19
5 BBB 2018-01-13 79.90
6 BBB 2018-01-14 78.10
I have a function:我有一个功能:
def rsi(dataframe, period, column = 'close'):
delta = dataframe[column].diff()
up, down = delta.copy(), delta.copy()
up[up < 0] = 0
down[down > 0] = 0
rolling_up = up.ewm(com=period - 1, adjust=False).mean()
rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
rsi = 100 - 100 / (1 + rolling_up / rolling_down)
dataframe['rsi'] = rsi
return dataframe
What I need is to apply this function to my dataframe for each groupby('ticket').我需要的是将此函数应用于每个 groupby('ticket') 的数据帧。 I tried this but it doesn't work.
我试过这个,但它不起作用。 Give me please some advice.
请给我一些建议。
print(dataframe.groupby('ticket').apply(rsi, 2))
I get a error:我收到一个错误:
cannot reindex from a duplicate axis
无法从重复的轴重新索引
# -*- coding: utf-8 -*-
import json
import pandas
import requests
import datetime
def get_historical_prices(tickets, range):
request_params = {'symbols': ','.join(tickets), 'types': 'chart', 'range': range}
json = requests.get('https://api.iextrading.com/1.0/stock/market/batch', params = request_params).json()
united_dataframe = pandas.DataFrame()
for symbol in json:
ticket_dataframe = pandas.DataFrame(json[symbol]['chart'])
ticket_dataframe.insert(0, 'ticket', symbol)
united_dataframe = united_dataframe.append(ticket_dataframe)
return united_dataframe[['ticket', 'date', 'close']]
def rsi(dataframe, period, column = 'close'):
delta = all_prices[column].diff()
up, down = delta.copy(), delta.copy()
up[up < 0] = 0
down[down > 0] = 0
rolling_up = up.ewm(com=period - 1, adjust=False).mean()
rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
rsi = 100 - 100 / (1 + rolling_up / rolling_down)
dataframe['rsi'] = rsi
return dataframe
# Get the data
tickets = ['AAPL', 'FB', 'TSLA']
all_prices = get_historical_prices(tickets, '1m')
print(all_prices.groupby('ticket').apply(rsi, 2))
there is a problem in the source code.源代码有问题。 the line
线
delta = all_prices[column].diff()
should be应该
delta = dataframe[column].diff()
fixing it will also run without problem.修复它也将毫无问题地运行。 reassignment will add the column
rsi
to all_prices
ie重新分配会将列
rsi
添加到all_prices
即
all_prices = all_prices.groupby('ticket').apply(rsi, 2)
so the final cod and results is shown below所以最终的鳕鱼和结果如下所示
In [20]: # -*- coding: utf-8 -*-
...:
...: import json
...: import pandas
...: import requests
...: import datetime
...:
...: def get_historical_prices(tickets, range):
...: request_params = {'symbols': ','.join(tickets), 'types': 'chart', 'range': range}
...: json = requests.get('https://api.iextrading.com/1.0/stock/market/batch', params = request_params).json()
...: united_dataframe = pandas.DataFrame()
...: for symbol in json:
...: ticket_dataframe = pandas.DataFrame(json[symbol]['chart'])
...: ticket_dataframe.insert(0, 'ticket', symbol)
...: united_dataframe = united_dataframe.append(ticket_dataframe)
...: return united_dataframe[['ticket', 'date', 'close']]
...:
...: def rsi(dataframe, period, column = 'close'):
...: delta = dataframe[column].diff()
...: up, down = delta.copy(), delta.copy()
...: up[up < 0] = 0
...: down[down > 0] = 0
...: rolling_up = up.ewm(com=period - 1, adjust=False).mean()
...: rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
...: rsi = 100 - 100 / (1 + rolling_up / rolling_down)
...: dataframe['rsi'] = rsi
...: return dataframe
...:
...: # Get the data
...: tickets = ['AAPL', 'FB', 'TSLA']
...: all_prices = get_historical_prices(tickets, '1m')
...:
...: all_prices = all_prices.groupby('ticket').apply(rsi, 2)
...: print(all_prices.head())
...:
...:
ticket date close rsi
0 AAPL 2018-01-12 177.09 NaN
1 AAPL 2018-01-16 176.19 0.000000
2 AAPL 2018-01-17 179.10 76.377953
3 AAPL 2018-01-18 179.26 78.208232
4 AAPL 2018-01-19 178.46 44.065484
The issue here is related to the line这里的问题与线路有关
dataframe['rsi'] = rsi
return dataframe
the problem is that rsi does not have the same index as dataframe, further more rsi has a different length问题是 rsi 与数据帧的索引不同,而且 rsi 的长度不同
I changed the lines above to我将上面的几行更改为
return rsi
and the code ran without issues并且代码运行没有问题
so the final cod and results is shown below所以最终的鳕鱼和结果如下所示
In [12]: # -*- coding: utf-8 -*-
...:
...: import json
...: import pandas
...: import requests
...: import datetime
...:
...: def get_historical_prices(tickets, range):
...: request_params = {'symbols': ','.join(tickets), 'types': 'chart', 'range': range}
...: json = requests.get('https://api.iextrading.com/1.0/stock/market/batch', params = request_params).json()
...: united_dataframe = pandas.DataFrame()
...: for symbol in json:
...: ticket_dataframe = pandas.DataFrame(json[symbol]['chart'])
...: ticket_dataframe.insert(0, 'ticket', symbol)
...: united_dataframe = united_dataframe.append(ticket_dataframe)
...: return united_dataframe[['ticket', 'date', 'close']]
...:
...: def rsi(dataframe, period, column = 'close'):
...: delta = all_prices[column].diff()
...: up, down = delta.copy(), delta.copy()
...: up[up < 0] = 0
...: down[down > 0] = 0
...: rolling_up = up.ewm(com=period - 1, adjust=False).mean()
...: rolling_down = down.ewm(com= period -1, adjust=False).mean().abs()
...: rsi = 100 - 100 / (1 + rolling_up / rolling_down)
...:
...: return rsi
...:
...: # Get the data
...: tickets = ['AAPL', 'FB', 'TSLA']
...: all_prices = get_historical_prices(tickets, '1m')
...:
...: print(all_prices.groupby('ticket').apply(rsi, 2))
...:
...:
close 0 1 2 3 4 5 6 \
ticket
AAPL NaN 0.0 76.377953 78.208232 44.065484 16.991057 19.694656
FB NaN 0.0 76.377953 78.208232 44.065484 16.991057 19.694656
TSLA NaN 0.0 76.377953 78.208232 44.065484 16.991057 19.694656
close 7 8 9 ... 11 12 \
ticket ...
AAPL 3.521704 1.252711 15.2917 ... 76.523444 48.103572
FB 3.521704 1.252711 15.2917 ... 76.523444 48.103572
TSLA 3.521704 1.252711 15.2917 ... 76.523444 48.103572
close 13 14 15 16 17 18 \
ticket
AAPL 80.777497 46.146025 23.884925 8.34273 16.897125 75.919398
FB 80.777497 46.146025 23.884925 8.34273 16.897125 75.919398
TSLA 80.777497 46.146025 23.884925 8.34273 16.897125 75.919398
close 19 20
ticket
AAPL 15.705838 12.501725
FB 15.705838 12.501725
TSLA 15.705838 12.501725
[3 rows x 61 columns]
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.