简体   繁体   中英

Can I pass a pretrained pdf function into seaborn.distplot?

I am aware that you can use seaborn.distplot to graph data as a histogram and superimpose a distribution on top of it. I'm aware of a parameter that allows you to pass in a pdf function to do so. In the source code, it looks like it internally calls fit() to do the training. I was wondering if there was a way to pre-train the model, and just use it.

I have tried using lambda functions representing my distribution, but I kept getting errors. I have also tried passing parameters into seaborn.distplot to help train with the settings I wanted, but that didn't work either.

Method 1 - Using a lambda for the pretrained model:

import seaborn as sns
from scipy import stats

params = stats.exponweib.fit(data, floc=0, f0=1)
custom_weib = lambda x: stats.exponweib.pdf(x, *params)
sns.distplot(data, bins=bin_count, fit=custom_weib, norm_hist=True, kde=False, hist_kws={'log':True})

I'm seeing this error message: AttributeError: 'function' object has no attribute 'fit' ^ It can't take a pre-trained model.

Method 2 - Attempted to pass parameters as part of the fit method. (I don't know if I'm doing this correctly.)

import seaborn as sns
from scipy import stats

sns.distplot(data, bins=bin_count, norm_hist=True, kde=False, hist_kws=hist_kws, fit=stats.exponweib, floc=0, f0=1)

I get this exception: TypeError: distplot() got an unexpected keyword argument 'floc' ^ It's obvious that I'm not passing in the variables correctly, but I don't know how.

Here's a link to the Seaborn source code if you need it: https://github.com/mwaskom/seaborn/blob/master/seaborn/distributions.py

In principle it's not possible to supply any parameters to seaborn's fit . This is due to the line params = fit.fit(a) in the source code.

However, it looks like you can trick seaborn by supplying an object that provides a fit() and a pdf() method and modify the arguments within this object.

import numpy as np
from scipy.stats import exponweib
import matplotlib.pyplot as plt
import seaborn as sns

fig, ax = plt.subplots(1, 1)

class MyDist():
    def __init__(self, **kw):
        self.dist = exponweib
        self.kw = kw

    def fit(self, data):
        return self.dist.fit(data, **self.kw)

    def pdf(self, data, *args, **kw):
        return self.dist.pdf(data, *args, **kw)


r = exponweib.rvs(3, 2, loc=0.3, scale=1.3, size=100000)

sns.distplot(r, fit=MyDist(floc=0.3, fscale=1.3), norm_hist=True, kde=False)


params = exponweib.fit(r, floc=0.3, fscale=1.3)
x = np.linspace(0.1, 4.1, 100)
ax.plot(x, exponweib.pdf(x, *params),
        'r-', lw=3, alpha=0.6)


plt.show()

在此处输入图片说明

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM