簡體   English   中英

使用LAPACK dgesvd_在非平方矩陣上的SVD

[英]SVD on a non-square matrix using LAPACK dgesvd_

我必須在非平方矩陣上計算SVD。 我為此使用LAPACK的dgesvd_例程。 與MATLAB相比,我可以得到期望值的方陣沒有問題。 但是我無法對4x5矩陣產生預期的結果。 我知道解決方案應該與MATLAB的解決方案匹配,因為返回的奇異值按降序排序。 我可以看到,雖然可以在SVD的原始A輸入數組中找到一些奇異值。 這表明我必須稱dgesvd_錯誤,否則我將錯誤地引用結果,這可能與前導數組維有關。

在每種情況下,我首先發出一個LWORK = -1的調用,向LAPACK查詢最佳值,這些最優值隨后輸入到下一個調用中以計算SVD。 我不確定返回值的所有含義以及它們是否有效,是否應該更改等。我認為它們是可以的,因此我在隨后的調用中使用它們來計算SVD。

因此,此代碼按預期工作(3x3矩陣):

 41 /* Reference data. */
 42 double ref_array_A[3][3] = {
 43     { 1, 2, 3},
 44     { 2, 4, 5 },
 45     { 3, 5, 6 }
 46 };
 47 
 48 double ref_array_U[3][3] = {
 49     { -0.327985, -0.736976, -0.591009 },
 50     { -0.591009, -0.327985, 0.736976 },
 51     { -0.736976, 0.591009, -0.327985 }
 52 };
 53 
 54 double ref_array_Sigma[3][1] = {
 55     { 11.344814 },
 56     { 0.515729 },
 57     { 0.170915 }
 58 };
 59 
 60 double ref_array_VT[3][3] = {
 61     { -0.327985, -0.591009, -0.736976 },
 62     { 0.736976, 0.327985, -0.591009 },
 63     { -0.591009, 0.736976, -0.327985 }
 64 };
 66 /* MATLAB result
 67  *
 68  *  >> A = [ 1, 2, 3; 2, 4, 5; 3, 5, 6]
 69  *
 70  *  A = 
 71  *      1     2     3
 72  *      2     4     5
 73  *      3     5     6
 74  *
 75  *  >> [U, S, V] = svd(A)
 76  *
 77  *  U =
 78  *      -0.3280   -0.7370   -0.5910
 79  *      -0.5910   -0.3280    0.7370
 80  *      -0.7370    0.5910   -0.3280
 81  *
 82  *  S =
 83  *      11.3448     0           0
 84  *      0           0.5157      0
 85  *      0           0           0.1709
 86  *
 87  *  V =
 88  *      -0.3280    0.7370   -0.5910
 89  *      -0.5910    0.3280    0.7370
 90  *      -0.7370   -0.5910   -0.3280
 91  */
double WORK_QUERY = 0;
206 
207 
208     /* Call dgesvd_ with lwork = -1 to query optimal workspace size. */
209 
210     JOBU = 'A';
211     JOBVT = 'A';
212     M = 3;
213     N = 3;
214     LDA = 3;            /* (out) */
215     LDU = 3;            /* (out) */
216     S = NULL;           /* (don't care) */
217     U = NULL;           /* (don't care) */
218     VT = NULL;          /* (don't care) */
219     LDVT = 3;           /* (out) */
220     WORK = NULL;        /* (out) , because LWORK is 0 do not care */
221     LWORK = 4 * M * N * M *N + 6 * M * N + dd_max(M, N);
222 
223     A = calloc(M * N, sizeof(double));
224     if (!A) {
225         goto ddt2_fail_sys;
226     }
227     for (i = 0; i < M; ++i) {
228         for (j = 0; j < N; ++j) {
229             A[i * N + j] = ref_array_A[i][j];
230         }
231     }
232 
233     S = calloc(dd_min(M, N), sizeof(double));
234     if (!S) {
235         goto ddt2_fail_sys;
236     }
237 
238     U = calloc(LDU * M, sizeof(double));
239     if (!U) {
240         goto ddt2_fail_sys;
241     }
242 
243     VT = calloc(LDVT * N, sizeof(double));
244     if (!A) {
245         goto ddt2_fail_sys;
246     }
247 
248     fprintf(stderr, "Reference array A:\n");
249     dd_walk_dbl_arr_rowwise(A, M, N, cb_dbl, cb_dbl_row_end);
250 
251     fprintf(stderr, "Reference array U:\n");
252     dd_walk_dbl_arr_rowwise(&ref_array_U[0][0], M, M, cb_dbl, cb_dbl_row_end);
253 
254     fprintf(stderr, "Reference array Sigma:\n");
255     dd_walk_dbl_arr_rowwise(&ref_array_Sigma[0][0], dd_min(M, N), 1, cb_dbl, cb_dbl_row_end);
256 
257     fprintf(stderr, "Reference array VT:\n");
258     dd_walk_dbl_arr_rowwise(&ref_array_VT[0][0], N, N, cb_dbl, cb_dbl_row_end);
LWORK = -1;
261     dgesvd_("A", "A", &M, &N, A, &LDA, S, U, &LDU, VT, &LDVT, &WORK_QUERY, &LWORK, &INFO);
262     if (INFO != 0) {
263         if (INFO < 0) {
264             fprintf(stderr, "Error on LAPACK's dgesvd_ query: \"the %d-th argument had illegal value\"\n", INFO);
265         } else {
266             fprintf(stderr, "Error on LAPACK's dgesvd_ query: \"DBDSDC didn't converge, updating process failed\"\n");
267         }
268         return -1;
269     }
270 
271     LWORK = (int) WORK_QUERY;
272     WORK = calloc(LWORK, sizeof(double));
273     if (!WORK) {
274         goto ddt2_fail_sys;
275     }
276 
277     fprintf(stderr, "LAPACK's dgesvd_ query optimal results: LDA %d, LDU %d, LDVT %d, LWORK %d, WORK_QUERY %f\n", LDA, LDU, LDVT, LWORK, WORK_QUERY);
278     fprintf(stderr, "Rest of params: M %d, N %d\n", M, N);
279 
280     /* Compute SVD. */
281     dgesvd_(&JOBU, &JOBVT, &M, &N, A, &LDA, S, U, &LDU, VT, &LDVT, WORK, &LWORK, &INFO);
282     if (INFO != 0) {
283         if (INFO < 0) {
284             fprintf(stderr, "Error on LAPACK's dgesvd_ query: \"the %d-th argument had illegal value\"\n", INFO);
285         } else {
286             fprintf(stderr, "Error on LAPACK's dgesvd_ query: \"DBDSDC didn't converge, updating process failed\"\n");
287         }
288         return -1;
289     }
290 
291     fprintf(stderr, "LAPACK's dgesvd_ SVD completed\n");
292 
293     fprintf(stderr, "Result A:\n");
294     dd_walk_dbl_arr_rowwise(A, M, N, cb_dbl, cb_dbl_row_end);
295 
296     fprintf(stderr, "Result U**T:\n");
297     dd_walk_dbl_arr_rowwise(U, LDU, M, cb_dbl, cb_dbl_row_end);
298     fprintf(stderr, "Result U:\n");
299     dd_walk_dbl_arr_colwise(U, LDU, M, cb_dbl, cb_dbl_row_end);
300 
301 
302     fprintf(stderr, "Result S:\n");
303     dd_walk_dbl_arr_rowwise(S, dd_min(M, N), 1, cb_dbl, cb_dbl_row_end);
304 
305     fprintf(stderr, "Result VT:\n");
306     dd_walk_dbl_arr_rowwise(VT, LDVT, N, cb_dbl, cb_dbl_row_end);
307 
308     free(WORK);
309     free(A);
310     free(S);
311     free(U);
312     free(VT);
313 
314     return 0;

正確的結果:

peter@xx:~$ ./test4
Reference array A:
    1.000000    2.000000    3.000000
    2.000000    4.000000    5.000000
    3.000000    5.000000    6.000000
Reference array U:
    -0.327985   -0.736976   -0.591009
    -0.591009   -0.327985   0.736976
    -0.736976   0.591009    -0.327985
Reference array Sigma:
    11.344814
    0.515729
    0.170915
Reference array VT:
    -0.327985   -0.591009   -0.736976
    0.736976    0.327985    -0.591009
    -0.591009   0.736976    -0.327985
LAPACK's dgesvd_ query optimal results: LDA 3, LDU 3, LDVT 3, LWORK 201, WORK_QUERY 201.000000
Rest of params: M 3, N 3
LAPACK's dgesvd_ SVD completed
Result A:
    -3.741657   0.421793    0.632690
    10.643576   1.261481    -0.720622
    0.478213    -0.279401   -0.211863
Result U**T:
    -0.327985   -0.591009   -0.736976
    -0.736976   -0.327985   0.591009
    -0.591009   0.736976    -0.327985
Result U:
    -0.327985   -0.736976   -0.591009
    -0.591009   -0.327985   0.736976
    -0.736976   0.591009    -0.327985
Result S:
    11.344814
    0.515729
    0.170915
Result VT:
    -0.327985   0.736976    -0.591009
    -0.591009   0.327985    0.736976
    -0.736976   -0.591009   -0.327985

但是不是這個(4x5矩陣):

 39 /* Reference data. */
 40 double ref_array_A[4][5] = {
 41     { 1, 0, 0, 0, 2 },
 42     { 0, 0, 3, 0, 0 },
 43     { 0, 0, 0, 0, 0 },
 44     { 0, 2, 0, 0, 0 }
 45 };
 46 
 47 double ref_array_U[4][4] = {
 48     { 0, 0, 1, 0 },
 49     { 0, 1, 0, 0 },
 50     { 0, 0, 0, -1 },
 51     { 1, 0, 0, 0 }
 52 };
 53 
 54 double ref_array_Sigma[4][5] = {
 55     { 2, 0, 0, 0, 0 },
 56     { 0, 3, 0, 0, 0 },
 57     { 0, 0, 2.236068, 0, 0 },
 58     { 0, 0, 0, 0, 0 }
 59 };
 60 
 61 double ref_array_VT[5][5] = {
 62     { 0, 1, 0, 0, 0 },
 63     { 0, 0, 1, 0, 0 },
 64     { 0.447214, 0, 0, 0, 0.894427 },
 65     { 0, 0, 0, 1, 0 },
 66     { -0.894427, 0, 0, 0, -0.447214 }
 67 };
 68 
 69 /* MATLAB result
 70  *
 71  * >> A = [ 1 0 0 0 2; 0 0 3 0 0 ; 0 0 0 0 0 ;0 2 0 0 0 ];
 72  * >> [U, S, V] = svd(A)
 73  *
 74  * U =
 75  *      0     1     0     0
 76  *      1     0     0     0
 77  *      0     0     0    -1
 78  *      0     0     1     0
 79  *
 80  * S =
 81  *      3.0000      0           0           0           0
 82  *      0           2.2361      0           0           0
 83  *      0           0           2.0000      0           0
 84  *      0           0           0           0           0
 85  *
 86  * V =
 87  *      0           0.4472      0           0           -0.8944
 88  *      0           0           1.0000      0           0
 89  *      1.0000      0           0           0           0
 90  *      0           0           0           1.0000      0
 91  *      0           0.8944      0           0           0.4472
 92  */
double WORK_QUERY = 0;
206 
207 
208     /* Call dgesvd_ with lwork = -1 to query optimal workspace size. */
209 
210     JOBU = 'A';
211     JOBVT = 'A';
212     M = 4;
213     N = 5;
214     LDA = 4;            /* (out) */
215     LDU = 4;            /* (out) */
216     S = NULL;           /* (don't care) */
217     U = NULL;           /* (don't care) */
218     VT = NULL;          /* (don't care) */
219     LDVT = 5;           /* (out) */
220     WORK = NULL;        /* (out) , because LWORK is 0 do not care */
221     LWORK = 4 * M * N * M *N + 6 * M * N + dd_max(M, N);
222 
223     A = calloc(M * N, sizeof(double));
224     if (!A) {
225         goto ddt2_fail_sys;
226     }
227     for (i = 0; i < M; ++i) {
228         for (j = 0; j < N; ++j) {
229             A[i * N + j] = ref_array_A[i][j];
230         }
231     }
232 
233     S = calloc(M * N, sizeof(double));
234     if (!S) {
235         goto ddt2_fail_sys;
236     }
237 
238     U = calloc(LDU * M, sizeof(double));
239     if (!U) {
240         goto ddt2_fail_sys;
241     }
242 
243     VT = calloc(LDVT * N, sizeof(double));
244     if (!A) {
245         goto ddt2_fail_sys;
246     }
247 
248     fprintf(stderr, "Reference array A:\n");
249     dd_walk_dbl_arr_rowwise(A, M, N, cb_dbl, cb_dbl_row_end);
250 
251     fprintf(stderr, "Reference array U:\n");
252     dd_walk_dbl_arr_rowwise(&ref_array_U[0][0], M, M, cb_dbl, cb_dbl_row_end);
253 
254     fprintf(stderr, "Reference array Sigma:\n");
255     dd_walk_dbl_arr_rowwise(&ref_array_Sigma[0][0], M, N, cb_dbl, cb_dbl_row_end);
256 
257     fprintf(stderr, "Reference array VT:\n");
258     dd_walk_dbl_arr_rowwise(&ref_array_VT[0][0], N, N, cb_dbl, cb_dbl_row_end);
259 
260     LWORK = -1;
261     dgesvd_("A", "A", &M, &N, A, &LDA, S, U, &LDU, VT, &LDVT, &WORK_QUERY, &LWORK, &INFO);
if (INFO != 0) {
263         if (INFO < 0) {
264             fprintf(stderr, "Error on LAPACK's dgesvd_ query: \"the %d-th argument had illegal value\"\n", INFO);
265         } else {
266             fprintf(stderr, "Error on LAPACK's dgesvd_ query: \"DBDSDC didn't converge, updating process failed\"\n");
267         }
268         return -1;
269     }
270 
271     LWORK = (int) WORK_QUERY;
272     WORK = calloc(LWORK, sizeof(double));
273     if (!WORK) {
274         goto ddt2_fail_sys;
275     }
276 
277     fprintf(stderr, "LAPACK's dgesvd_ query optimal results: LDA %d, LDU %d, LDVT %d, LWORK %d, WORK_QUERY %f\n", LDA, LDU, LDVT, LWORK, WORK_QUERY);
278     fprintf(stderr, "Rest of params: M %d, N %d\n", M, N);
279 
280     /* Compute SVD. */
281     dgesvd_(&JOBU, &JOBVT, &M, &N, A, &LDA, S, U, &LDU, VT, &LDVT, WORK, &LWORK, &INFO);
282     if (INFO != 0) {
283         if (INFO < 0) {
284             fprintf(stderr, "Error on LAPACK's dgesvd_ query: \"the %d-th argument had illegal value\"\n", INFO);
285         } else {
286             fprintf(stderr, "Error on LAPACK's dgesvd_ query: \"DBDSDC didn't converge, updating process failed\"\n");
287         }
288         return -1;
289     }
290 
291     fprintf(stderr, "LAPACK's dgesvd_ SVD completed\n");
292 
293     fprintf(stderr, "Result A:\n");
294     dd_walk_dbl_arr_rowwise(A, M, N, cb_dbl, cb_dbl_row_end);
295 
296     fprintf(stderr, "Result U:\n");
297     dd_walk_dbl_arr_rowwise(U, LDU, M, cb_dbl, cb_dbl_row_end);
298 
299     fprintf(stderr, "Result S:\n");
300     dd_walk_dbl_arr_rowwise(S, M, N, cb_dbl, cb_dbl_row_end);
301 
302     fprintf(stderr, "Result VT:\n");
303     dd_walk_dbl_arr_rowwise(VT, LDVT, N, cb_dbl, cb_dbl_row_end);
304 
305     free(WORK);
306     free(A);
307     free(S);
308     free(U);
309     free(VT);
310 
311     return 0;

不良結果:

peter@xx:~/$ ./test2
Reference array A:
    1.000000    0.000000    0.000000    0.000000    2.000000
    0.000000    0.000000    3.000000    0.000000    0.000000
    0.000000    0.000000    0.000000    0.000000    0.000000
    0.000000    2.000000    0.000000    0.000000    0.000000
Reference array U:
    0.000000    0.000000    1.000000    0.000000
    0.000000    1.000000    0.000000    0.000000
    0.000000    0.000000    0.000000    -1.000000
    1.000000    0.000000    0.000000    0.000000
Reference array Sigma:
    2.000000    0.000000    0.000000    0.000000    0.000000
    0.000000    3.000000    0.000000    0.000000    0.000000
    0.000000    0.000000    2.236068    0.000000    0.000000
    0.000000    0.000000    0.000000    0.000000    0.000000
Reference array VT:
    0.000000    1.000000    0.000000    0.000000    0.000000
    0.000000    0.000000    1.000000    0.000000    0.000000
    0.447214    0.000000    0.000000    0.000000    0.894427
    0.000000    0.000000    0.000000    1.000000    0.000000
    -0.894427   0.000000    0.000000    0.000000    -0.447214
LAPACK's dgesvd_ query optimal results: LDA 4, LDU 4, LDVT 5, LWORK 300, WORK_QUERY 300.000000
Rest of params: M 4, N 5
LAPACK's dgesvd_ SVD completed
Result A:
    -3.000000   -2.000000   0.000000    -1.000000   0.500000
    -2.236068   0.000000    0.000000    0.000000    0.000000
    0.000000    0.000000    0.000000    0.000000    0.000000
    0.000000    0.500000    -0.236068   0.000000    0.000000
Result U:
    0.707107    0.000000    0.000000    0.707107
    -0.707107   0.000000    -0.000000   0.707107
    0.000000    0.000000    1.000000    0.000000
    0.000000    1.000000    0.000000    0.000000
Result S:
    3.872983    1.732051    0.000000    0.000000    0.000000
    0.000000    0.000000    0.000000    0.000000    0.000000
    0.000000    0.000000    0.000000    0.000000    0.000000
    0.000000    0.000000    0.000000    0.000000    0.000000
Result VT:
    0.182574    -0.408248   0.000000    0.000000    -0.894427
    0.912871    0.408248    0.000000    0.000000    0.000000
    -0.000000   -0.000000   1.000000    0.000000    0.000000
    -0.000000   -0.000000   0.000000    1.000000    0.000000
    0.365148    -0.816497   0.000000    0.000000    0.447214

在一般矩陣情況下,我該怎么辦?

函數dgesvd_期望矩陣以列優先的順序排列,而您的代碼以行優先的樣式提供數據:

227     for (i = 0; i < M; ++i) {
228         for (j = 0; j < N; ++j) {
229             A[i * N + j] = ref_array_A[i][j];
230         }
231     }

有效地,您的代碼因此在計算SVD

[ 1 2 0 0 2 ]   [ 1 0 0 0 ] ^ T
[ 0 0 0 0 0 ] = [ 2 0 0 3 ]
[ 0 0 0 0 0 ]   [ 0 0 0 0 ]
[ 0 3 0 0 0 ]   [ 2 0 0 0 ]

的確產生了大約3.87, 1.73

由於矩陣是正方形( M=N )且對稱,因此在第一個示例中不會發生此錯誤。

同樣,參數S應該只是一維數組(如您的第一個示例)。 由於您將其打印出來,然后以dd_walk_dbl_arr_rowwise(S, M, N, cb_dbl, cb_dbl_row_end);行格式dd_walk_dbl_arr_rowwise(S, M, N, cb_dbl, cb_dbl_row_end); ,這些值將連續顯示在第一行中...

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM