[英]How can you implement Householder based QR decomposition in Python?
我目前正在嘗試為矩形矩陣實現基於 Householder 的 QR 分解,如http://eprints.ma.man.ac.uk/1192/1/qrupdating_12nov08.pdf (第 3、4、5 頁)中所述。
顯然我弄錯了一些偽代碼,因為 (1) 我的結果不同於numpy.qr.linalg()
和 (2) 我的例程產生的矩陣R
不是上三角矩陣。
我的代碼(也可在https://pyfiddle.io/fiddle/afcc2e0e-0857-4cb2-adb5-06ff9b80c9d3/?i=true下找到)
import math
import argparse
import numpy as np
from typing import Union
def householder(alpha: float, x: np.ndarray) -> Union[np.ndarray, int]:
"""
Computes Householder vector for alpha and x.
:param alpha:
:param x:
:return:
"""
s = math.pow(np.linalg.norm(x, ord=2), 2)
v = x
if s == 0:
tau = 0
else:
t = math.sqrt(alpha * alpha + s)
v_one = alpha - t if alpha <= 0 else -s / (alpha + t)
tau = 2 * v_one * v_one / (s + v_one * v_one)
v /= v_one
return v, tau
def qr_decomposition(A: np.ndarray, m: int, n: int) -> Union[np.ndarray, np.ndarray]:
"""
Applies Householder-based QR decomposition on specified matrix A.
:param A:
:param m:
:param n:
:return:
"""
H = []
R = A
Q = A
I = np.eye(m, m)
for j in range(0, n):
# Apply Householder transformation.
x = A[j + 1:m, j]
v_householder, tau = householder(np.linalg.norm(x), x)
v = np.zeros((1, m))
v[0, j] = 1
v[0, j + 1:m] = v_householder
res = I - tau * v * np.transpose(v)
R = np.matmul(res, R)
H.append(res)
return Q, R
m = 10
n = 8
A = np.random.rand(m, n)
q, r = np.linalg.qr(A)
Q, R = qr_decomposition(A, m, n)
print("*****")
print(Q)
print(q)
print("-----")
print(R)
print(r)
所以我不清楚如何將零引入我的R
矩陣/關於我的代碼的哪一部分不正確。 我很高興有任何指示! 非常感謝您的時間。
您鏈接到的筆記中有很多問題/缺少細節。 在咨詢了一些其他來源(包括這本非常有用的教科書)之后,我能夠想出一個類似的工作實現。
下面是qr_decomposition
工作版本的qr_decomposition
:
import numpy as np
from typing import Union
def householder(x: np.ndarray) -> Union[np.ndarray, int]:
alpha = x[0]
s = np.power(np.linalg.norm(x[1:]), 2)
v = x.copy()
if s == 0:
tau = 0
else:
t = np.sqrt(alpha**2 + s)
v[0] = alpha - t if alpha <= 0 else -s / (alpha + t)
tau = 2 * v[0]**2 / (s + v[0]**2)
v /= v[0]
return v, tau
def qr_decomposition(A: np.ndarray) -> Union[np.ndarray, np.ndarray]:
m,n = A.shape
R = A.copy()
Q = np.identity(m)
for j in range(0, n):
# Apply Householder transformation.
v, tau = householder(R[j:, j])
H = np.identity(m)
H[j:, j:] -= tau * v.reshape(-1, 1) @ v
R = H @ R
Q = H @ Q
return Q[:n].T, R[:n]
m = 5
n = 4
A = np.random.rand(m, n)
q, r = np.linalg.qr(A)
Q, R = qr_decomposition(A)
with np.printoptions(linewidth=9999, precision=20, suppress=True):
print("**** Q from qr_decomposition")
print(Q)
print("**** Q from np.linalg.qr")
print(q)
print()
print("**** R from qr_decomposition")
print(R)
print("**** R from np.linalg.qr")
print(r)
輸出:
**** Q from qr_decomposition
[[ 0.5194188817843675 -0.10699353671401633 0.4322294754656072 -0.7293293270703678 ]
[ 0.5218635773595086 0.11737804362574514 -0.5171653705211056 0.04467925806590414]
[ 0.34858177783013133 0.6023104248793858 -0.33329256746256875 -0.03450824948274838]
[ 0.03371048915852807 0.6655221685383623 0.6127023580593225 0.28795294754791 ]
[ 0.5789790833500734 -0.411189947884951 0.24337120818874305 0.618041080584351 ]]
**** Q from np.linalg.qr
[[-0.5194188817843672 0.10699353671401617 0.4322294754656068 0.7293293270703679 ]
[-0.5218635773595086 -0.11737804362574503 -0.5171653705211053 -0.044679258065904115]
[-0.3485817778301313 -0.6023104248793857 -0.33329256746256863 0.03450824948274819 ]
[-0.03371048915852807 -0.665522168538362 0.6127023580593226 -0.2879529475479097 ]
[-0.5789790833500733 0.41118994788495106 0.24337120818874317 -0.6180410805843508 ]]
**** R from qr_decomposition
[[ 0.6894219296137802 1.042676051151294 1.3418719684631446 1.2498925815126485 ]
[ 0.00000000000000000685 0.7076056836914905 0.29883043386651403 0.41955370595004277 ]
[-0.0000000000000000097 -0.00000000000000007292 0.5304551654027297 0.18966088433421135 ]
[-0.00000000000000000662 0.00000000000000008718 0.00000000000000002322 0.6156558913022807 ]]
**** R from np.linalg.qr
[[-0.6894219296137803 -1.042676051151294 -1.3418719684631442 -1.2498925815126483 ]
[ 0. -0.7076056836914905 -0.29883043386651376 -0.4195537059500425 ]
[ 0. 0. 0.53045516540273 0.18966088433421188]
[ 0. 0. 0. -0.6156558913022805 ]]
這個版本的qr_decomposition
幾乎完全再現了np.linalg.qr
的輸出。 下面對差異進行了評論。
np.linalg.qr
和qr_decomposition
的輸出值匹配高精度。 然而, qr_decomposition
用於在R
產生零的計算組合並沒有完全取消,因此零實際上並不完全等於零。
事實證明, np.linalg.qr
並沒有做任何花哨的浮點技巧來確保其輸出中的零為0.0
。 它只是調用np.triu
,它強制將這些值設置為0.0
。 所以要達到相同的結果,只需將qr_decomposition
的return
行qr_decomposition
為:
return Q[:n].T, np.triu(R[:n])
+
/ -
) Q 和 R 中的一些+/-
符號在np.linalg.qr
和qr_decomposition
的輸出中是不同的,但這並不是真正的問題,因為符號有許多有效的選擇(請參閱有關Q 和 R )。 您可以通過使用替代算法生成v
和tau
來完全匹配np.linalg.qr
的符號約定:
def householder_vectorized(a):
"""Use this version of householder to reproduce the output of np.linalg.qr
exactly (specifically, to match the sign convention it uses)
based on https://rosettacode.org/wiki/QR_decomposition#Python
"""
v = a / (a[0] + np.copysign(np.linalg.norm(a), a[0]))
v[0] = 1
tau = 2 / (v.T @ v)
return v,tau
np.linalg.qr
的輸出qr_decomposition
,這個版本的qr_decomposition
將與np.linalg.qr
的輸出完全匹配:
import numpy as np
from typing import Union
def qr_decomposition(A: np.ndarray) -> Union[np.ndarray, np.ndarray]:
m,n = A.shape
R = A.copy()
Q = np.identity(m)
for j in range(0, n):
# Apply Householder transformation.
v, tau = householder_vectorized(R[j:, j, np.newaxis])
H = np.identity(m)
H[j:, j:] -= tau * (v @ v.T)
R = H @ R
Q = H @ Q
return Q[:n].T, np.triu(R[:n])
m = 5
n = 4
A = np.random.rand(m, n)
q, r = np.linalg.qr(A)
Q, R = qr_decomposition(A)
with np.printoptions(linewidth=9999, precision=20, suppress=True):
print("**** Q from qr_decomposition")
print(Q)
print("**** Q from np.linalg.qr")
print(q)
print()
print("**** R from qr_decomposition")
print(R)
print("**** R from np.linalg.qr")
print(r)
輸出:
**** Q from qr_decomposition
[[-0.10345123000824041 0.6455437884382418 0.44810714367794663 -0.03963544711256745 ]
[-0.55856415402318 -0.3660716543156899 0.5953932791844518 0.43106504879433577 ]
[-0.30655198880585594 0.6606757192118904 -0.21483067305535333 0.3045011114089389 ]
[-0.48053620675695174 -0.11139783377793576 -0.6310958848894725 0.2956864520726446 ]
[-0.5936453158283703 -0.01904935140131578 -0.016510508076204543 -0.79527388379824 ]]
**** Q from np.linalg.qr
[[-0.10345123000824041 0.6455437884382426 0.44810714367794663 -0.039635447112567376]
[-0.5585641540231802 -0.3660716543156898 0.5953932791844523 0.4310650487943359 ]
[-0.30655198880585594 0.6606757192118904 -0.21483067305535375 0.30450111140893893 ]
[-0.48053620675695186 -0.1113978337779356 -0.6310958848894725 0.29568645207264455 ]
[-0.5936453158283704 -0.01904935140131564 -0.0165105080762043 -0.79527388379824 ]]
**** R from qr_decomposition
[[-1.653391466100325 -1.0838054573405895 -1.0632037969249921 -1.1825735233596888 ]
[ 0. 0.7263519982452554 0.7798481878600413 0.5496287509656425 ]
[ 0. 0. -0.26840760341581243 -0.2002757085967938 ]
[ 0. 0. 0. 0.48524469321440966]]
**** R from np.linalg.qr
[[-1.6533914661003253 -1.0838054573405895 -1.0632037969249923 -1.182573523359689 ]
[ 0. 0.7263519982452559 0.7798481878600418 0.5496287509656428]
[ 0. 0. -0.2684076034158126 -0.2002757085967939]
[ 0. 0. 0. 0.4852446932144096]]
除了尾隨數字中不可避免的舍入錯誤之外,輸出現在匹配。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.