[英]Solving Sylvester equations in PyTorch
我正在嘗試解決以下形式的西爾維斯特矩陣方程
AX + XB = C
據我所見,這些方程通常用 Bartels-Stewart 算法求解,采用連續的 Schur 分解。 我知道scipy.linalg
已經有一個solve_sylvester
function,但我正在將西爾維斯特方程的解集成到神經網絡中,所以我需要一種方法來計算關於X
的梯度。 目前,我只是使用 Kronecker 乘積和矢量化技巧來解決帶有torch.linalg.solve
的線性系統,但這具有可怕的運行時復雜性。 我還沒有發現任何 PyTorch 支持 Sylvester 方程,更不用說 Schur 分解了,但是在我嘗試在 GPU 上實現 Barters-Stewart 之前,有沒有更簡單的方法來找到梯度?
如果您接受復雜的解決方案X
,那么您可以在Bartels-Stewart 算法中使用特征分解而不是 schur 分解
import torch
def sylvester(A, B, C):
R, U = torch.linalg.eig(A)
S, V = torch.linalg.eig(B)
F = U.transpose(-1, -2) @ (C + 0j) @ V
W = R[..., :, None] - S[..., None, :]
Y = F @ torch.linalg.inv(W)
return U @ Y @ V.transpose(-1, -2)
可以在 GPU 上驗證
batch_size = 10
device='cuda'
N = 12
A = torch.randn((batch_size, N, N), device=device)
B = torch.randn((batch_size, N, N), device=device)
X = torch.randn((batch_size, N, N), device=device)
C = A @ X - X @ B
X_ = sylvester(A, B, C)
C_ = A @ X - X @ B
torch.allclose(C, C_)
# back-propagation works out of the box
X_.sum().backward()
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.