I'm working on the LeetCode problem Maximum Depth of Binary Tree , and would like to define a helper function get_next_nodes
which, given a list of nodes in one 'layer' of the tree, returns the list of nodes in the next layer. I've tried the following:
# Definition for a binary tree node.
class TreeNode:
def __init__(self, x):
self.val = x
self.left = None
self.right = None
def get_next_nodes(nodes):
return [next_node for next_node in (node.left, node.right) for node in nodes if next_node]
def test_get_next_nodes():
tree = TreeNode(1)
tree.left = TreeNode(2)
assert get_next_nodes([tree]) == [tree.left]
However, this test fails because node
is not defined in the list comprehension:
Kurts-MacBook-Pro:LeetCode kurtpeek$ pytest maximum_depth_of_binary_tree.py::test_get_next_nodes
============================= test session starts ==============================
platform darwin -- Python 3.7.0, pytest-3.6.4, py-1.5.4, pluggy-0.6.0
rootdir: /Users/kurtpeek/GoogleDrive/LeetCode, inifile:
plugins: timeout-1.3.2
collected 1 item
maximum_depth_of_binary_tree.py F [100%]
=================================== FAILURES ===================================
_____________________________ test_get_next_nodes ______________________________
def test_get_next_nodes():
tree = TreeNode(1)
tree.left = TreeNode(2)
> assert get_next_nodes([tree]) == [tree.left]
maximum_depth_of_binary_tree.py:41:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
nodes = [TreeNode(1)]
def get_next_nodes(nodes):
> return [next_node for next_node in (node.left, node.right) for node in nodes if next_node]
E NameError: name 'node' is not defined
maximum_depth_of_binary_tree.py:35: NameError
=========================== 1 failed in 0.04 seconds ===========================
I've tried to compare with examples in https://docs.python.org/3/tutorial/datastructures.html#list-comprehensions but so far don't see how to refactor the list comprehension to get it to work. How can I fix this helper function?
Update
I got the helper function to work when written as follows:
def get_next_nodes(nodes):
next_nodes = []
for node in nodes:
next_nodes += [child for child in (node.left, node.right) if child]
return next_nodes
However, this seems like the type of pattern that could be refactored using list comprehension.
You confused the order of clauses in the nested list comprehension. Correct expression:
return [next_node
for node in nodes
for next_node in (node.left, node.right)
if next_node]
Think of a nested LC as a nested loop: an iteration of the outer loop ( for node in nodes
) is executed first.
You should write your for
s in the list comprehension in another order:
[next_node for node in nodes for next_node in (node.left, node.right) if next_node]
The docs for list comprehensions have an example on such a multi-level list comprehension with an explanation how it is evaluated.
In your case it would be:
result = []
for node in nodes:
for next_node in (node.left, node.right):
if next_node:
result.append(next_mode)
return result
Note that changing the order of the two for
loops would make no sense, because node
would be undefined. This is exactly what happens in your code.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.