简体   繁体   中英

Numpy: creating a symmetric matrix class

based on this answer I was coding a simple class for symmetric matrices in python with numpy, but I'm having (probably a very simple) problem. This is the problematic code:

import numpy as np

 class SyMatrix(np.ndarray):
    def __init__(self, arr):
        self = (arr + arr.T)/2.0 - np.diag(np.diag(arr)) 
    def __setitem__(self,(i,j), val):
        np.ndarray.__setitem__(self, (i, j), value)
        np.ndarray.__setitem__(self, (j, i), value)

Besides this feeling wrong (I don't know if assigning to self is a good practice...) When I try to create a new array I get this:

>>> foo = SyMatrix( np.zeros(shape = (2,2)))
Traceback (most recent call last):
   File "<stdin>", line 1, in <module>
TypeError: only length-1 arrays can be converted to Python scalars

I also tried:

import numpy as np

 class SyMatrix(np.ndarray):
    def __init__(self, n):
        self =  np.zeros(shape = (n,n)).view(SyMatrix)  
    def __setitem__(self,(i,j), val):
        np.ndarray.__setitem__(self, (i, j), value)
        np.ndarray.__setitem__(self, (j, i), value)

And then I get:

>>> foo = SyMatrix(2)
>>> foo
SyMatrix([  6.93581448e-310,   2.09933710e-316])
>>> 

where I expected an array with shape=(2,2) . What's the correct way to do what I'm trying to do? Is assigning to self problematic?

There are a few issues here.

  1. When subclassing numpy.ndarray() , you should overwrite __new__() , not __init__() . Your line

    foo = SyMatrix(2)

    actually calls numpy.ndarray.__new__() with the parameter 2, incompatible with its signature .

  2. Assigning to self does absolutely nothing here. It just creates an object and makes the local name self point to this object. As soon as the function exits, all local names are dropped. Assignment in Python neither creates variables , nor does it alter objects ; it just assigns an existing object to a name.

  3. Even when fixing the last two issues, your symmetric matrix class won't work as expected. There are literally dozens of methods you would need to overwrite to ensure that the matrix is always symmetric.

  4. (arr + arr.T)/2.0 - np.diag(np.diag(arr)) most probably isn't what you want. It will always have zeros on the diagonal. You probably want (arr + arr.T)/2.0 .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM