[英]Select Armadillo sub-matrix with non contiguous indices
I am passing a python code to C++ where I find python expressions like this: 我正在将python代码传递给C ++,在这里我找到这样的python表达式:
J11 = dS_dVa[array([pvpq]).T, pvpq].real
Here, J11
and dS_dVa
are Sparse matrices, and pvpq
is an array of indices that can be in any growing order (ie. 1, 2, 5, 7, 9) 在此, J11
和dS_dVa
是稀疏矩阵,而pvpq
是可以按任何增长顺序(即pvpq
)的索引数组
Looking at the documentation here I have inferred the following: 查看此处的文档,我推断出以下内容:
arma::Row<int> pvpq(calc->pqpv);
arma::sp_mat J11 = arma::real(dS_dVa.submat(pvpq, pvpq));
where calc->pqpv
is of type std::vector<int>
. 其中calc->pqpv
类型为std::vector<int>
。
However the GCC compiler says: 但是,GCC编译器说:
engine.h:2436: error: no matching function for call to ‘arma::SpMat<std::complex<double> >::submat(arma::Row<int>&, arma::Row<int>&)’
arma::sp_mat J11 = arma::real(dS_dVa.submat(pvpq, pvpq));
^
How do I fix this? 我该如何解决?
Is it telling me that Sparse matrices do not have the submat
method? 它是在告诉我稀疏矩阵没有submat
方法吗?
After a while, I made my own function. 一段时间后,我发挥了自己的作用。 It uses the inner CSC structure. 它使用内部CSC结构。
/**
* @brief sp_submatrix Function to extract columns and rows from a sparse matrix
* @param A Sparse matrix pointer
* @param rows vector of the rown indices to keep (must be sorted)
* @param cols vector of the clumn indices to keep (must be sorted)
* @return Sparse matrix of the indicated indices
*/
arma::sp_mat sp_submatrix(arma::sp_mat *A, std::vector<std::size_t> *rows, std::vector<std::size_t> *cols) {
std::size_t n_rows = rows->size();
std::size_t n_cols = cols->size();
bool found = false;
std::size_t n = 0;
std::size_t p = 0;
std::size_t found_idx = 0;
arma::vec new_val(A->n_nonzero);
arma::uvec new_row_ind(A->n_nonzero);
arma::uvec new_col_ptr(n_cols + 1);
new_col_ptr(p) = 0;
for (auto const& j: *cols) { // for every column in the cols vector
for (std::size_t k = A->col_ptrs[j]; k < A->col_ptrs[j + 1]; k++) { // k is the index of the "values" and "row_indices" that corresponds to the column j
// search row_ind[k] in rows
found = false;
found_idx = 0;
while (!found && found_idx < n_rows) {
if (A->row_indices[k] == rows->at(found_idx))
found = true;
found_idx++;
}
// store the values if the row was found in rows
if (found) { // if the row index is in the designated rows...
new_val(n) = A->values[k]; // store the value
new_row_ind(n) = found_idx - 1; // store the index where the original index was found inside "rows"
n++;
}
}
p++;
new_col_ptr(p) = n;
}
new_col_ptr(p) = n ;
// reshape the vectors to the actual number of elements
new_val.reshape(n, 1);
new_row_ind.reshape(n, 1);
return arma::sp_mat(new_row_ind, new_col_ptr, new_val, n_rows, n_cols);
}
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.