简体   繁体   English

如何获取两个不同的numpy.ndarray子类的__matmul__以返回特定的子类?

[英]How do I get __matmul__ of two different numpy.ndarray subclasses to return a particular subclass?

I have two np.ndarray subclasses. 我有两个np.ndarray子类。 Tuple @ Matrix returns a Tuple , but Matrix @ Tuple returns a Matrix . Tuple @ Matrix返回一个Tuple ,但是Matrix @ Tuple返回一个Matrix How might I have it return a Tuple instead? 我如何让它返回一个Tuple呢?

import numpy as np

class Tuple(np.ndarray):
    def __new__(cls, input_array, info=None):
        return np.asarray(input_array).view(cls)

class Matrix(np.ndarray):
    def __new__(cls, input_array, info=None):
        return np.asarray(input_array).view(cls)

def scaling(x, y, z):
    m = Matrix(np.identity(4))
    m[0, 0] = x
    m[1, 1] = y
    m[2, 2] = z
    return m

example: 例:

>>> Tuple([1,2,3,4]) @ scaling(2,2,2)
Tuple([2., 4., 6., 4.])

>>> scaling(2,2,2) @ Tuple([1,2,3,4])
Matrix([2., 4., 6., 4.])   # XXXX I'd like this to be a Tuple

PS: Matrix @ Matrix should return Matrix PS: Matrix @ Matrix应该返回Matrix

You can overload the __matmul__ method to return a Tuple - and if you want to be a Tuple if any of the variables is a Tuple and a Matrix otherwise, I think this'll work: 您可以重载__matmul__方法以返回一个Tuple -并且如果您想成为一个Tuple如果任何变量是TupleMatrix否则我认为这会起作用:

class Matrix(np.ndarray):
    def __new__(cls, input_array, info=None):
        return np.asarray(input_array).view(cls)

    def __matmul__(m1, m2):
         return (m2.T @ m1.T).T if isinstance(m2, Tuple) else np.matmul(m1, m2)

I made a mistake in copying from the np.matrix example. 在从np.matrix示例进行复制时,我犯了一个错误。

class Tuple(np.ndarray): 
    __array_priority__ = 10 
    def __new__(cls, input_array, info=None): 
        return np.asarray(input_array).view(cls) 
class Matrix(np.ndarray):
    __array_priority__ = 5.0 
    def __new__(cls, input_array, info=None): 
        return np.asarray(input_array).view(cls)

In [2]: def scaling(x, y, z):  
   ...:      ...:     m = Matrix(np.identity(4))  
   ...:      ...:     m[0, 0] = x  
   ...:      ...:     m[1, 1] = y  
   ...:      ...:     m[2, 2] = z  
   ...:      ...:     return m  
   ...:                                                                                                                                  
In [3]: Tuple([1,2,3,4]) @ scaling(2,2,2)                                                                                                
Out[3]: Tuple([2., 4., 6., 4.])
In [4]: scaling(2,2,2) @ Tuple([1,2,3,4])                                                                                                
Out[4]: Tuple([2., 4., 6., 4.])

=== ===

Taking a clue from the np.matrix definition: numpy.matrixlib.defmatrix.py np.matrix定义中获取线索:numpy.matrixlib.defmatrix.py

Add a __array_priority__ attribute: 添加__array_priority__属性:

In [382]: class Tuple(np.ndarray): 
     ...:     def __new__(cls, input_array, info=None): 
     ...:         __array_priority = 10 
     ...:         return np.asarray(input_array).view(cls) 
     ...:  
     ...: class Matrix(np.ndarray): 
     ...:     def __new__(cls, input_array, info=None): 
     ...:         __array_priority = 5 
     ...:         return np.asarray(input_array).view(cls) 
     ...:                                                                                            
In [383]:                                                                                            
In [383]: def scaling(x, y, z): 
     ...:     m = Matrix(np.identity(4)) 
     ...:     m[0, 0] = x 
     ...:     m[1, 1] = y 
     ...:     m[2, 2] = z 
     ...:     return m 
     ...:                                                                                            
In [384]: Tuple([1,2,3,4]) @ scaling(2,2,2)                                                          
Out[384]: Tuple([2., 4., 6., 4.])
In [385]: scaling(2,2,2) @ Tuple([1,2,3,4])                                                          
Out[385]: Matrix([2., 4., 6., 4.])

One way of solving this is by implementing a custom __matmul__ in Matrix and __rmatmul__ in Tuple : 解决此问题的一种方法是在Matrix实现自定义__matmul__ ,在Tuple实现__rmatmul__

import numpy as np

class Tuple(np.ndarray):
    def __new__(cls, input_array, info=None):
        return np.asarray(input_array).view(cls)

    def __rmatmul__(self, other):
        return super().__matmul__(other)

class Matrix(np.ndarray):
    def __new__(cls, input_array, info=None):
        return np.asarray(input_array).view(cls)

    def __matmul__(self, other):
        if not isinstance(other, Matrix):
            return NotImplemented
        return super().__matmul__(other)

def scaling(x, y, z):
    m = Matrix(np.identity(4))
    m[0, 0] = x
    m[1, 1] = y
    m[2, 2] = z
    return m

scaling(2,2,2) @ scaling(2,2,2)
# Matrix([[4., 0., 0., 0.],
#         [0., 4., 0., 0.],
#         [0., 0., 4., 0.],
#         [0., 0., 0., 1.]])
Tuple([1,2,3,4]) @ scaling(2,2,2)
# Tuple([2., 4., 6., 4.])
scaling(2,2,2) @ Tuple([1,2,3,4])
# Tuple([2., 4., 6., 4.])

Just overload __matmul__ of Matrix class to return tuple instead 只需重载Matrix类的__matmul__即可返回元组

class Matrix(np.ndarray):
    def __new__(cls, input_array, info=None):
        return np.asarray(input_array).view(cls)

    def __matmul__(self, other):
        return other @ self

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

相关问题 如何子类化numpy.ndarray的子​​类 - How to subclass a subclass of numpy.ndarray 如何编写一个行为类似于numpy.ndarray的类而又不继承numpy.ndarray的类? - How do I write a class that behaves like a numpy.ndarray without subclassing numpy.ndarray? 'numpy.ndarray' 对象如何不是 'numpy.ndarray' 对象? - How do 'numpy.ndarray' object do not 'numpy.ndarray' object? 具有新属性的numpy.ndarray子类 - numpy.ndarray subclass with new properties 我如何轻松地将numpy.ndarray转换为numpy.array列表? - How do I easily convert a numpy.ndarray to a list of numpy.array? 为什么我得到 'numpy.ndarray' 对象没有属性 'convert'? - Why do I get 'numpy.ndarray' object has no attribute 'convert'? 为什么我得到 AttributeError: 'numpy.ndarray' object has no attribute 'replace' in python? - Why do I get AttributeError: 'numpy.ndarray' object has no attribute 'replace' in python? 为什么我会收到日志方法的“ufunc 循环不支持 numpy.ndarray 类型的参数 0”错误? - Why do I get the 'loop of ufunc does not support argument 0 of type numpy.ndarray' error for log method? TypeError:“ numpy.ndarray”对象不可调用我该怎么办? - TypeError: 'numpy.ndarray' object is not callable what do i do? 如何低调numpy.ndarray - How to downcast numpy.ndarray
 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM