[英]Eigen find set difference between two matrices
我想獲得兩個本征矩陣之間的差集。 代碼:
void diffMatrix(
MatrixXi &M1, // First Matrix
MatrixXi &M2, // Second Matrix
MatrixXi &M3, // Matrix set difference
VectorXi &I3 // Matrix set difference indices
)
{
// find rows in first matrix that aren't in second matrix
// cols of M1 = M2
assert(M1.cols() == M2.cols());
M3.resize(M1.rows(), M1.cols());
I3.resize(M1.rows());
bool m2r_nonex;
size_t k = 0;
// get M1 rows
for (size_t i = 0; i < M1.rows(); i++)
{
m2r_nonex = true;
auto m1r = M1.row(i);
// NOTE: this is slow
// check M1 row is in M2
for (size_t j = 0; j < M2.rows(); j++)
{
auto m2r = M2.row(j);
if (m1r == m2r)
m2r_nonex = false;
}
// if it's not in m2, add it to M3
if (m2r_nonex)
{
M3.row(k) = m1r;
I3(k) = i;
k++;
}
}
M3.conservativeResize(k, NoChange);
I3.conservativeResize(k, NoChange);
}
MatrixXi M1, M2, M3;
VectorXi I3;
M1.resize(3, 3);
M2.resize(2, 3);
M1 << 0, 0, 0, 1, 1, 1, 2, 2, 2;
M2 << 1, 1, 1, 2, 2, 2;
diffMatrix(M1, M2, M3, I3);
===========================================
M3 (Rows: 1 Cols: 3)
===========================================
[[0, 0, 0]]
提供的代碼當然可以工作,但是速度很慢。 理想情況下,人們會將內部 for 循環替換為一些更緊湊的表達式,也許可以在單個語句中計算 M1 行在 M2 行中的所有出現次數……這可能嗎?
=====編輯=====
(根據 Homer512 的回答)
是的,它有效。 了解性能提升...加載 10000 條記錄后:
MatrixXi M1, M2, M3;
VectorXi I3;
size_t rows = 10000;
M1.resize(rows, 3);
M2.resize(rows, 3);
for (size_t i = 0; i < rows; i++)
{
M1(i,0) = i;
M1(i,1) = i;
M1(i,2) = i;
M2(i,0) = i + 1;
M2(i,1) = i + 1;
M2(i,2) = i + 1;
}
第一種方法建議經過的時間是:43.826125 秒; 提議的第二種方法經過的時間是:0.017632 秒
它快了幾個數量級......謝謝你。
這是一個應該相對有效的版本。
我們從轉置輸入矩陣開始。 您想要比較行。 但是Eigen 組織它的矩陣列優先,這意味着一行中的連續元素不會連續存儲在 memory 中。這使得你做的每行都很慢,而你做的每列很快。 理想情況下,您希望跳過此步驟並簡單地從正確的方向開始。 或者切換到行主矩陣,如鏈接的 Eigen 文檔中所述。
const Eigen::MatrixXi left_transp = M1.transpose();
const Eigen::MatrixXi right_transp = M2.transpose();
現在是節目的真正明星:我們想使用 hash 集合來檢查集合中是否有元素。 我們有一個可以使用的 hash 集合,即std::unordered_set
,但我們需要定義一個合適的鍵來表示向量,理想情況下無需復制數據。
值得慶幸的是 C++17 介紹了std::string_view
和相關的類型定義,包括std::u32string_view
; 非常適合二進制 int32 數據。 這些類型帶有std::hash
專業化和比較。
此時我們可以為M2中的所有條目建一個hash表。
std::unordered_set<std::u32string_view> right_set;
right_set.reserve(static_cast<std::size_t>(right_transp.cols()));
const std::size_t rows = static_cast<std::size_t>(right_transp.rows());
for(auto col: right_transp.colwise()) {
static_assert(sizeof(char32_t) == sizeof(int));
assert(col.innerStride() == 1);
const char32_t* ptr = reinterpret_cast<const char32_t*>(col.data());
right_set.emplace(ptr, rows);
}
在下一步中,我們可以對 M1 執行相同的操作,並在 hash 表中搜索重復項。
const Eigen::Index left_n = left_transp.cols();
I3.resize(left_n);
Eigen::Index count = 0;
for(Eigen::Index i = 0; i < left_n; ++i) {
const char32_t* ptr = reinterpret_cast<const char32_t*>(
left_transp.col(i).data());
const std::u32string_view view{ ptr, rows };
if(! right_set.count(view))
I3[count++] = i;
}
I3.conservativeResize(count);
剩下的就是使用索引來構建 M3。 同樣,最好不要移調。
M3.resize(static_cast<Eigen::Index>(rows), count);
for(Eigen::Index i = 0; i < count; ++i)
M3.col(i) = left_transp.col(I3[i]);
M3.transposeInPlace();
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.