简体   繁体   中英

3 X 3 magic square recursively

I'm trying to find all possible solutions to the 3X3 magic square. There should be exactly 8 solutions.

My code gets them all but there are a lot of repeats. I'm having a hard time tracking the recursive steps to see why I'm getting all the repeats.

// This program finds all solutions to the magic square for a 3X3     
// square where each column, row and diagonal sum is equal

#include <iostream>
using namespace std;
#define SQUARE_SIZE 9

int anyLine = 0;
int currLine = 0;
int numSolutions = 0;


// swap two values in the square.
void swap(int arr[], int idxa, int idxb)
{
    int tmp = arr[idxa];
    arr[idxa] = arr[idxb];
    arr[idxb] = tmp;
}

void printArray(int arr[])
{
    for (int i = 0; i < SQUARE_SIZE; i++)
    {
        cout << arr[i] << " ";
        if ((i + 1) % 3 == 0)
            cout << endl;
    }
    cout << endl;
}

// this function tests to see if we have a "good" arrangement of numbers            
// i.e the sum of each row, column and diagonal is equal

bool checkArr(int arr[])
{
    anyLine = arr[0] + arr[1] + arr[2];
    currLine = 0;
    for (int i = 0; i < SQUARE_SIZE; i++)
    {
        currLine += arr[i];
        if ((i + 1) % 3 == 0)
        {
            if (currLine != anyLine)
                return false;
            currLine = 0;
        }
    }

    // check vertically
    for (int col = 0; col <3; col++)
    {
        for (int row = 0; row <3; row++)
        {
            currLine += arr[col + 3 * row];
        }

        if (currLine != anyLine)
            return false;

        currLine = 0;
    }

    // check the diagonals
    if ((arr[2] + arr[4] + arr[6]) != anyLine)
        return false;

    if ((arr[0] + arr[4] + arr[8]) != anyLine)
        return false;

    return true;
}

void solve(int arr[], int pos)
{
    if (pos == 8)
    {
        if (checkArr(arr))
        {
            printArray(arr);
            numSolutions++;
        }
    } else 
    {
        for (int i = 0; i < 9; i++)
        {
            if (i == pos) continue;

            if (checkArr(arr))
            {
                printArray(arr);
                numSolutions++;
            }
            swap(arr, pos, i);
            solve(arr, pos + 1);
        }
    }
}

int main()
{
    int arr[SQUARE_SIZE] = { 1, 2, 3, 4, 5, 6, 7, 8, 9 };

    solve(arr, 0);
    cout << "number of solutions is: " << numSolutions << endl;

    return 0;
}

Basically, you are finding all permutations of the array using a recursive permutation algorithm .

There are 4 things you need to change:

First, start your loop from pos, not 0

Second, swap elements back after recursing (backtracking)

Third, only test once you have generated each complete permutation (when pos = 8), or else you will be testing the same permutations more than once.

Fourth, swapping an element with itself (ie not swapping it) is a valid permutation, because the elements are allowed to stay in their original positions.

void solve(int arr[], int pos)
{
    if (pos == 8)
    {
        if (checkArr(arr))
        {
            printArray(arr);
            numSolutions++;
        }
    }
    else
    {
        for (int i = pos ; i < 9; i++)
        { 
            swap(arr,pos,i);
            solve(arr,pos +1);
            swap(arr,pos,i); 
        }
    }
}

Demo

Your code calls printArray from two places - the base case of the recursion (ie when pos == 8 ) and in the loop before calling swap . The second call is unnecessary: you would get the same square when you reach the pos == 8 state.

This brings the number of duplicates down, but it does not eliminate them because of the way in which you generate your squares. You need to keep track of what has been printed. One way to do it is to make a set of solutions that you have found, and check it before printing the newly found solution:

set<int> seen;

int key(int arr[]) {
    return arr[0]
    + 10 * arr[1]
    + 100 * arr[2]
    + 1000 * arr[3]
    + 10000 * arr[4]
    + 100000 * arr[5]
    + 1000000 * arr[6]
    + 10000000 * arr[7]
    + 100000000 * arr[8];
}

 void printArray(int arr[]) {
    if (!seen.insert(key(arr)).second) {
        // second is set to false when a duplicate is found
        return;
    }
    numSolutions++;
    for (int i = 0; i < SQUARE_SIZE; i++) {
        cout << arr[i] << " ";
        if((i+1) % 3 == 0)
            cout << endl;
    }
    cout << endl;
 }

Demo.

A few things to note about the solution above:

  • key(int[]) converts the square to a single decimal number, so this approach is going to work only for squares composed of decimal digits. You would need a different strategy for arbitrary numbers - for example, using a set of comma-separated strings.
  • Counting of solutions is moved to printArray(int[]) . You could drop numSolutions altogether, and use seen.size() instead; it provides the same answer.

If you don't want to actually solve this recursively for exercise purposes, I'd recommend using std::next_permutation :

void solve(int(&arr)[SQUARE_SIZE], int pos)
{
    sort(std::begin(arr), std::end(arr));
    do {
        if (checkArr(arr))  { 
            numSolutions++;
            printArray(arr);
        }
    } while (next_permutation(begin(arr), end(arr)));   
}

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM