简体   繁体   中英

scipy.stats.rv_continuous distribution with gap: problems with support?

I need several continuous distributions with gaps in them to fit some data, and am subclassing scipy.stats.rv_continuous for that purpose. Below is an example for a uniform distribution with a gap in it. The distribution is flat between s0 and l , and between h and s1 .

from scipy.stats import *

class gapF_gen(rv_continuous):
    ''' Class for a flat distribution with a gap in it
    s0, s1: bounds of support
    l, h: gap
    s0 < l < h < s1
    '''
    def _argcheck(self, s0, s1, l, h):  return (s0 < l < h < s1)
        
    def _get_support(self, s0, s1, l, h):   return s0, s1
    
    def _pdf(self, x, s0, s1, l, h):
        if (s0 <= x <= l) or (h <= x <= s1): return 1 / (s1 - h + l - s0)   
        else: return 0

gapF = gapF_gen(name='gapF')

bf = gapF(s0=-2.6, s1=4.77, l=-1.3, h=3.5)
print(bf.pdf(-2.8))  # OK
print(bf.pdf([-23.8, 3.8, 2.6, 6.9, 77.9])) # Not OK

I have defined _pdf to check if the value is zero or not. This works when passing scalar values to the automatically generated pdf , but when lists are passed to pdf , things don't work because of the range check:

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

On the other hand, if I rename my function to override ```pdf``, then for scalars I get the error:

TypeError: _parse_args() got an unexpected keyword argument 's0'

Are there any suggestions how to solve this?

The error is simply that you cannot do s0 <= x <= l0 with x being a numpy array (try it!). Instead, use bitwise and : (s0 <= x) & (x <= l0) . Or, if you like it more verbose, use np.logical_and .

And BTW, you should not override pdf . Sub classes only implement the underscored methods: _pdf, _cdf etc.

Building on @ev-bre's hints, the following works:

from scipy.stats import *

class gapF_gen(rv_continuous):
    ''' Class for a flat distribution with a gap in it
    s0, s1: bounds of support
    l, h: gap
    s0 < l < h < s1
    '''
    def _argcheck(self, s0, s1, l, h):  return (s0 < l < h < s1)
        
    def _get_support(self, s0, s1, l, h):   return s0, s1
    
    def _pdf(self, x, s0, s1, l, h):
        return np.where(((s0 <= x) & (x <= l)) | ((h <= x) & (x <= s1)), 1 / (s1 - h + l - s0), 0)

gapF = gapF_gen(name='gapF')

bf = gapF(s0=-2.2, s1=4.77, l=-1.3, h=3.5)
print(bf.pdf(-2.8))
print(bf.pdf([-23.8, 3.8, 2.6, 6.9, 77.9, -2.1]))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM