I was able to subclass numpy:
class myary(np.ndarray):
def __new__(cls, arg1, arg2) :
.....
obj = super(myary, cls).__new__(cls, shape=(arg1,), dtype=np.int)
....
return obj
it works, but the problem is when I apply numpy functions (like np.concatenate(),np.stack().... etc) the output is ndarray instead of myary. I implemented __array_wrap__
, so np.sort, np.add... work, but not the one I mentioned above.
def __array_wrap__(self, out_arr, context=None):
return super(self.__class__, self).__array_wrap__(out_arr, context)
How to force that all numpy methods return whatever I pass as input.
It looks like adapting a class.__array_function__(func, types, args, **kwargs)
is the way to go as of version 1.16 of NumPy according to this scipy reference . That description has a helpful (though sparse) example for how to implement numpy.concatenate
and numpy.broadcast_to
which made it easy to implement for me, At its simplest (and even more sparse). here's how to handle numpy:concatenate and return an instance of your class:
import numpy as np
class MyClass:
def __array_function__(self, func, types, args, kwargs):
if func == np.concatenate:
< do stuff here for concatenating your class >
return < result of stuff done of type MyClass>
else:
return NotImplemented
As a side note, when I implemented np.concatenate
, I was confused as to how numpy knew to call my method when I provided a list of my subclasses as input (instead of an instance of my subclass). I found this helpful--this method is called whenever a numpy array function is called, and decides whether to use your implementation as follows (note the importance of returning NotImplemented
if you're not handling a given function):
- NumPy will gather implementations of
__array_function__
from all specified inputs and call them in order: subclasses before superclasses, and otherwise left to right. Note that in some edge cases involving subclasses, this differs slightly from the current behavior of Python.- Implementations of
__array_function__
indicate that they can handle the operation by returning any value other thanNotImplemented
.- If all
__array_function__
methods returnNotImplemented
, NumPy will raiseTypeError
.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.