简体   繁体   中英

Python equivalent of Matlab shiftdim()

I am currently converting some Matlab code to Python and I am wondering if there is a similar function to Matlab's shiftdim(A, n)

B = shiftdim(A,n) shifts the dimensions of an array A by n positions. shiftdim shifts the dimensions to the left when n is a positive integer and to the right when n is a negative integer. For example, if A is a 2-by-3-by-4 array, then shiftdim(A,2) returns a 4-by-2-by-3 array.

If you use numpy you can use np.moveaxis .

From the docs :

>>> x = np.zeros((3, 4, 5))
>>> np.moveaxis(x, 0, -1).shape
(4, 5, 3)
>>> np.moveaxis(x, -1, 0).shape
(5, 3, 4)

numpy.moveaxis(a, source, destination)[source]

 Parameters a: np.ndarray The array whose axes should be reordered. source: int or sequence of int Original positions of the axes to move. These must be unique. destination: int or sequence of int Destination positions for each of the original axes. These must also be unique.

shiftdim 's function is a bit more complex than shifting axes around.

  • For input shiftdim(A, n) , if n is positive, shift the axes to the left by n (ie, rotate), but if n is negative, shift the axes to the right and append trailing dimensions of size 1.
  • For input shiftdim(A) , remove any trailing dimensions of size 1.
from collections import deque
import numpy as np

def shiftdim(array, n=None):
    if n is not None:
        if n >= 0:
            axes = tuple(range(len(array.shape)))
            new_axes = deque(axes)
            new_axes.rotate(n)
            return np.moveaxis(array, axes, tuple(new_axes))
        return np.expand_dims(array, axis=tuple(range(-n)))
    else:
        idx = 0
        for dim in array.shape:
            if dim == 1:
                idx += 1
            else:
                break
        axes = tuple(range(idx))
        # Note that this returns a tuple of 2 results
        return np.squeeze(array, axis=axes), len(axes)

Same examples as the Matlab docs

a = np.random.uniform(size=(4, 2, 3, 5))
print(shiftdim(a, 2).shape)      # prints (3, 5, 4, 2)
print(shiftdim(a, -2).shape)     # prints (1, 1, 4, 2, 3, 5)

a = np.random.uniform(size=(1, 1, 3, 2, 4))
b, nshifts = shiftdim(a)
print(nshifts)                   # prints 2
print(b.shape)                   # prints (3, 2, 4)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM