简体   繁体   中英

Einstein summation in numpy for non-linear function

I'm trying to set up the following non-linear cost functions (1) and (2) 伊姆古尔

by using the numpy einsum function.

I tried translating them to python where R is an array of shape (14,3,3) and v , vt (denotes v' from equation (1)) are of shape (14,3). old_points and new_points representing p and p' respectively are of shape (6890,3).

def costfunc(x, old_points, new_points, weights, n_joints):
    """
    Set up non-linear cost functions by using equations from LBS:
    (1) p'_i = sum{j}(w_ji (R_j p_i + v_j))
    (2) p_i - sum{j}(w_ji (R_j p'_i + v'_j))
    where Rt denotes the transpose of R.
    :param old_points: original vertex positions
    :param new_points: transformed vertex positions
    :param weights: weight matrix obtained from spectral clustering
    :param n_joints: number of joints
    :return: non-linear cost functions as in (1), (2) to find the root of
    """

    # Extract rotations R, Rt and offsets v, v' from rv
    R = np.array([(np.array(x[j * 15:j * 15 + 9]).reshape(3, 3)) for j in range(n_joints)])
    Rt = np.array([R[j].T for j in range(n_joints)])
    v = np.array([(np.array(x[j * 15 + 9:j * 15 + 12])) for j in range(n_joints)])
    vt = np.array([(np.array(x[j * 15 + 12:j * 15 + 15])) for j in range(n_joints)])

    ## Use equations (1) and (2) for the non-linear pass.
    # R_j p_i
    Rp = np.einsum('jkl,il', R, old_points)
    Rtv = np.einsum('jkl,il', Rt, v)
    # Rt_j p'_i
    Rtp = np.einsum('jkl,il', Rt, new_points)
    Rvt = np.einsum('jkl,il', R, vt)

    # w_ji (Rp_ij - Rtv_j)
    wRpv = np.einsum('ji,ijk->ik', weights, Rp - Rvt)
    # w_ji (Rtp'_ij - Rv'_j)
    wRtpv = np.einsum('ji,ijk->ik', weights, Rtp - Rtv)

    # Set up a non-linear cost function, then compute the squared norm.
    d = new_points - wRpv
    dt = old_points - wRtpv

    norm = np.linalg.norm(d, axis=1)
    normt = np.linalg.norm(dt, axis=1)

    result = np.concatenate([norm, normt])

    return np.power(result, 2)

Right now, there is an error in the lines where wRpv and wRtpv are computed ValueError: operands could not be broadcast together with shapes (6890,14,3) (14,14,3) . How do I resolve this? Any help is very much appreciated!

I got it now. This is the solution:

def costfunc(x, old_points, new_points, weights, n_joints):
    """
    Set up non-linear cost functions by using equations from LBS:
    (1) p'_i = sum{j}(w_ji (R_j p_i + v_j))
    (2) p_i - sum{j}(w_ji (R_j p'_i + v'_j))
    where Rt denotes the transpose of R.
    :param old_points: original vertex positions
    :param new_points: transformed vertex positions
    :param weights: weight matrix obtained from spectral clustering
    :param n_joints: number of joints
    :return: non-linear cost functions as in (1), (2) to find the root of
    """

    # Extract rotations R, Rt and offsets v, v' from rv
    R = np.array([(np.array(x[j * 15:j * 15 + 9]).reshape(3, 3)) for j in range(n_joints)])
    Rt = np.array([R[j].T for j in range(n_joints)])
    v = np.array([(np.array(x[j * 15 + 9:j * 15 + 12])) for j in range(n_joints)])
    vt = np.array([(np.array(x[j * 15 + 12:j * 15 + 15])) for j in range(n_joints)])

    ## Use equations (1) and (2) for the non-linear pass.
    # R_j p_i
    Rp = np.einsum('jkl,il', R, old_points)
    # Rt_j p'_i
    Rtp = np.einsum('jkl,il', Rt, new_points)

    # R_j v'_j
    Rvt = np.array([R[i] @ vt[i] for i in range(n_joints)])
    # Rt_j v_j
    Rtv = np.array([Rt[i] @ v[i] for i in range(n_joints)])

    # w_ji (Rp_ij - Rtv_j)
    wRpv = np.einsum('ji,ijk->ik', weights, Rp - Rvt)
    # w_ji (Rtp'_ij - Rv'_j)
    wRtpv = np.einsum('ji,ijk->ik', weights, Rtp - Rtv)

    # Set up a non-linear cost function, then compute the squared norm.
    d = new_points - wRpv
    dt = old_points - wRtpv

    norm = np.linalg.norm(d, axis=1)
    normt = np.linalg.norm(dt, axis=1)

    result = np.concatenate([norm, normt])

    return np.power(result, 2)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM