簡體   English   中英

C++ 位集算法

[英]C++ Bitset algorithm

我得到一個填充有 1 或 0 的 nxn 網格。我想計算角塊全為 1 的子網格的數量。 我的解決方案遍歷所有行對並計算匹配 1 的數量,然后使用公式 numOf1s * (numOf1s-1)/2 並添加到結果中。 但是,當我在https://cses.fi/problemset/task/2137上提交我的解決方案時,在 n = 3000 的輸入上沒有 output (可能是由某些錯誤引起的)。 錯誤可能是什么?

int main()
    {
    
        int n; cin>> n;
        vector<bitset<3000>> grid(n);
        for(int i=0;i<n;i++){
            cin >> grid[i];
        }
        long result = 0;
        for(int i=0;i<n-1;i++){
            for(int j=i+1;j<n;j++){
                int count = (grid[i]&grid[j]).count();
                result += (count*(count-1))/2;
            }
        }
        cout << result;
    }

此解決方案將導致超出時間限制。 在最壞的情況下,bitset::count() 是 O(n)。 您的代碼的總復雜度為 O(n^3)。 在最壞的情況下,操作數將是 3000^3 > 10^10,這太大了。

我不確定這個解決方案是你能想出的最好的解決方案,但它基於原始解決方案,並帶有 bitset 的自制替代方案。 這允許我使用 64 位塊,並使用快速popcnt() 硬件版本會更好,因為它可以與 AVX 寄存器一起使用,但這應該更便攜,並且可以在cses.fi上運行。 function count_common()基本上不是生成一個長的交集位集然后計算一個的數量,而是制作一個交集並立即使用它來計算一個。

stream 提取器可能會得到改進,從而節省更多時間。

#include <iostream>
#include <array>
#include <cstdint>
#include <climits>
 
uint64_t popcnt(uint64_t v) {
    v = v - ((v >> 1) & (uint64_t)~(uint64_t)0 / 3);
    v = (v & (uint64_t)~(uint64_t)0 / 15 * 3) + ((v >> 2) & (uint64_t)~(uint64_t)0 / 15 * 3);
    v = (v + (v >> 4)) & (uint64_t)~(uint64_t)0 / 255 * 15;
    uint64_t c = (uint64_t)(v * ((uint64_t)~(uint64_t)0 / 255)) >> (sizeof(uint64_t) - 1) * CHAR_BIT;
    return c;
}
 
struct line {
    uint64_t cells_[47] = { 0 }; // 3000/64 = 47
 
    uint64_t& operator[](int pos) { return cells_[pos]; }
    const uint64_t& operator[](int pos) const { return cells_[pos]; }
};
 
uint64_t count_common(const line& a, const line& b) {
    uint64_t u = 0;
    for (int i = 0; i < 47; ++i) {
        u += popcnt(a[i] & b[i]);
    }
    return u;
}
 
std::istream& operator>>(std::istream& is, line& ln) {
    is >> std::ws;
    int pos = 0;
    uint64_t val = 0;
    while (true) {
        char ch = is.get();
        if (is && ch == '\n') {
            break;
        }
        if (ch == '1') {
            val |= 1LL << (63 - pos % 64);
        }
        if ((pos + 1) % 64 == 0) {
            ln[pos / 64] = val;
            val = 0;
        }
        ++pos;
    }
    if (pos % 64 != 0) {
        ln[pos / 64] = val;
    }
    return is;
}
 
struct grid {
    int n_;
    std::array<line, 3000> data_;
 
    line& operator[](int r) {
        return data_[r];
    }
};
 
std::istream& operator>>(std::istream& is, grid& g) {
    is >> g.n_;
    for (int r = 0; r < g.n_; ++r) {
        is >> g[r];
    }
    return is;
}
 
int main()
{
    grid g;
    std::cin >> g;
 
    uint64_t count = 0;
    for (int r1 = 0; r1 < g.n_; ++r1) {
        for (int r2 = r1 + 1; r2 < g.n_; ++r2) {
            uint64_t n = count_common(g[r1], g[r2]);
            count += n * (n - 1) / 2;
        }
    }
    std::cout << count << '\n';
    return 0;
}

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM