简体   繁体   English

具有周期性边界条件的np.ndarray

[英]np.ndarray with Periodic Boundary conditions

Problem 问题

To impose np.ndarray periodic boundary conditions as laid out below 施加np.ndarray周期性边界条件

Details 细节

  • Wrap the indexing of a python np.ndarray around the boundaries in n -dimensions 将python np.ndarray的索引包装在n维度的边界周围
  • This is a periodic boundary condition forming an n -dimensional torus 这是形成n维环面的周期性边界条件
  • Wrapping only occurs in the case that the value returned is scalar (a single point). 在返回的值为标量 (单个点)的情况下才会进行换行。
  • Slices will be treated as normal and will not be wrapped 切片将被视为正常, 不会被包裹

An example and counterexample are given below: 下面给出一个示例和反例:

a = np.arange(27).reshape(3,3,3)
b = Periodic_Lattice(a) # A theoretical class

# example: returning a scalar that shouldn't be accessible
print b[3,3,3] == b[0,0,0] # returns a scalar so invokes wrapping condition 
try: a[3,3,3] # the value is out of bounds in the original np.ndarray
except: print 'error'

# counter example: returning a slice
try: b[3,3] # this returns a slice and so shouldn't invoke the wrap
except: print 'error'

which should give the output: 哪个应该给出输出:

True
error
error

I anticipate that I should be overloading __getitem__ and __setitem__ within np.ndarray but how to proceed with this is not entirely clear and there are many implementations on SO that fail for many test cases. 我预计我应该在np.ndarray重载__getitem____setitem__但是如何继续这一点并不完全清楚,并且在许多测试用例上有许多SO实现失败。

Wrap function 包裹功能

A simple function can be written with the mod function, % in basic python and generalised to operate on an n -dimensional tuple given a specific shape. 一个简单的函数可以使用mod函数编写,在基本python中为% ,并且在给定特定形状的情况下对n维元组进行泛化。

def latticeWrapIdx(index, lattice_shape):
    """returns periodic lattice index 
    for a given iterable index

    Required Inputs:
        index :: iterable :: one integer for each axis
        lattice_shape :: the shape of the lattice to index to
    """
    if not hasattr(index, '__iter__'): return index         # handle integer slices
    if len(index) != len(lattice_shape): return index  # must reference a scalar
    if any(type(i) == slice for i in index): return index   # slices not supported
    if len(index) == len(lattice_shape):               # periodic indexing of scalars
        mod_index = tuple(( (i%s + s)%s for i,s in zip(index, lattice_shape)))
        return mod_index
    raise ValueError('Unexpected index: {}'.format(index))

This is tested as: 测试如下:

arr = np.array([[ 11.,  12.,  13.,  14.],
                [ 21.,  22.,  23.,  24.],
                [ 31.,  32.,  33.,  34.],
                [ 41.,  42.,  43.,  44.]])
test_vals = [[(1,1), 22.], [(3,3), 44.], [( 4, 4), 11.], # [index, expected value]
             [(3,4), 41.], [(4,3), 14.], [(10,10), 33.]]

passed = all([arr[latticeWrapIdx(idx, (4,4))] == act for idx, act in test_vals])
print "Iterating test values. Result: {}".format(passed)

and gives the output of, 并给出输出,

Iterating test values. Result: True

Subclassing Numpy Numpy的子类

The wrapping function can be incorporated into a subclassed np.ndarray as described here : 包装功能可以被结合到一个子类np.ndarray如所描述这里

class Periodic_Lattice(np.ndarray):
    """Creates an n-dimensional ring that joins on boundaries w/ numpy

    Required Inputs
        array :: np.array :: n-dim numpy array to use wrap with

    Only currently supports single point selections wrapped around the boundary
    """
    def __new__(cls, input_array, lattice_spacing=None):
        """__new__ is called by numpy when and explicit constructor is used:
        obj = MySubClass(params) otherwise we must rely on __array_finalize
         """
        # Input array is an already formed ndarray instance
        # We first cast to be our class type
        obj = np.asarray(input_array).view(cls)

        # add the new attribute to the created instance
        obj.lattice_shape = input_array.shape
        obj.lattice_dim = len(input_array.shape)
        obj.lattice_spacing = lattice_spacing

        # Finally, we must return the newly created object:
        return obj

    def __getitem__(self, index):
        index = self.latticeWrapIdx(index)
        return super(Periodic_Lattice, self).__getitem__(index)

    def __setitem__(self, index, item):
        index = self.latticeWrapIdx(index)
        return super(Periodic_Lattice, self).__setitem__(index, item)

    def __array_finalize__(self, obj):
        """ ndarray.__new__ passes __array_finalize__ the new object, 
        of our own class (self) as well as the object from which the view has been taken (obj). 
        See
        http://docs.scipy.org/doc/numpy/user/basics.subclassing.html#simple-example-adding-an-extra-attribute-to-ndarray
        for more info
        """
        # ``self`` is a new object resulting from
        # ndarray.__new__(Periodic_Lattice, ...), therefore it only has
        # attributes that the ndarray.__new__ constructor gave it -
        # i.e. those of a standard ndarray.
        #
        # We could have got to the ndarray.__new__ call in 3 ways:
        # From an explicit constructor - e.g. Periodic_Lattice():
        #   1. obj is None
        #       (we're in the middle of the Periodic_Lattice.__new__
        #       constructor, and self.info will be set when we return to
        #       Periodic_Lattice.__new__)
        if obj is None: return
        #   2. From view casting - e.g arr.view(Periodic_Lattice):
        #       obj is arr
        #       (type(obj) can be Periodic_Lattice)
        #   3. From new-from-template - e.g lattice[:3]
        #       type(obj) is Periodic_Lattice
        # 
        # Note that it is here, rather than in the __new__ method,
        # that we set the default value for 'spacing', because this
        # method sees all creation of default objects - with the
        # Periodic_Lattice.__new__ constructor, but also with
        # arr.view(Periodic_Lattice).
        #
        # These are in effect the default values from these operations
        self.lattice_shape = getattr(obj, 'lattice_shape', obj.shape)
        self.lattice_dim = getattr(obj, 'lattice_dim', len(obj.shape))
        self.lattice_spacing = getattr(obj, 'lattice_spacing', None)
        pass

    def latticeWrapIdx(self, index):
        """returns periodic lattice index 
        for a given iterable index

        Required Inputs:
            index :: iterable :: one integer for each axis

        This is NOT compatible with slicing
        """
        if not hasattr(index, '__iter__'): return index         # handle integer slices
        if len(index) != len(self.lattice_shape): return index  # must reference a scalar
        if any(type(i) == slice for i in index): return index   # slices not supported
        if len(index) == len(self.lattice_shape):               # periodic indexing of scalars
            mod_index = tuple(( (i%s + s)%s for i,s in zip(index, self.lattice_shape)))
            return mod_index
        raise ValueError('Unexpected index: {}'.format(index))

Testing demonstrates the lattice overloads correctly, 测试正确地演示了晶格过载,

arr = np.array([[ 11.,  12.,  13.,  14.],
                [ 21.,  22.,  23.,  24.],
                [ 31.,  32.,  33.,  34.],
                [ 41.,  42.,  43.,  44.]])
test_vals = [[(1,1), 22.], [(3,3), 44.], [( 4, 4), 11.], # [index, expected value]
             [(3,4), 41.], [(4,3), 14.], [(10,10), 33.]]

periodic_arr  = Periodic_Lattice(arr)
passed = (periodic_arr == arr).all()
passed *= all([periodic_arr[idx] == act for idx, act in test_vals])
print "Iterating test values. Result: {}".format(passed)

and gives the output of, 并给出输出,

Iterating test values. Result: True

Finally, using the code provided in the initial problem we obtain: 最后,使用初始问题中提供的代码,我们获得:

True
error
error

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM