简体   繁体   中英

Subclassing numpy : numpy methods return ndarray instead of the sublassed type?

I was able to subclass numpy:

class myary(np.ndarray):
def __new__(cls, arg1, arg2) : 
   .....
        obj = super(myary, cls).__new__(cls, shape=(arg1,), dtype=np.int)
        ....
        return obj

it works, but the problem is when I apply numpy functions (like np.concatenate(),np.stack().... etc) the output is ndarray instead of myary. I implemented __array_wrap__ , so np.sort, np.add... work, but not the one I mentioned above.

def __array_wrap__(self, out_arr, context=None):
    return super(self.__class__, self).__array_wrap__(out_arr, context)

How to force that all numpy methods return whatever I pass as input.

It looks like adapting a class.__array_function__(func, types, args, **kwargs) is the way to go as of version 1.16 of NumPy according to this scipy reference . That description has a helpful (though sparse) example for how to implement numpy.concatenate and numpy.broadcast_to which made it easy to implement for me, At its simplest (and even more sparse). here's how to handle numpy:concatenate and return an instance of your class:

    import numpy as np

    class MyClass:
        def __array_function__(self, func, types, args, kwargs):
            if func == np.concatenate:
                < do stuff here for concatenating your class >
                return < result of stuff done of type MyClass>
            else:
                return NotImplemented

As a side note, when I implemented np.concatenate , I was confused as to how numpy knew to call my method when I provided a list of my subclasses as input (instead of an instance of my subclass). I found this helpful--this method is called whenever a numpy array function is called, and decides whether to use your implementation as follows (note the importance of returning NotImplemented if you're not handling a given function):

  1. NumPy will gather implementations of __array_function__ from all specified inputs and call them in order: subclasses before superclasses, and otherwise left to right. Note that in some edge cases involving subclasses, this differs slightly from the current behavior of Python.
  2. Implementations of __array_function__ indicate that they can handle the operation by returning any value other than NotImplemented .
  3. If all __array_function__ methods return NotImplemented , NumPy will raise TypeError .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM