简体   繁体   中英

Effective backtracking algorithm

I have this code:

#include <stdio.h>
#include <stdlib.h>
#include <limits.h>

long result = LONG_MAX;

void swap(int *a, int *b) {
    int temp = *a;
    *a = *b;
    *b = temp;
}

int hasConnection(int *array, int arrayIndex, int maxBound, int *rules, int size) {
    int connection;
    for (int i = 0; i < (size - 1)*2; i++) {
        if (rules[i] == array[arrayIndex]) {
            if (i % 2) {
                connection = rules[i - 1];
            } else {
                connection = rules[i + 1];
            }
            for (int j = maxBound - 1; j >= 0; j--) {
                if (array[j] == connection) {
                    return 1;
                }
            }
        }
    }
    return 0;
}

int isCrossed(int *array, int outConnectionIndex, int inConnectionIndex, int *rules, int size) {
    for (int i = inConnectionIndex + 1; i < outConnectionIndex; i++) {//sweep trough indexes in between
        if (hasConnection(array, i, inConnectionIndex, rules, size)) {//array[i] has connection with index lower than inConnectionIndex
            return 1;
        }
    }
    return 0;
}

int isWiredInsideAndCrossed(int *array, int arrayIndex, int *rules, int size) {
    int connection;
    for (int i = 0; i < 2 * (size - 1); i++) {
        if (rules[i] == array[arrayIndex]) {
            if (i % 2) {
                connection = rules[i - 1];
            } else {
                connection = rules[i + 1];
            }
            for (int j = 1; j < arrayIndex - 1; j++) {
                if (array[j] == connection) {
                    if (isCrossed(array, arrayIndex, j, rules, size)) {
                        return 1;
                    }
                }
            }
        }
    }
    return 0;
}

void trySequence(int * array, int size, int *priceMap, int *rules) {
    int ret = 0;
    for (int i = 0; i < size; i++) {
        ret = ret + priceMap[i * size + array[i]];
        if (ret >= result || isWiredInsideAndCrossed(array, i, rules, size)) {
            return;
        }
    }
    result = ret;
}

void permute(int *array, int i, int size, int *priceMap, int *rules) {
    if (size == i) {
        trySequence(array, size, priceMap, rules);
        return;
    }
    int j = i;
    for (j = i; j < size; j++) {
        swap(array + i, array + j);
        permute(array, i + 1, size, priceMap, rules);
        swap(array + i, array + j);
    }
    return;
}

int main(int argc, char** argv) {    
    int size;
    fscanf(stdin, "%d", &size);
    int *priceMap = malloc(sizeof (int)*size * size);
    int *rules = malloc(sizeof (int)*(size - 1)*2);
    int i = 0;
    int squaredSize = size*size;
    while (i < squaredSize) {
        scanf("%d", priceMap + i);
        i++;
    }
    i = 0;
    int rulesSize = (size - 1)*2;
    while (i < rulesSize) {
        scanf("%d", rules + i);
        i++;
    }
    int arrayToPermute [size];
    for (int j = 0; j < size; j++) {
        arrayToPermute[j] = j;
    }
    permute(arrayToPermute, 0, size, priceMap, rules);
    printf("%ld\n", result);
    return (EXIT_SUCCESS);
}

It's supposed to do this

EDIT: Basicaly the task is to find the cheapest combination of N devices placed in N slots on the edge of a circle, on the input there is a cost matrix that tells what each device would cost to install in each slot, devices are also connected to each other by wires that mustn't cross. There are always N-1 connections. Details and examples are on the link above.

My problem is that my solution is way too slow. I need it to solve a drill head of size 13 generally in less than two seconds. Too be precise: I need this input solved in less than 1s:

12
 27 25 21 27 25 30 27 26 22 28 27 26
 21 22 26 30 25 28 21 21 22 23 22 30
 20 21 22 20 30 30 30 22 30 26 23 26
 27 30 24 21 20 24 26 24 22 22 24 22
 29 26 20 29 22 23 27 28 23 28 30 27
 21 21 20 30 20 22 25 29 22 29 27 24
 26 21 30 24 23 23 29 29 29 28 23 22
 25 27 21 24 20 24 27 23 27 28 25 26
 26 27 23 27 23 27 29 30 25 24 20 23
 20 22 25 20 23 26 20 29 21 24 25 20
 27 28 25 20 25 22 26 23 24 21 26 23
 23 21 28 23 26 30 22 30 25 26 26 20
0 1
1 2
2 3
3 4
4 5
5 6
6 7
7 8
8 9
9 10
10 11

(solution is 262) and this in less than 2s:

13
 52  9 42 65 54 47 16 62 35 47 63  2 48
 25  4 12 25 58 12 45 62 70 60 40 17 33
 28 64 64 62  1 28  3 26 56 15 59 64 17
  7 23 70 20 57 70 46  5  6  1 21 12 40
 62 53  5 15 22 43 57 15 26 42 51 16 38
 20 13 64  3 51 22 28  1 18 27  4 36  9
 11 20 41 65 29 63 54 28 31 63 27 59 41
 44 21 42 16 59 10 60 11  3 53 52 53 37
 41 51 18  4 38  6 22 49 15 51 54 61  7
 54  6  5 24 47 35 46 11 26 17 53 37 25
 34 42  6 54 40 47 59 25 53 53 37  9 64
 69 63 68  5 37 16 17 61 33 51 19 39 44
  6 47  4  6 21 17 23 24 13 29 34 54 33
0 1
0 2
0 3
1 4
1 5
1 6
2 7
2 8
2 9
3 10
3 11
3 12

.(solution is 165) So I'm wondering if anybody would have an idea how to do this, I'm completely lost?

Well you only asked for an idea : Instead of all n! permutations, you gotta use backtracking to ignore permutations that will surely make wires cross.

ie u have permutation 1 2 3 4 .... 11, and 1 2 3 part already makes wires cross, so u can ignore all permutations for 4 .... 11 part

Here`s some pseudo-code for implementation details:

int n;              // devices
int cost[n][n];     // cost for putting device i into slot j
bool used[n]={0};   // we need to keep track of used devices
int slots[n];       // tracks which device is in which slot
int edges[n-1];     //edges of a tree
int ats = INF;

bool cross(); //function that checks if any devices cross
              //it seems you already wrote something similar, so i skipped this

solve(int x, int total)     //function that tries putting remaining blocks into slot x
{
if(x==n) //all slots are filled means we`re done
{
ats = min(ats,total);
return;
}

if(total > ats)     // some pruning optimization for some speedup
return;             // cause no matter what we do we won`t be able to beat this cost


for(int i=0; i<n; i++)
if(!used[i])                    //if device is not used and
                                //we can try putting it into our slot
{
    slot[x] = i;
    used[i] = true;
    if(!cross())                //if putting device i into slot x makes some lines cross, skip it
    solve(x+1,total + cost[i][x]);
    used[i] = false;
}
}

main()
{

for (int i=0; i<n; i++) //try all devices into slot 0
{
used[i] = true;
slot[i]=0;
solve(1,cost[i][0]);
used[i] = false;
}

print(ats);

}

Your program generates all possible permutations and then tests whether the permutation is valid. The test itself involves nested loops. You could try optimize your data structures, so that your checks are more efficient or you could try to find dead ends in the search space early, as Photon suggests.

A more efficient approach might be to design the program so that only valid permutations are created.This reduces the search space and also gets rid of the tests.

If you look at the example in the problem description, the wrie network is an acyclic graph:

                            5          
                            |         
                        1   6
                        |   |
                    0---2---7
                        |   |
                        3   8               
                        |
                        4

If you start with tool 0 and put it in slot 0, the next step is to place a permutation of tool 2 and its "descendants", that is of tools 1, 7 and 3 and its respective connected tools. From the perspective of tool 0, you can turn this into a tree:

                  [1237]
                / /  |  \
              1  2  [34] [678]
                   / |    |  \ \
                  3  4   [56] 7  8
                          | \
                          5  6

Here, the leaves correspond to just a single tool. The branches have several tools. All prumutations of the branches form valid arrangements.

0 (1 2 (3 4) ((5 6) 7 8))

Of course, every permutation of [56] must be combined with every permutation of the other branches. You can achieve this by implementing a kind of odometer that doesn't go through the numbers from 0 to 9 in each space, but through the possible permutations of the branches.

Every generated permutation is valid, but this technique doesn't create all possible permutations yet. Remember that we've fixed tool 0 in slot 0, but that needn't be the case. But the topography of valid layout can be used to generate 8 other layouts by rotating it and putting tool 0 in slots 1, 2, and so on.

This technique reduces you search space from 9! to 9 · 4! · 2! · 3! · 2! or by a factor of 70. And there are no tests, but at the cost of a more complicated data structure.

(The reduction is extreme in your 12-tool example, where the network of wires ir really just a straight line without bifurcations.)

This code implements the described technique:

#include <stdio.h>
#include <stdlib.h>
#include <limits.h>



enum {
    N = 16                          // hardcoded max. size
};

struct tool {
    int conn[N - 1];                // wire connections
    int nconn;

    int desc[N];                    // "descendants" of the tree node
    int ndesc;

    int cost[N];                    // row of the cost matrix
    int used;                       // flag for recursive descent
};

struct drill {
    int n;
    struct tool tool[N];

    int root;                       // root node    
    int branch[N];                  // indices of branch nodes
    int nbranch;                    // permutating branches

    int opt;                        // current optimum
};

void swap(int a[], int i, int j)
{
    int s = a[i]; a[i] = a[j]; a[j] = s;
}

void reverse(int a[], int i, int n)
{
    while (i < --n) swap(a, i++, n);
}

/*
 *      Turn an array to the next higher permutation. When the
 *      permutation is already the highest, return 0 and reset the
 *      array to the smalles permutation. Otherwise, return 1.
 */
int next_perm(int a[], int n)
{
    int i = n - 1;
    int k = n - 1;

    if (n < 2) return 0;

    while (k && a[k] < a[k - 1]) k--;
    if (k == 0) {
        reverse(a, 0, n);
        return 0;
    }
    k--;

    while (i > k && a[i] < a[k]) i--;
    swap(a, i, k);
    reverse(a, k + 1, n);

    return 1;
}

/*
 *      Insertion sort for sorting the branches at the beginning.
 */
void sort(int a[], int len)
{
    for (int i = 1; i < len; i++) {
        int k = i;

        while (k > 0 && a[k] < a[k - 1]) {
            swap(a, k, k - 1);
            k--;
        }
    }
}

/*
 *      Determine the list of descendants for each node.
 */
void descend(struct drill *dr, int n)
{
    struct tool *t = dr->tool + n;

    t->ndesc = 1;
    t->desc[0] = n;

    t->used = 1;

    for (int i = 0; i < t->nconn; i++) {
        int m = t->conn[i];

        if (dr->tool[m].used == 0) {
            t->desc[t->ndesc++] = m;
            descend(dr, m);
        }
    }

    if (t->ndesc > 1) {
        sort(t->desc, t->ndesc);
        dr->branch[dr->nbranch++] = n;
    }

    t->used = 0;
}

/*
 *      Fill the array a with the current arrangement in the tree.
 */
int evaluate(struct drill *dr, int a[], int n)
{
    struct tool *t = dr->tool + n;
    int m = 0;

    if (n == dr->root) {
        a[0] = dr->root;
        return 1 + evaluate(dr, a + 1, dr->tool[n].conn[0]);
    }

    for (int i = 0; i < t->ndesc; i++) {
        int d = t->desc[i];

        if (d == n) {
            a[m++] = d;
        } else {
            m += evaluate(dr, a + m, d);
        }
    }

    return m;
}

/*
 *      Evaluate all possible permutations and find the optimum.
 */
void optimize(struct drill *dr)
{
    dr->opt = (1u << 31) - 1;

    for (;;) {
        int i = 0;
        struct tool *t = dr->tool + dr->branch[0];

        for (int j = 0; j < dr->n; j++) {
            int a[2 * N];
            int cost = 0;

            evaluate(dr, a, dr->root);

            for (int i = 0; i < dr->n; i++) {
                int k = (i + j) % dr->n;

                cost += dr->tool[i].cost[a[k]];
            }

            if (cost < dr->opt) dr->opt = cost;
        }

        while (next_perm(t->desc, t->ndesc) == 0) {
            i++;

            if (i == dr->nbranch) return;
            t = dr->tool + dr->branch[i];            
        }
    }
}

/*
 *      Read and prepare drill data, then optimize.
 */
int main(void)
{
    struct drill dr = {0};

    fscanf(stdin, "%d", &dr.n);

    for (int j = 0; j < dr.n; j++) {
        for (int i = 0; i < dr.n; i++) {
            scanf("%d", &dr.tool[j].cost[i]);
        }
    }

    for (int i = 1; i < dr.n; i++) {
        int a, b;

        scanf("%d", &a);
        scanf("%d", &b);

        dr.tool[a].conn[dr.tool[a].nconn++] = b;
        dr.tool[b].conn[dr.tool[b].nconn++] = a;
    }

    while (dr.tool[dr.root].nconn > 1) dr.root++;
    dr.tool[dr.root].used = 1;

    descend(&dr, dr.tool[dr.root].conn[0]);
    optimize(&dr);

    printf("%d\n", dr.opt);

    return 0;
}

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM