简体   繁体   中英

While recursively finding max path sum, append left or right direction of binary tree

I am creating python code to identify the largest (sum of nodes) path of a binary tree. Whilst recurring through the tree, I'd like to append the path direction (either "l" or "r" for left and right respectively) into a list that can be called later in the code.

So far I'm managed to correctly get the largest path (max sum of nodes) and the first path direction, but not the full path.

I feel like I'm close to getting this done, just need a hint in the right direction.

def sum_max_path(root):

    if not root:
        return

    if root.left is None:
        l_sum = 0
    else:
        l_sum = sum_max_path(root.left)

    if root.right is None:
        r_sum = 0
    else:
        r_sum = sum_max_path(root.right)

    if l_sum > r_sum:
        root.list.append("l")
        return root.value + l_sum
    elif l_sum < r_sum:
        root.list.append("r")
        return root.value + r_sum
    else:
        root.list.append("l")
        return root.value + l_sum

    return root.value + max(l_sum, r_sum)

return sum_max_path(root), root.list

The output of this is:

The total value in this path is: 8
The largest value path is: ['l']

What I'd like if for the output to be:

The largest value path is ['l', 'r', 'l'] 

(Obviously depending on how long the path is based on the generated tree).

Do not store the path statically, instead pass it to and return it from each recursive invocation:

def max_sum(node, path):
    ls = rs = 0
    lp = rp = path
    if node.left:
        ls, lp = max_sum(node.left, path + ['l'])
    if node.right:
        rs, rp = max_sum(node.right, path + ['r'])
    ls += node.value
    rs += node.value
    return (ls, lp) if ls > rs else (rs, rp)

Complete example:

class Node:
    def __init__(self, value, left=None, right=None):
        self.value = value
        self.left = left
        self.right = right


tree = Node(
    1,
    Node(
        9,
        Node(2),
        Node(3)
    ),
    Node(
        8,
        Node(2),
        Node(
            5,
            Node(3),
            Node(2)
        )
    )
)


def max_sum(node, path):
    ls = rs = 0
    lp = rp = path
    if node.left:
        ls, lp = max_sum(node.left, path + ['l'])
    if node.right:
        rs, rp = max_sum(node.right, path + ['r'])
    ls += node.value
    rs += node.value
    return (ls, lp) if ls > rs else (rs, rp)


print(max_sum(tree, []))

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM