[英]C++ Bitset algorithm
我得到一個填充有 1 或 0 的 nxn 網格。我想計算角塊全為 1 的子網格的數量。 我的解決方案遍歷所有行對並計算匹配 1 的數量,然后使用公式 numOf1s * (numOf1s-1)/2 並添加到結果中。 但是,當我在https://cses.fi/problemset/task/2137上提交我的解決方案時,在 n = 3000 的輸入上沒有 output (可能是由某些錯誤引起的)。 錯誤可能是什么?
int main()
{
int n; cin>> n;
vector<bitset<3000>> grid(n);
for(int i=0;i<n;i++){
cin >> grid[i];
}
long result = 0;
for(int i=0;i<n-1;i++){
for(int j=i+1;j<n;j++){
int count = (grid[i]&grid[j]).count();
result += (count*(count-1))/2;
}
}
cout << result;
}
此解決方案將導致超出時間限制。 在最壞的情況下,bitset::count() 是 O(n)。 您的代碼的總復雜度為 O(n^3)。 在最壞的情況下,操作數將是 3000^3 > 10^10,這太大了。
我不確定這個解決方案是你能想出的最好的解決方案,但它基於原始解決方案,並帶有 bitset 的自制替代方案。 這允許我使用 64 位塊,並使用快速popcnt()
。 硬件版本會更好,因為它可以與 AVX 寄存器一起使用,但這應該更便攜,並且可以在cses.fi
上運行。 function count_common()
基本上不是生成一個長的交集位集然后計算一個的數量,而是制作一個交集並立即使用它來計算一個。
stream 提取器可能會得到改進,從而節省更多時間。
#include <iostream>
#include <array>
#include <cstdint>
#include <climits>
uint64_t popcnt(uint64_t v) {
v = v - ((v >> 1) & (uint64_t)~(uint64_t)0 / 3);
v = (v & (uint64_t)~(uint64_t)0 / 15 * 3) + ((v >> 2) & (uint64_t)~(uint64_t)0 / 15 * 3);
v = (v + (v >> 4)) & (uint64_t)~(uint64_t)0 / 255 * 15;
uint64_t c = (uint64_t)(v * ((uint64_t)~(uint64_t)0 / 255)) >> (sizeof(uint64_t) - 1) * CHAR_BIT;
return c;
}
struct line {
uint64_t cells_[47] = { 0 }; // 3000/64 = 47
uint64_t& operator[](int pos) { return cells_[pos]; }
const uint64_t& operator[](int pos) const { return cells_[pos]; }
};
uint64_t count_common(const line& a, const line& b) {
uint64_t u = 0;
for (int i = 0; i < 47; ++i) {
u += popcnt(a[i] & b[i]);
}
return u;
}
std::istream& operator>>(std::istream& is, line& ln) {
is >> std::ws;
int pos = 0;
uint64_t val = 0;
while (true) {
char ch = is.get();
if (is && ch == '\n') {
break;
}
if (ch == '1') {
val |= 1LL << (63 - pos % 64);
}
if ((pos + 1) % 64 == 0) {
ln[pos / 64] = val;
val = 0;
}
++pos;
}
if (pos % 64 != 0) {
ln[pos / 64] = val;
}
return is;
}
struct grid {
int n_;
std::array<line, 3000> data_;
line& operator[](int r) {
return data_[r];
}
};
std::istream& operator>>(std::istream& is, grid& g) {
is >> g.n_;
for (int r = 0; r < g.n_; ++r) {
is >> g[r];
}
return is;
}
int main()
{
grid g;
std::cin >> g;
uint64_t count = 0;
for (int r1 = 0; r1 < g.n_; ++r1) {
for (int r2 = r1 + 1; r2 < g.n_; ++r2) {
uint64_t n = count_common(g[r1], g[r2]);
count += n * (n - 1) / 2;
}
}
std::cout << count << '\n';
return 0;
}
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.