简体   繁体   中英

How can I efficiently solve a quaduatic equation system via pytorch?

I need to solve a non-linear equation system shown below:

def trySolveEquation(V, L):
    #The equation to solve is:
    #{     (V . Ct) ^ 2 = 1
    #{     (L + u) . Ct = 0
    #C and u are the unknowns, C is a vector and u is a scalar, Ct is a vector transposed from C.
    #V is a vector with dimension equal to Ct, L is a square matrix, they are known.
    #u is Lagrange multiplier, and it's unknown.
    #'.' means matrix multiply.

    dim = V.shape[0]
    assert L.shape[0] == dim and L.shape[1] == dim
    C = torch.zeros((dim))
    Ct = C.view((dim, 1))
    u = 0

    '*Solve the equation here.*'

    print('C=', C)
    print('u=', u)
    return C, u

The dimention of C is about ten, and this equation system would be solved up to a billion of times, so it's nice to be implemented via torch so GPU could be utilized. Is there any methods better than gradient descent?

PyTorch only natively supports solving systems of linear equations (eg torch.solve , torch.linalg.solve ). But you can try eg:

locuslab/qpth

A fast and differentiable QP solver for PyTorch.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM