简体   繁体   中英

Overriding dictionary behaviour Python3

I'm a beginner using Python and am trying to use the search function in in a dictionary to search for keys that are numpy arrays with the coordinates (2) of a point. So, what I want is: a dictionary whose keys are numpy arrays and whose values are integers. The in operator would be then used to compare for keys using some tolerance measure (numpy.allclose function). I understand that numpy arrays are not hashables so I would have to override the getitem and setitem functions (based on what I found in How to properly subclass dict and override __getitem__ & __setitem__ ). But how do I make these hashable to add them as keys in the dictionary? How do I override the behaviour of the in operator for this case?

Thanks for the help!

Numpy arrays are not hashable but tuples are. So you can hash the array if you turn it into a tuple. Theoretically, If you also round it beforehand you can take advantage of the fast lookup, because you now have discrete points. But you will get resolution problems during retranslating since rounding is done with decimal base but numbers are stored binary. It is possible to circumvent this by turning it into a scaled integer but that slows everything down a bit.

In the end you just need to write a class that translates back and forth between arrays and tuples on the fly and you're good to go.
An implementation could look like this:

import numpy as np

class PointDict(dict):

    def __init__(self, precision=5):
        super(PointDict, self).__init__()
        self._prec = 10**precision

    def decode(self, tup):
        """
        Turns a tuple that was used as index back into a numpy array.
        """
        return np.array(tup, dtype=float)/self._prec

    def encode(self, ndarray):
        """
        Rounds a numpy array and turns it into a tuple so that it can be used
        as index for this dict.
        """
        return tuple(int(x) for x in ndarray*self._prec)

    def __getitem__(self, item):
        return self.decode(super(PointDict, self).__getitem__(self.encode(item)))

    def __setitem__(self, item, value):
        return super(PointDict, self).__setitem__(self.encode(item), value)

    def __contains__(self, item):
        return super(PointDict, self).__contains__(self.encode(item))

    def update(self, other):
        for item, value in other.items():
            self[item] = value

    def items(self):
        for item in self:
            yield (item, self[item])

    def __iter__(self):
        for item in super(PointDict, self).__iter__():
            yield self.decode(item)

When looking up a lot of points, a pure numpy solution with vectorized batch write/lookup might be better. This solution, however, is easy to understand and to implement.

Instead of a numpy array, use a 2-tuple of floats as the key. Tuples are hashable since they are immutable.

Python dictionaries use a hash-table in the background to make key lookup fast.

Writing a closeto function isn't that hard;

def closeto(a, b, limit=0.1):
    x, y = a
    p, q = b
    return (x-p)**2 + (y-q)**2 < limit**2

And this can be used to do find points that are close. But then you have to iterate over all keys because key lookup is exact. But if you do this iteration in a comprehension , it is much faster than it for -loop.

Testing (in IPython, with Python 3):

In [1]: %cpaste
Pasting code; enter '--' alone on the line to stop or use Ctrl-D.
:    def closeto(a, b, limit=0.1):
:        x, y = a
:        p, q = b
:        return (x-p)**2 + (y-q)**2 < limit**2
:--

In [2]: d = {(0.0, 0.0): 12, (1.02, 2.17): 32, (2.0, 4.2): 23}

In [3]: {k: v for k, v in d.items() if closeto(k, (1.0, 2.0), limit=0.5)}
Out[3]: {(1.02, 2.17): 32}

Convert the arrays to tuples, which are hashable:

In [18]: a1 = np.array([0.5, 0.5])

In [19]: a2 = np.array([1.0, 1.5])

In [20]: d = {}

In [21]: d[tuple(a1)] = 14

In [22]: d[tuple(a2)] = 15

In [23]: d
Out[23]: {(0.5, 0.5): 14, (1.0, 1.5): 15}

In [24]: a3 = np.array([0.5, 0.5])

In [25]: a3 in d
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-25-07c81d61b999> in <module>()
----> 1 a3 in d

TypeError: unhashable type: 'numpy.ndarray'

In [26]: tuple(a3) in d
Out[26]: True

Unfortunately, since you want to apply a tolerance to the comparison, you don't have much option but to iterate over all the keys looking for a "close" match, whether you implement this as a function or in-line.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM