简体   繁体   中英

Binary search tree cumsum

Problem : Given a binary search tree in which the keys are numbers, we'll define the operation 'cumsum' ( shorthand for cumulative sum ) that switches the key of every node in the tree with the sum of all the keys that are smaller or equal to it.
For example,
在此处输入图像描述

In this example,
The key 5 in the root is switched to the value 10: the sum of the original key in the root ( which is 5 ) and all smaller keys than it ( which are 2 and 3 ).
The key 3 is switched with the value 5: sum of the original key in this node ( meaning, 3 ) and all the smaller keys than it ( which is 2 ).
The key 12 in the rightmost node is switched with the value 45: sum of the original key in this node ( meaning, 12 ) and all the smaller keys than it ( which are 2,3,5,6,8 and 9 ).

Note that the method needs to be an envelope function that envelopes a recursive function. Also note that the method cumsum does not return a new tree but rather updates the tree itself ( in-place )

My attempt:

def cumsum(T):
    def cumsum_rec(node,L):
        L.append(node.key)
        if node.left != None:
            cumsum_rec(node.left,L)
        if node.right != None:
            cumsum_rec(node.right,L)
        count = 0
        for val in L:
            if val < node.key:
                count += val
        node.key += count

    L = []
    cumsum_rec(T.root,L)

Explanation : I traverse each node in the tree, I save each original node in an auxiliary list denoted as 'L'. When all nodes were traversed, I look for all the nodes keys in the list 'L' that are smaller than the current node and appropriately update the key of the current node to be the sum of the keys of all nodes whose key are smaller or equal to the current node's key.

My implemenation works, for example, for the tree in the example above:

t = Binary_search_tree()
t.insert(5,'A')
t.insert(2,'A')
t.insert(9,'A')
t.insert(3,'A')
t.insert(8,'A')
t.insert(12,'A')
t.insert(6,'A')

which looks:

>>> print(t)
          5                   
   ______/ |__________        
  2                   9       
 / |__             __/ |__    
#     3           8       12  
     / |       __/ |     /  | 
    #   #     6     #   #    #
             / |              
            #   #  

And after performing cumsum operation on it:

>>> cumsum(t)
>>> print(t)
          10                      
   ______/  |____________         
  2                      33       
 / |__                __/  |__    
#     5             24        45  
     / |         __/  |      /  | 
    #   #      16      #    #    #
              /  |                
             #    #  

My question:
Although my implementation works, I was interested in seeing other possible implementations for the sake of learning. Do you have an alternative implementation? one that doesn't require using a list as an input to the recursive function?

Addendum ( implementation of Binary search tree and Tree_node classes if you're interested ):

def printree(t, bykey = True):
        """Print a textual representation of t
        bykey=True: show keys instead of values"""
        #for row in trepr(t, bykey):
        #        print(row)
        return trepr(t, bykey)

def trepr(t, bykey = False):
        """Return a list of textual representations of the levels in t
        bykey=True: show keys instead of values"""
        if t==None:
                return ["#"]

        thistr = str(t.key) if bykey else str(t.val)

        return conc(trepr(t.left,bykey), thistr, trepr(t.right,bykey))

def conc(left,root,right):
        """Return a concatenation of textual represantations of
        a root node, its left node, and its right node
        root is a string, and left and right are lists of strings"""
        
        lwid = len(left[-1])
        rwid = len(right[-1])
        rootwid = len(root)
        
        result = [(lwid+1)*" " + root + (rwid+1)*" "]
        
        ls = leftspace(left[0])
        rs = rightspace(right[0])
        result.append(ls*" " + (lwid-ls)*"_" + "/" + rootwid*" " + "|" + rs*"_" + (rwid-rs)*" ")
        
        for i in range(max(len(left),len(right))):
                row = ""
                if i<len(left):
                        row += left[i]
                else:
                        row += lwid*" "

                row += (rootwid+2)*" "
                
                if i<len(right):
                        row += right[i]
                else:
                        row += rwid*" "
                        
                result.append(row)
                
        return result

def leftspace(row):
        """helper for conc"""
        #row is the first row of a left node
        #returns the index of where the second whitespace starts
        i = len(row)-1
        while row[i]==" ":
                i-=1
        return i+1

def rightspace(row):
        """helper for conc"""
        #row is the first row of a right node
        #returns the index of where the first whitespace ends
        i = 0
        while row[i]==" ":
                i+=1
        return i



#######################################################################

class Tree_node():
    def __init__(self, key, val):
        self.key = key
        self.val = val
        self.left = None
        self.right = None

    def __repr__(self):
        return "(" + str(self.key) + ":" + str(self.val) + ")"
    
    
    
class Binary_search_tree():

    def __init__(self):
        self.root = None


    def __repr__(self): #no need to understand the implementation of this one
        out = ""
        for row in printree(self.root): #need printree.py file
            out = out + row + "\n"
        return out


    def lookup(self, key):
        ''' return node with key, uses recursion '''

        def lookup_rec(node, key):
            if node == None:
                return None
            elif key == node.key:
                return node
            elif key < node.key:
                return lookup_rec(node.left, key)
            else:
                return lookup_rec(node.right, key)

        return lookup_rec(self.root, key)



    def insert(self, key, val):
        ''' insert node with key,val into tree, uses recursion '''

        def insert_rec(node, key, val):
            if key == node.key:
                node.val = val     # update the val for this key
            elif key < node.key:
                if node.left == None:
                    node.left = Tree_node(key, val)
                else:
                    insert_rec(node.left, key, val)
            else: #key > node.key:
                if node.right == None:
                    node.right = Tree_node(key, val)
                else:
                    insert_rec(node.right, key, val)
            return
        
        if self.root == None: #empty tree
            self.root = Tree_node(key, val)
        else:
            insert_rec(self.root, key, val)

Thanks in advance for any help!

Here's one implementation that doesn't require keeping the extra list; it just adds the numbers up as it goes.

def cumsum(T):
    def cumsum_rec(node, initial):
        if node is None:
            return initial
        left = cumsum_rec(node.left, initial)
        node.key = left + node.key
        right = cumsum_rec(node.right, node.key)
        return right
    cumsum_rec(T.root, 0)

Note that there is no need to do extra comparisons of values (my code has no < ), because all of that information is already contained in the structure of the tree.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM