简体   繁体   中英

How does a gradient backpropagates through random samples?

I'm learning about policy gradients and I'm having hard time understanding how does the gradient passes through a random operation. From here : It is not possible to directly backpropagate through random samples. However, there are two main methods for creating surrogate functions that can be backpropagated through It is not possible to directly backpropagate through random samples. However, there are two main methods for creating surrogate functions that can be backpropagated through .

They have an example of the score function :

probs = policy_network(state)
# Note that this is equivalent to what used to be called multinomial
m = Categorical(probs)
action = m.sample()
next_state, reward = env.step(action)
loss = -m.log_prob(action) * reward
loss.backward()

Which I tried to create an example of:

import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions import Normal
import matplotlib.pyplot as plt
from tqdm import tqdm

softplus = torch.nn.Softplus()

class Model_RL(nn.Module):
    def __init__(self):
        super(Model_RL, self).__init__()
        self.fc1 = nn.Linear(1, 20)
        self.fc2 = nn.Linear(20, 30)
        self.fc3 = nn.Linear(30, 2)

    def forward(self, x):
        x1 = self.fc1(x)
        x = torch.relu(x1)
        x2 = self.fc2(x)
        x = torch.relu(x2)
        x3 = softplus(self.fc3(x))
        return x3, x2, x1

# basic 

net_RL = Model_RL()

features = torch.tensor([1.0]) 
x = torch.tensor([1.0]) 
y = torch.tensor(3.0)

baseline = 0
baseline_lr = 0.1

epochs = 3

opt_RL = optim.Adam(net_RL.parameters(), lr=1e-3)
losses = []
xs = []
for _ in tqdm(range(epochs)):
    out_RL = net_RL(x)
    mu, std = out_RL[0]
    dist = Normal(mu, std)
    print(dist)
    a = dist.sample()
    log_p = dist.log_prob(a)
    
    out = features * a
    reward = -torch.square((y - out))
    baseline = (1-baseline_lr)*baseline + baseline_lr*reward
    
    loss = -(reward-baseline)*log_p

    opt_RL.zero_grad()
    loss.backward()
    opt_RL.step()
    losses.append(loss.item())

This seems to work magically fine which again, I don't understand how the gradient passes through as they mentioned that it can't pass through the random operation (but then somehow it does).

Now since the gradient can't flow through the random operation I tried to replace mu, std = out_RL[0] with mu, std = out_RL[0].detach() and that caused the error: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn . If the gradient doesn't pass through the random operation, I don't understand why would detaching a tensor before the operation matter.

It is indeed true that sampling is not a differentiable operation per se . However, there exist two (broad) ways to mitigate this - [1] The REINFORCE way and [2] The reparameterization way. Since your example is related to [1], I will stick my answer to REINFORCE.

What REINFORCE does is it entirely gets rid of sampling operation in the computation graph . However, the sampling operation remains outside the graph. So, your statement

.. how does the gradient passes through a random operation..

isn't correct. It does not pass through any random operation. Let's see your example

mu, std = out_RL[0]
dist = Normal(mu, std)
a = dist.sample()
log_p = dist.log_prob(a)

Computation of a does not involve creating a computation graph. It is technically equivalent to plugging in some offline data from a dataset (as in supervised learning)

mu, std = out_RL[0]
dist = Normal(mu, std)
# a = dist.sample()
a = torch.tensor([1.23, 4.01, -1.2, ...], device='cuda')
log_p = dist.log_prob(a)

Since we don't have offline data beforehand, we create them on the fly and the .sample() method does merely that.

So, there is no random operation on the graph. The log_p depends on mu and std deterministically, just like any standard computation graph. If you cut the connection like this

mu, std = out_RL[0].detach()

.. of course it is going to complaint.

Also, do not get confused by this operation

dist = Normal(mu, std)
log_p = dist.log_prob(a)

as it does not contain any randomness by itself. This is merely a shortcut for writing the tedious log-likelihood formula for Normal distribution.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM