简体   繁体   中英

Getting the next layer of nodes in a binary tree using Python list comprehension

I'm working on the LeetCode problem Maximum Depth of Binary Tree , and would like to define a helper function get_next_nodes which, given a list of nodes in one 'layer' of the tree, returns the list of nodes in the next layer. I've tried the following:

# Definition for a binary tree node.
class TreeNode:
    def __init__(self, x):
        self.val = x
        self.left = None
        self.right = None


def get_next_nodes(nodes):
    return [next_node for next_node in (node.left, node.right) for node in nodes if next_node]


def test_get_next_nodes():
    tree = TreeNode(1)
    tree.left = TreeNode(2)
    assert get_next_nodes([tree]) == [tree.left]

However, this test fails because node is not defined in the list comprehension:

Kurts-MacBook-Pro:LeetCode kurtpeek$ pytest maximum_depth_of_binary_tree.py::test_get_next_nodes
============================= test session starts ==============================
platform darwin -- Python 3.7.0, pytest-3.6.4, py-1.5.4, pluggy-0.6.0
rootdir: /Users/kurtpeek/GoogleDrive/LeetCode, inifile:
plugins: timeout-1.3.2
collected 1 item                                                               

maximum_depth_of_binary_tree.py F                                        [100%]

=================================== FAILURES ===================================
_____________________________ test_get_next_nodes ______________________________

    def test_get_next_nodes():
        tree = TreeNode(1)
        tree.left = TreeNode(2)
>       assert get_next_nodes([tree]) == [tree.left]

maximum_depth_of_binary_tree.py:41: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

nodes = [TreeNode(1)]

    def get_next_nodes(nodes):
>       return [next_node for next_node in (node.left, node.right) for node in nodes if next_node]
E       NameError: name 'node' is not defined

maximum_depth_of_binary_tree.py:35: NameError
=========================== 1 failed in 0.04 seconds ===========================

I've tried to compare with examples in https://docs.python.org/3/tutorial/datastructures.html#list-comprehensions but so far don't see how to refactor the list comprehension to get it to work. How can I fix this helper function?

Update

I got the helper function to work when written as follows:

def get_next_nodes(nodes):
    next_nodes = []
    for node in nodes:
        next_nodes += [child for child in (node.left, node.right) if child]
    return next_nodes

However, this seems like the type of pattern that could be refactored using list comprehension.

You confused the order of clauses in the nested list comprehension. Correct expression:

return [next_node 
            for node in nodes 
                for next_node in (node.left, node.right) 
                    if next_node]

Think of a nested LC as a nested loop: an iteration of the outer loop ( for node in nodes ) is executed first.

You should write your for s in the list comprehension in another order:

[next_node for node in nodes for next_node in (node.left, node.right) if next_node]

The docs for list comprehensions have an example on such a multi-level list comprehension with an explanation how it is evaluated.

In your case it would be:

result = []
for node in nodes:
    for next_node in (node.left, node.right):
        if next_node:
            result.append(next_mode)
return result

Note that changing the order of the two for loops would make no sense, because node would be undefined. This is exactly what happens in your code.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM