简体   繁体   中英

Find n-dimensional point in numpy array

I am investigating whether storing points in a numpy array helps me search for points, and I have several questions about it.

I have a Point class that represents a 3-dimensional point.

class Point( object ):
  def __init__( self, x, y, z ):
    self.x = x
    self.y = y
    self.z = z

  def __repr__( self ):
    return "<Point (%r, %r, %r)>" % ( self.x, self.y, self.z )

I build a list of Point objects. Notice that the coordinates (1, 2, 3) deliberately occurs twice; that is what I am going to search for.

>>> points = [Point(1, 2, 3), Point(4, 5, 6), Point(1, 2, 3), Point(7, 8, 9)]

I store the Point objects in a numpy array.

>>> import numpy
>>> npoints = numpy.array( points )
>>> npoints
array([<Point (1, 2, 3)>, <Point (4, 5, 6)>, <Point (1, 2, 3)>,
   <Point (7, 8, 9)>], dtype=object)

I search for all points with coordinates (1, 2, 3) in the following manner.

>>> numpy.where( npoints == Point(1, 2, 3) )
>>> (array([], dtype=int64),)

But, the result is not useful. So, that does not seem to be the correct way to do it. Is numpy.where the thing to use? Is there another way to express the condition for numpy.where that would be successful?

The next thing I try is to store just the coordinates of the points in a numpy array.

>>> npoints = numpy.array( [(p.x, p.y, p.z) for p in points ])
>>> npoints
array([[1, 2, 3],
      [4, 5, 6],
      [1, 2, 3],
      [7, 8, 9]])

I search for all points with coordinates (1,2,3) in the following manner.

>>> numpy.where( npoints == [1,2,3] )
(array([0, 0, 0, 2, 2, 2]), array([0, 1, 2, 0, 1, 2]))

The result is, at least, something I can deal with. The array of row indexes in the first return value, array([0, 0, 0, 2, 2, 2]) , does indeed tell me that the coordinates I am searching for are in rows 0 and 2 of npoints . I could get away with doing something like the following.

>>> rows, cols = numpy.where( npoints == [1,2,3] )
>>> rows
array([0, 0, 0, 2, 2, 2])
>>> cols
array([0, 1, 2, 0, 1, 2])
>>> foundRows = set( rows )
>>> foundRows
set([0, 2])
>>> for r in foundRows:
...   # Do something with npoints[r]

However, I feel that I am not really using numpy.where appropriately, and that I am just getting lucky in this particular situation.

What is the appropriate way to find all occurrences of a n-dimensional point (ie, a row with particular values) in a numpy array?

Preserving the order of the array is essential.

You can create a “rich comparison” method object.__eq__(self, other) inside your Point class to be able to use == among Point objects:

class Point( object ):
  def __init__( self, x, y, z ):
    self.x = x
    self.y = y
    self.z = z

  def __repr__( self ):
    return "<Point (%r, %r, %r)>" % ( self.x, self.y, self.z )
  def __eq__(self, other):
    return self.x == other.x and self.y == other.y and self.z == other.z

import numpy
points = [Point(1, 2, 3), Point(4, 5, 6), Point(1, 2, 3), Point(7, 8, 9)]
npoints = numpy.array( points )
found = numpy.where(npoints == Point(1, 2, 3))
print(found) # => (array([0, 2]),)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM