I need several continuous distributions with gaps in them to fit some data, and am subclassing scipy.stats.rv_continuous
for that purpose. Below is an example for a uniform distribution with a gap in it. The distribution is flat between s0
and l
, and between h
and s1
.
from scipy.stats import *
class gapF_gen(rv_continuous):
''' Class for a flat distribution with a gap in it
s0, s1: bounds of support
l, h: gap
s0 < l < h < s1
'''
def _argcheck(self, s0, s1, l, h): return (s0 < l < h < s1)
def _get_support(self, s0, s1, l, h): return s0, s1
def _pdf(self, x, s0, s1, l, h):
if (s0 <= x <= l) or (h <= x <= s1): return 1 / (s1 - h + l - s0)
else: return 0
gapF = gapF_gen(name='gapF')
bf = gapF(s0=-2.6, s1=4.77, l=-1.3, h=3.5)
print(bf.pdf(-2.8)) # OK
print(bf.pdf([-23.8, 3.8, 2.6, 6.9, 77.9])) # Not OK
I have defined _pdf
to check if the value is zero or not. This works when passing scalar values to the automatically generated pdf
, but when lists are passed to pdf
, things don't work because of the range check:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
On the other hand, if I rename my function to override ```pdf``, then for scalars I get the error:
TypeError: _parse_args() got an unexpected keyword argument 's0'
Are there any suggestions how to solve this?
The error is simply that you cannot do s0 <= x <= l0
with x being a numpy array (try it!). Instead, use bitwise and : (s0 <= x) & (x <= l0)
. Or, if you like it more verbose, use np.logical_and
.
And BTW, you should not override pdf
. Sub classes only implement the underscored methods: _pdf, _cdf etc.
Building on @ev-bre's hints, the following works:
from scipy.stats import *
class gapF_gen(rv_continuous):
''' Class for a flat distribution with a gap in it
s0, s1: bounds of support
l, h: gap
s0 < l < h < s1
'''
def _argcheck(self, s0, s1, l, h): return (s0 < l < h < s1)
def _get_support(self, s0, s1, l, h): return s0, s1
def _pdf(self, x, s0, s1, l, h):
return np.where(((s0 <= x) & (x <= l)) | ((h <= x) & (x <= s1)), 1 / (s1 - h + l - s0), 0)
gapF = gapF_gen(name='gapF')
bf = gapF(s0=-2.2, s1=4.77, l=-1.3, h=3.5)
print(bf.pdf(-2.8))
print(bf.pdf([-23.8, 3.8, 2.6, 6.9, 77.9, -2.1]))
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.