简体   繁体   中英

Count Inversions with merge sort in C++

I'm working on my first few algorithms to build my C++ skills and am currently coding up a method of counting inversions with merge sort. I've managed to get a working merge sort together but I'm having a bit of trouble keeping track of the number of inversions. Any ideas of where to go from here? How can I keep track of the number of inversions on a recursive algorithm like this? Additionally I've seen a couple different implementations of this in my internet travels and have found most people stray away from the std::vector method, any idea why? Thanks for any help, my code is below!

#include <iostream>
#include <math.h>
#include <vector>

using namespace std;

vector<int> print(vector<int> input){

    for(int i=0; i<input.size(); i++){
        cout<<input[i]<<",";
    }
    cout<<endl;
    return input;
}


vector<int> merge(vector<int> left,vector<int> right){

    //set up some varibles
    vector<int> output;
    int i=0;
    int j=0;

    //loop through the lists and merge
    while(i<left.size() && j<right.size()){

        //push the smallest of the two to the vector output
        if(left[i]<=right[j]){
            output.push_back(left[i]);
            i+=1;
        }
        if(left[i]>right[i]){
            output.push_back(right[j]);
            j+=1;
        }
    }

    //push the remnants of the vectors to output
    for(i; i<left.size(); i++){
        output.push_back(left[i]);
    }

    for(j; j<right.size(); j++){
        output.push_back(right[j]);
    }

    return output;
}//end merge

vector<int> merge_sort(vector<int> input){
    //check the size of the vector
    if(input.size()<2){
        return input;
    }

    else{

    //int new vectors
    vector<int> left;
    vector<int> right;
    vector<int> output;

    //find the middle of the input vector
    int middle=(input.size())/2;

    //build the left vector
    for(int i=0; i<middle; i++){
        left.push_back(input[i]);
    }

    //build the right vector
    for(int i=middle; i<input.size(); i++){
        right.push_back(input[i]);
    }

    //make recursive calls
    left=merge_sort(left);
    right=merge_sort(right);

    //call merge
    output=merge(left,right);


    return output;
    }
}


int main()
{
    vector<int> output;
    vector<int> input;

    input.push_back(2);
    input.push_back(1);
    input.push_back(10);
    input.push_back(4);

    output=merge_sort(input);

    print(output);


}

Good news: counting inversions is pretty easy from here.

Think about your "merge" method. Every time you put an element from the left vector into output, you are not changing its position relative to elements from the right. On the other hand, every time you add an element from the right vector, you are putting it "before" all elements still to be processed in the left vector, when it was prevously "after" them, ie creating (left.size - i) "inversions".

You can prove this easily by induction if needed.

So the answer is simply : pass an int* to your merge method, and increment it by (left.size - i) every time you push an element from the right vector.


EDIT: Working code sample

#include <iostream>
#include <vector>
// removed useless dependency math.h

using namespace std;

// void type -> does not return anything
void print (vector<int> input) {
    // range-based for loop (since C++ 11)
    // no brackets -> only one instruction in for loop
    for(int i : input)
        cout << i << ",";
}

vector<int> merge (vector<int> left, vector<int> right, int * inv_count) {
    vector<int> output;
    // multiple variable definition of the same type
    int i=0, j=0;

    // spaces around "<", after "while", before "{" for readability
    while (i < left.size() && j < right.size()) {

        // one-instruction trick again
        if (left[i] <= right[j])
            // i++ is evaluated to <previous value of i> and then increments i
            // this is strictly equivalent to your code, but shorter
            // check the difference with ++i
            output.push_back(left[i++]);
        // else because the two conditions were complementary
        else {
            output.push_back(right[j++]);
            // pointer incrementation
            *inv_count += (left.size() - i);
        }
    }

    // first field of for ommited because there is no need to initialize i
    for(; i < left.size(); i++)
        output.push_back(left[i]);

    for(; j < right.size(); j++)
        output.push_back(right[j]);

    return output;
}

vector<int> merge_sort (vector<int> input, int * inv_count) {
    // no-braces-idiom again
    // spaces around "<" and after "if" for readability
    if (input.size() < 2)
        return input;

    // no need for else keyword because of the return

    // multiple variable definition
    vector<int> left, right;

    int middle = input.size() / 2;

    // one-instruction for loop
    for(int i=0; i < middle; i++)
        left.push_back(input[i]);

    for(int i=middle; i < input.size(); i++)
        right.push_back(input[i]);

    // no need for intermediate variable
    return merge( merge_sort(left, inv_count),
                  merge_sort(right, inv_count),
                  inv_count);
}

// consistent convention : brace on the same line as function name with a space
int main () {
    // vector initialization (valid only since C++ 11)
    vector<int> input = {2, 1, 10, 4, 42, 3, 21, 7};

    int inv_count = 0;

    // No need for intermediate variables again, you can chain functions
    print( merge_sort(input, &inv_count) );

    // The value inv_count was modified although not returned
    cout << "-> " << inv_count << " inversions" << endl;
}

I modified your code to include a few usual C++ idioms. Because you used the C++14 tag, I also used tricks available only since C++11. I do not recommend using all of these tricks everywhere, they are included here because it is a good learning experience.

I suggest you read about pointers before diving deeper into C++.

Also note that this code is in no way optimal : too many intermediate vectors are created, and vectors are not useful here, arrays would be enough. But I'll leave this for another time.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM