简体   繁体   中英

Can someone please explain this bit manipulation code to me?

I am new to competitive programming. I recently gave the Div 3 contest codeforces. Eventhough I solved the problem C, I really found this code from one of the top programmers really interesting. I have been trying to really understand his code, but it seems like I am too much of a beginner to understand it without someone else explaining it to me.

Here is the code.

void main(){
    int S;
    cin >> S;
    int ans = 1e9;
 
    for (int mask = 0; mask < 1 << 9; mask++) {
        int sum = 0;
        string num;
 
        for (int i = 0; i < 9; i++)
            if (mask >> i & 1) {
                sum += i + 1;
                num += char('0' + (i + 1));
            }
 
        if (sum != S)
            continue;
 
        ans = min(ans, stoi(num));
    }
 
    cout << ans << '\n';
}

The problem is to find the minimum number whose sum of digits is equal to given number S, such that every digit in the result is unique.

Eq. S = 20, Ans = 389 (3+8+9 = 20)

Mask is 9-bits long, each bit represents a digit from 1-9. Thus it counts from 0 and stops at 512. Each value in that number corresponds to possible solution. Find every solution that sums to the proper value, and remember the smallest one of them.

For example, if mask is 235, in binary it is

011101011   // bit representation of 235
987654321   // corresponding digit

==> 124678  // number for this example: "digits" with a 1-bit above
            // and with lowest digits to the left

There are a few observations:

  • you want the smallest digits in the most significant places in the result, so a 1 will always come before any larger digit.

  • there is no need for a zero in the answer; it doesn't affect the sum and only makes the result larger

This loop converts the bits into the corresponding digit, and applies that digit to the sum and to the "num" which is what it'll print for output.

    for (int i = 0; i < 9; i++)
        if (mask >> i & 1) {  // check bit i in the mask
            sum += i + 1;     // numeric sum
            num += char('0' + (i + 1));  // output as a string
        }

(mask >> i) ensures the i th bit is now shifted to the first place, and then & 1 removes every bit except the first one. The result is either 0 or 1, and it's the value of the i th bit.

The num could have been accumulated in an int instead of a string (initialized to 0, then for each digit: multiply by 10, then add the digit), which is more efficient, but they didn't.

The way to understand what a snippet of code is doing is to A) understand what it does at a macro-level, which you have done and B) go through each line and understand what it does, then C) work your way backward and forward from what you know, gaining progress a bit at a time. Let me show you what I mean using your example.

Let's start by seeing, broadly (top-down) what the code is doing:

void main(){

    // Set up some initial state
    int S;
    cin >> S;
    int ans = 1e9;
 
    // Create a mask, that's neat, we'll look at this later.
    for (int mask = 0; mask < 1 << 9; mask++) {

        // Loop state
        int sum = 0;
        string num;
 
        // This loop seems to come up with candidate sums, somehow.
        for (int i = 0; i < 9; i++)
            if (mask >> i & 1) {
                sum += i + 1;
                num += char('0' + (i + 1));
            }
 
        // Stop if the sum we've found isn't the target
        if (sum != S)
            continue;
 
        // Keep track of the smallest value we've seen so far
        ans = min(ans, stoi(num));
    }
 
    // Print out the smallest value
    cout << ans << '\n';
}

So, going from what we knew about the function at a macro level, we've found that there are really only two spots that are obscure, the two loops. (If anything outside of those are confusing to you, please clarify.)

So now let's try going bottom-up, line-by-line those loops.

// The number 9 appears often, it's probably meant to represent the digits 1-9
// The syntax 1 << 9 means 1 bitshifted 9 times.
// Each bitshift is a multiplication by 2.
// So this is equal to 1 * (2^9) or 512.
// Mask will be 9 bits long, and each combination of bits will be covered.
for (int mask = 0; mask < 1 << 9; mask++) {

  // Here's that number 9 again.
  // This time, we're looping from 0 to 8.
  for (int i = 0; i < 9; i++) {

    // The syntax mask >> i shifts mask down by i bits.
    // This is like dividing mask by 2^i.
    // The syntax & 1 means get just the lowest bit.
    // Together, this returns true if mask's ith bit is 1, false if it's 0.
    if (mask >> i & 1) {

      // sum is the value of summing the digits together
      // So the mask seems to be telling us which digits to use.
      sum += i + 1;

      // num is the string representation of the number whose sum we're finding.
      // '0'+(i+1) is a way to convert numbers 1-9 into characters '1'-'9'.
      num += char('0' + (i + 1));
    }
  }
}

Now we know what the code is doing, but it's hard to figure out. Now we have to meet in the middle - combine our overall understanding of what the code does with the low-level understanding of the specific lines of code.

We know that this code gives up after 9 digits. Why? Because there are only 9 unique non-zero values (1,2,3,4,5,6,7,8,9). The problem said they have to be unique.

Where's zero? Zero doesn't contribute. A number like 209 will always be smaller than its counterpart without the zero, 92 or 29. So we just don't even look at zero.

We also know that this code doesn't care about order. If digit 2 is in the number, it's always before digit 5. In other words, the code doesn't ever look at the number 52 , only 25 . Why? Because the smallest anagram number (numbers with the same digits in a different order) will always start with the smallest digit, then the second smallest, etc.

So, putting this all together:

void main(){
    // Read in the target sum S
    int S;
    cin >> S;

    // Set ans to be a value that's higher than anything possible
    // Because the largest number with unique digits is 987654321.
    int ans = 1e9;
 
    // Go through each combination of digits, from 1 to 9.
    for (int mask = 0; mask < 1 << 9; mask++) {
        int sum = 0;
        string num;
 
        for (int i = 0; i < 9; i++)
            // If this combination includes the digit i+1,
            // Then add it to the sum, and append to the string representation.
            if (mask >> i & 1) {
                sum += i + 1;
                num += char('0' + (i + 1));
            }
 
        // If this combination does not yield the right sum, try the next combination.
        if (sum != S)
            continue;
 
        // If this combination does yield the right sum, 
        // see if it's smaller than our previous smallest.
        ans = min(ans, stoi(num));
    }
 
    // Print the smallest combination we found.
    cout << ans << '\n';
}

I hope this helps!

The for loop is iterating over all 9-digit binary numbers and turning those binary numbers into a string of decimal digits such that if n th binary digit is on then a n +1 digit is appended to the decimal number.

Generating the numbers this way ensures that the digits are unique and that zero never appears.

But as @Welbog mentions in comments this solution to the problem is way more complicated than it needs to be. The following will be faster, and I think is clearer:

int smallest_number_with_unique_digits_summing_to_s(int s) {
    int tens = 1;
    int answer = 0;
    for (int n = 9; n > 0 && s > 0; --n) {
        if (s >= n) {
            answer += n * tens;
            tens *= 10;
            s -= n;
        }
    }
    return answer;
}

Just a quick way to on how code works.

First you need to know sum of which digits equal to S . Since each digit is unique, you can assign a bit to them in a binary number like this:

Bit number    Digit
0             1
1             2
2             3
...
8             9

So you can check all numbers that are less than 1e9 and check if sum of bits if equal to your sum based on their value. So for example if we assume S=17 :

384 -> 1 1000 0000 -> bit 8 = digit 9 and bit 7 = digit 8 -> sum of digits = 8+9=17

Now that you know sum if correct, you can just create number based on digits you found.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM