簡體   English   中英

在 PyTorch 中求解 Sylvester 方程

[英]Solving Sylvester equations in PyTorch

我正在嘗試解決以下形式的西爾維斯特矩陣方程

AX + XB = C

據我所見,這些方程通常用 Bartels-Stewart 算法求解,采用連續的 Schur 分解。 我知道scipy.linalg已經有一個solve_sylvester function,但我正在將西爾維斯特方程的解集成到神經網絡中,所以我需要一種方法來計算關於X的梯度。 目前,我只是使用 Kronecker 乘積和矢量化技巧來解決帶有torch.linalg.solve的線性系統,但這具有可怕的運行時復雜性。 我還沒有發現任何 PyTorch 支持 Sylvester 方程,更不用說 Schur 分解了,但是在我嘗試在 GPU 上實現 Barters-Stewart 之前,有沒有更簡單的方法來找到梯度?

如果您接受復雜的解決方案X ,那么您可以在Bartels-Stewart 算法中使用特征分解而不是 schur 分解

import torch

def sylvester(A, B, C):
    R, U = torch.linalg.eig(A)
    S, V = torch.linalg.eig(B)
    F = U.transpose(-1, -2) @ (C + 0j) @ V
    W = R[..., :, None] - S[..., None, :]
    Y = F @ torch.linalg.inv(W)
    return U @ Y @ V.transpose(-1, -2)

可以在 GPU 上驗證

batch_size = 10
device='cuda'
N = 12
A = torch.randn((batch_size, N, N), device=device)
B = torch.randn((batch_size, N, N), device=device)
X = torch.randn((batch_size, N, N), device=device)
C = A @ X - X @ B
X_ = sylvester(A, B, C)
C_ = A @ X - X @ B
torch.allclose(C, C_)

# back-propagation works out of the box
X_.sum().backward()

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM