简体   繁体   中英

AVL Tree Balancing - C

I've been trying to write up a simple AVL Tree implementation in C. It supports duplicate values as well. Everything seems to work fine but every now and then I get a poorly balanced tree. To me, the rotation functions seem to be working fine like they should. I'm thinking there is a problem with the height checks but I can't seem to find the problem.

The tree I get just from the inserts is unbalanced, so the insert is problematic. Then, before this, after deletion the tree is usually poorly balanced. It is sometimes balanced properly though, which I can't seem to identify how.

The code for this implementation is as follows:

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <time.h>
#include <math.h>

#define SPACE_PER_NODE 2
#define MAX(x, y) (x) > (y) ? (x) : (y)


enum delete_flags {
    DELETE_NO_FORCE,
    DELETE_FORCE
};

typedef unsigned int uint;

struct tree_elem {
    int data;
    uint dup_count;
    int height;
    struct tree_elem* left;
    struct tree_elem* right;
};

typedef struct tree_elem node;

node* create_bst();
void insert(node**, int);
void delete_elem(node**, int, uint);
node* search(node*, int);
node* get_parent(node*, node*);
node* find_min(node*);
node* get_successor(node*, node*);
uint max_depth(node*);
void display_helper(node*, int);
void display_tree(node*);
int get_height(node*);
void rotate_once_left(node**);
void rotate_once_right(node**);
void rotate_twice_left(node**);
void rotate_twice_right(node**);

void* s_malloc (const uint t) {
    void* p = malloc(t);
    if(!p) {
        printf("Out of memory.\n");
        exit(EXIT_FAILURE);
    }
    return p;
}

void s_free (void* p) {
    if(!p) {
        printf("Error: Tried to free NULL ptr.\n");
        exit(EXIT_FAILURE);     
    }
    else
        free(p);
}

node* create_bst(int data) {
    node* tree = (node*) s_malloc(sizeof(node));
    tree->left = tree->right = NULL;
    tree->data = data;
    return tree;
}

void insert(node** t, int val) {
    if(!(*t)) {
        *t = (node*) s_malloc(sizeof(node));
        (*t)->data = val;
        (*t)->left = (*t)->right = NULL;
        (*t)->dup_count = 0;
        (*t)->height = 0;
        return;
    }
    if((*t)->data < val) {
        insert(&(*t)->right, val);

        if(get_height((*t)->right) - get_height((*t)->left) >= 2) {
            if((*t)->right->data < val)
                rotate_once_right(&(*t));
            else if((*t)->right->data > val)
                rotate_twice_right(&(*t));
        }
    }
    else if((*t)->data > val) {
        insert(&(*t)->left, val);

        if(get_height((*t)->left) - get_height((*t)->right) >= 2) {
            if((*t)->left->data > val)
                rotate_once_left(&(*t));
            else if((*t)->left->data < val)
                rotate_twice_left(&(*t));
        }
    }
    else {
        ++(*t)->dup_count;
        return;                                                             // this is important! if there are duplicates, they might cause an unwanted height change!
    }
    (*t)->height = MAX(get_height((*t)->left), get_height((*t)->right)) + 1;
}

node* get_successor(node* t, node* s) {
    if(s->right)
        return find_min(s->right);
    node* suc = NULL;
    node* temp = t;
    // Start from root and search for successor in the tree
    while (temp) {
        if (s->data < temp->data) {
            suc = temp;
            temp = temp->left;
        }
        else if (s->data > temp->data)
            temp = temp->right;
        else
           break;
    }
    return suc;
}

void free_tree (node* t) {
    if (!t)
        return;
    free_tree(t->left);
    free_tree(t->right);
    free(t);
}

node* search(node* t, int val) {
    if(!t)
        return NULL;
    if(t->data == val)
        return t;
    else if(t->data < val)
        return search(t->right, val);
    return search(t->left, val);
}

node* find_min(node* t) {
    node* temp = t;
    while(temp->left)
        temp = temp->left;
    return temp;
}

uint max_depth(node* t) {
   if (!t)
       return 0;
   int ldepth = max_depth(t->left);
   int rdepth = max_depth(t->right);
   if (ldepth > rdepth)
       return ldepth + 1;
   return rdepth + 1;
}

void display_helper(node* t, int spaces) {
    int width = ceil(log10(max_depth(t)+0.01)) + 2;
    wchar_t* sp64 = L"                                                                ";
    if (!t) {
        wprintf(L"\n");
        return;
    }
    display_helper(t->right, spaces + width);
    wprintf(L"%*.*s%d\n", 0, spaces, sp64, t->data);
    display_helper(t->left, spaces + width);
}

void display_tree(node* t) {
    if(t)
        display_helper(t, SPACE_PER_NODE);
}

int get_height(node* t) {
    if(!t)
        return 0;
    return t->height;
}

void rotate_once_left(node** k1) {
    node* temp = (*k1)->left;
    (*k1)->left = temp->right;
    temp->right = *k1;

    (*k1)->height = MAX(get_height((*k1)->left), get_height((*k1)->right)) + 1;
    temp->height = MAX(get_height(temp->left), (*k1)->height) + 1;

    *k1 = temp;
}


void rotate_once_right(node** k1) {
    node* temp = (*k1)->right;
    (*k1)->right = temp->left;
    temp->left = *k1;

    (*k1)->height = MAX(get_height((*k1)->left), get_height((*k1)->right)) + 1;
    temp->height = MAX(get_height(temp->right), (*k1)->height) + 1;

    *k1 = temp;
}

void rotate_twice_left(node** k1) {
    rotate_once_right(&(*k1)->left);
    rotate_once_left(k1);
}

void rotate_twice_right(node** k1) {
    rotate_once_left(&(*k1)->right);
    rotate_once_right(k1);
}

int main() {
    srand(time(NULL));
    node* tree = create_bst(rand() % 15 + 1);
    for(uint i = 0; i < 14; ++i) {
        int elem;
        // create unique elements from 1 to 20.
        do {
            elem = rand() % 15 + 1;
        } while (search(tree, elem));
        insert(&tree, elem);
    }
    display_tree(tree);
    int input;
    do {
        printf("Enter value to delete: ");
        scanf("%d", &input);
        delete_elem(&tree, input, DELETE_NO_FORCE);
        display_tree(tree);
    } while(input != -1);
    return 0;
}

One place to look is your MAX macro.

MAX(get_height((*t)->left), get_height((*t)->right)) + 1;

probably does not compute what you think it does.

In this day and age, when compilers inline with such great aplomb, you shouldn't use a macro for this computation. It's not only incorrect, it's almost certainly less efficient.

And I'll ditto here what I said in the comment: You should strongly consider test driven development. Write a predicate that checks the AVL conditions are met for a given tree, including that it's a valid BST. Now add items to an empty tree and run the predicate after each. When it reports the tree is not AVL, you'll be able to see what went wrong. When it doesn't, you'll have more confidence your code is working as intended.

Edit

Okay, expand the macro by hand (adding some whitespace):

(get_height((*t)->left)) > (get_height((*t)->right)) 
    ? (get_height((*t)->left))
    : (get_height((*t)->right)) + 1;

The + 1 is affecting only the else branch. You'd need an additional set of parentheses to get the right answer.

Moreover, the heights are being computed twice. With a function, it would only happen once. Admittedly an aggressive optimizer would probably eliminate the duplicate computations, too, but that optimization is considerably more elaborate and therefore fragile than merely inlining a max() function. Why use a macro to make the compiler's job harder ?

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM