[英]How to add error bars to interaction plot (statsmodels)?
我有以下代碼:
import numpy as np
import matplotlib.pyplot as plt
from statsmodels.graphics.factorplots import interaction_plot
a = np.array( [ item for item in [ 'a1', 'a2', 'a3' ] for _ in range(30) ] )
b = np.array( [ item for _ in range(45) for item in [ 'b1', 'b2' ] ] )
np.random.seed(123)
mse = np.ravel( np.column_stack( (np.random.normal(-1, 1, size=45 ), np.random.normal(2, 0.5, size=45 ) )) )
f = interaction_plot( a, b, mse )
這使:
是否有一種簡單的方法可以將誤差線直接添加到每個點?
f.axes.errorbar()?
還是直接使用matplotlib制作圖更好?
好吧,似乎尚未直接支持該功能,因此我決定直接修改源代碼並創建一個新功能。 我將其張貼在這里,也許對某人有用。
def int_plot(x, trace, response, func=np.mean, ax=None, plottype='b',
xlabel=None, ylabel=None, colors=[], markers=[],
linestyles=[], legendloc='best', legendtitle=None,
# - - - My changes !!
errorbars=False, errorbartyp='std',
# - - - -
**kwargs):
data = DataFrame(dict(x=x, trace=trace, response=response))
plot_data = data.groupby(['trace', 'x']).aggregate(func).reset_index()
# - - - My changes !!
if errorbars:
if errorbartyp == 'std':
yerr = data.groupby(['trace', 'x']).aggregate( lambda xx: np.std(xx,ddof=1) ).reset_index()
elif errorbartyp == 'ci95':
yerr = data.groupby(['trace', 'x']).aggregate( t_ci ).reset_index()
else:
raise ValueError("Type of error bars %s not understood" % errorbartyp)
# - - - - - - -
n_trace = len(plot_data['trace'].unique())
if plottype == 'both' or plottype == 'b':
for i, (values, group) in enumerate(plot_data.groupby(['trace'])):
# trace label
label = str(group['trace'].values[0])
# - - - My changes !!
if errorbars:
ax.errorbar(group['x'], group['response'],
yerr=yerr.loc[ yerr['trace']==values ]['response'].values,
color=colors[i], ecolor='black',
marker=markers[i], label='',
linestyle=linestyles[i], **kwargs)
# - - - - - - - - - -
ax.plot(group['x'], group['response'], color=colors[i],
marker=markers[i], label=label,
linestyle=linestyles[i], **kwargs)
這樣,我可以得到這個情節:
f = int_plot( a, b, mse, errorbars=True, errorbartyp='std' )
注意:該代碼還可以使用函數t_ci()
匯總誤差線。 我定義了這樣的功能:
def t_ci( x, C=0.95 ):
from scipy.stats import t
x = np.array( x )
n = len( x )
tstat = t.ppf( (1-C)/2, n )
return np.std( x, ddof=1 ) * tstat / np.sqrt( n )
同樣,我只是對該功能進行了一些微調,以適應當前的需求。 原始功能可以在這里找到:)
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.