简体   繁体   中英

How does updating a Python for loop within a list work?

In Michael Nielsen's tutorial on neural networks he has the following code:

def update_mini_batch(self, mini_batch, eta):
    """The ``mini_batch`` is a list of tuples ``(x, y)``, and ``eta``
    is the learning rate."""
    nabla_b = [np.zeros(b.shape) for b in self.biases]
    nabla_w = [np.zeros(w.shape) for w in self.weights]
    for x, y in mini_batch:
        delta_nabla_b, delta_nabla_w = self.backprop(x, y)
        nabla_b = [nb+dnb for nb, dnb in zip(nabla_b, delta_nabla_b)]
        nabla_w = [nw+dnw for nw, dnw in zip(nabla_w, delta_nabla_w)]
    self.weights = [w-(eta/len(mini_batch))*nw
                    for w, nw in zip(self.weights, nabla_w)]
    self.biases = [b-(eta/len(mini_batch))*nb
                   for b, nb in zip(self.biases, nabla_b)]

I understand what tuples and lists are and I understand what the zip function is doing but I don't understand how the variables nb, dnb, nw, and dnw are updated on these 2 lines of code:

        nabla_b = [nb+dnb for nb, dnb in zip(nabla_b, delta_nabla_b)]
        nabla_w = [nw+dnw for nw, dnw in zip(nabla_w, delta_nabla_w)]

Can anyone help explain the magic going on in these 2 lines?

The zip function sticks the two lists together element by element, so that if you gave it:

a = [1, 2, 3, 4]
b = ["a", "b", "c", "d"]

zip(a, b) would return:

[(1, "a"), (2, "b"), ...]

(each element being a tuple )

You can unpack elements of list s that are tuple s (or list s) using the a comma between each variable in the element tuple :

for elem_a, elem_b in zip(a, b):
    print(elem_a, elem_b)

This would print:

1 a
2 b
3 c
4 d

So in your case, it's adding the two lists nabla_b and delta_nabla_b elementwise, so you get one list with each element being the sum of the corresponding elements in the zipped lists.

It might look a bit strange because the for loop is all on one line, but this is called a "list comprehension". Simple list comprehensions read like English.

These 2 lines are typical examples of Python list comprehensions .

In essence, for your first list:

nabla_b = [nb+dnb for nb, dnb in zip(nabla_b, delta_nabla_b)]

this means:

  1. Take the 1st pair from zip(nabla_b, delta_nabla_b) ; name them nb and dnb
  2. add them ( nb+dnb )
  3. make the result the 1st element of new list nabla_b
  4. Go to step one for the 2nd pair etc appending the results to nabla_b , until all pairs from zip(nabla_b, delta_nabla_b) have been exhausted

As a simple example, the list comprehension below:

squares = [x**2 for x in range(10)]
print(squares)
# [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

is equivalent with the following for loop:

squares = []

for x in range(10):
    squares.append(x**2)

print(squares)
# [0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

See here for more examples and a quick introduction.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM