简体   繁体   中英

Given a binary tree and a sum, find all root-to-leaf paths where each path's sum equals the given sum

Why is "path.remove(path.size()-1)" used in the end?

This code is for finding all root to leaf paths whose sum is equal to given sum.

public List<List<Integer>> pathSum(TreeNode root, int sum) {
    List<List<Integer>> res = new ArrayList<List<Integer>>();
    ArrayList<Integer> path = new ArrayList<Integer>();
    pathSumRe(root, sum, res, path);
    return res;
}
public void pathSumRe(TreeNode root, int sum, List<List<Integer>> res,
        ArrayList<Integer> path) {

    if (root == null)
        return;

    path.add(root.val);

    if (root.left == null && root.right == null && root.val == sum) {

        ArrayList<Integer> tmp = new ArrayList<Integer>(path);
        res.add(tmp);
    }

    pathSumRe(root.left, sum - root.val, res, path);
    pathSumRe(root.right, sum - root.val, res, path);
    path.remove(path.size() - 1);
}

Removing "path.remove(path.size() - 1);" from the code will give the following output.

Input: [0,1,1], 1

Output: [[0,1],[0,1,1]] ==> This is the wrong output

Expected Output: [[0,1],[0,1]]

The path.remove(path.size() - 1) is removing the last added node from the path list, as you are reusing the same list for all recursive iterations and are adding the current node with path.add(root.val); in each method execution.


The following would be equivalent without reusing the same list (and creating a new one for each execution):

public void pathSumRe(TreeNode root, int sum, List<List<Integer>> res,
        ArrayList<Integer> path) {
    if (root == null) {
        return;
    }
    path.add(root.val);
    if (root.left == null && root.right == null && root.val == sum) {
        res.add(new ArrayList<Integer>(path));
    }
    pathSumRe(root.left, sum - root.val, res, new ArrayList<Integer>(path));
    pathSumRe(root.right, sum - root.val, res, new ArrayList<Integer>(path));
}

This is easier to understand, but creates way more new ArrayList s (depending on the tree structure). Regardless of your edit, both versions are working correctly for a TreeNode like this:

class TreeNode {
    public final int val;
    public final TreeNode left;
    public final TreeNode right;

    public TreeNode(int val, TreeNode left, TreeNode right) {
        this.val = val;
        this.left = left;
        this.right = right;
    }
}

Here is the clean Java Implementation :

public static List<List<Integer>> rootToLeafPathsForSum(BinaryTreeNode<Integer> node, int requiredSum) {
    List <List<Integer>> paths = new ArrayList<List<Integer>>();
    doFindRootToLeafPathsForSum(node, 0, requiredSum, new ArrayList<Integer>(), paths);
    return paths;
}

private static void doFindRootToLeafPathsForSum(BinaryTreeNode<Integer> node, int sum, int requiredSum,
        List<Integer> path, List<List<Integer>> paths) {
    if(node == null) {
        return ;
    } 
    path.add(node.getData());
    sum +=node.getData();
    if (node.isLeafNode()) {
        if (sum == requiredSum) {
            paths.add(new ArrayList<Integer>(path));
        }           
    } else {
        doFindRootToLeafPathsForSum(node.getLeft(), sum,  requiredSum, path, paths);
        doFindRootToLeafPathsForSum(node.getRight(), sum, requiredSum, path, paths);

    }
    path.remove(node.getData());
}

Here is the test case

@Test
public void allRoot2LeafPathsForGivenSum() {
    BinaryTreeNode<Integer> bt = buildTree();
    List <List<Integer>> paths = BinaryTreeUtil.rootToLeafPathsForSum(bt, 14);

    assertThat(paths.size(), is(2));

    assertThat(paths.get(0).toArray(new Integer[0]), equalTo(new Integer[]{1,2,5,6}));
    assertThat(paths.get(1).toArray(new Integer[0]), equalTo(new Integer[]{1,3,7,3}));

    for (List<Integer> list : paths) {          
        for (Integer integer : list) {
            System.out.print(String.format(" %d", integer));
        }
        System.out.println();
    }
}

private BinaryTreeNode<Integer> buildTree() {
    BinaryTreeNode<Integer> n4 = new BinaryTreeNode<Integer>(4);
    BinaryTreeNode<Integer> n6 = new BinaryTreeNode<Integer>(6);
    BinaryTreeNode<Integer> n5 = new BinaryTreeNode<Integer>(5, null, n6);
    BinaryTreeNode<Integer> n2= new BinaryTreeNode<Integer>(2, n4, n5);
    BinaryTreeNode<Integer> n31 = new BinaryTreeNode<Integer>(3);
    BinaryTreeNode<Integer> n7 = new BinaryTreeNode<Integer>(7, null, n31);
    BinaryTreeNode<Integer> n3 = new BinaryTreeNode<Integer>(3, n7, null);
    BinaryTreeNode<Integer> root = new BinaryTreeNode<Integer>(1, n2, n3);

    return root;
}

What you need is way to check the path traversed till now, which either sums till leaves to sum if yes you add that list to result if not you need to backtrack which is the most important step in this case! I hope code will make it more clear -

void util(TreeNode root, int sum, ArrayList<Integer>log, ArrayList<ArrayList<Integer>> result)
{
    if(root == null)    
        return;
    log.add(root.val);
    if(root.left == null && root.right == null && sum - root.val == 0)
        result.add(new ArrayList<Integer>(log));
    util(root.left, sum-root.val, log, result);
    util(root.right, sum-root.val, log, result);
    log.remove(log.size()-1);
}

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM