In NumPy, it is possible to use the __array_priority__ attribute to take control of binary operators acting on an ndarray and a user-defined type. For instance:
class Foo(object):
def __radd__(self, lhs): return 0
__array_priority__ = 100
a = np.random.random((100,100))
b = Foo()
a + b # calls b.__radd__(a) -> 0
The same thing, however, doesn't appear to work for comparison operators. For instance, if I add the following line to Foo
, then it is never called from the expression a < b
:
def __rlt__(self, lhs): return 0
I realize that __rlt__
is not really a Python special name, but I thought it might work. I tried all of __lt__
, __le__
, __eq__
, __ne__
, __ge__
, __gt__
with and without a preceding r
, plus __cmp__
, too, but I could never get NumPy to call any of them.
Can these comparisons be overridden?
To avoid confusion, here is a longer description NumPy's behavior. For starters, here's what the Guide to NumPy book says:
If the ufunc has 2 inputs and 1 output and the second input is an Object array
then a special-case check is performed so that NotImplemented is returned if the
second input is not an ndarray, has the array priority attribute, and has an
r<op> special method.
I think this is the rule that makes + work. Here's an example:
import numpy as np
a = np.random.random((2,2))
class Bar0(object):
def __add__(self, rhs): return 0
def __radd__(self, rhs): return 1
b = Bar0()
print a + b # Calls __radd__ four times, returns an array
# [[1 1]
# [1 1]]
class Bar1(object):
def __add__(self, rhs): return 0
def __radd__(self, rhs): return 1
__array_priority__ = 100
b = Bar1()
print a + b # Calls __radd__ once, returns 1
# 1
As you can see, without __array_priority__
, NumPy interprets the user-defined object as a scalar type, and applies the operation at every position in the array. That's not what I want. My type is array-like (but should not be derived from ndarray).
Here's a longer example showing how this fails when all of the comparison methods are defined:
class Foo(object):
def __cmp__(self, rhs): return 0
def __lt__(self, rhs): return 1
def __le__(self, rhs): return 2
def __eq__(self, rhs): return 3
def __ne__(self, rhs): return 4
def __gt__(self, rhs): return 5
def __ge__(self, rhs): return 6
__array_priority__ = 100
b = Foo()
print a < b # Calls __cmp__ four times, returns an array
# [[False False]
# [False False]]
It looks like I can answer this myself. np.set_numeric_ops
can be used as follows:
class Foo(object):
def __lt__(self, rhs): return 0
def __le__(self, rhs): return 1
def __eq__(self, rhs): return 2
def __ne__(self, rhs): return 3
def __gt__(self, rhs): return 4
def __ge__(self, rhs): return 5
__array_priority__ = 100
def override(name):
def ufunc(x,y):
if isinstance(y,Foo): return NotImplemented
return np.getattr(name)(x,y)
return ufunc
np.set_numeric_ops(
** {
ufunc : override(ufunc) for ufunc in (
"less", "less_equal", "equal", "not_equal", "greater_equal"
, "greater"
)
}
)
a = np.random.random((2,2))
b = Foo()
print a < b
# 4
I cannot reproduce your problem. The correct approach is to use __cmp__
special-method. If I write
import numpy as np
class Foo(object):
def __radd__(self, lhs):
return 0
def __cmp__(self, this):
return -1
__array_prioriy__ = 100
a = np.random.random((100,100))
b = Foo()
print a<b
and set a break point in the debugger, execution stops at the return -1
.
Btw: __array_prioriy__
doesn't make any difference here: you have a typo in it!
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.