简体   繁体   中英

How to Record Variables in Pytorch Without Breaking Gradient Computation?

I am trying to implement some policy gradient training, similar to this . However, I would like to manipulate the rewards(like discounted future sum and other differentiable operations) before doing the backward propagation.

Consider the manipulate function defined to calculate the reward to go :

def manipulate(reward_pool):
    n = len(reward_pool)
    R = np.zeros_like(reward_pool)
    for i in reversed(range(n)):
        R[i] = reward_pool[i] + (R[i+1] if i+1 < n else 0)
    return T.as_tensor(R)

I tried to store the rewards in a list:

#pseudocode
reward_pool = [0 for i in range(batch_size)]

for k in batch_size:
  act = net(state)
  state, reward = env.step(act)
  reward_pool[k] = reward

R = manipulate(reward_pool)
R.backward()
optimizer.step()

It seems like inplace operation breaks the gradient computation, the code gives me an error: one of the variables needed for gradient computation has been modified by an inplace operation .

I also tried to initialize an empty tensor first, and store it in the tensor, but inplace operation is still the issue - a view of a leaf Variable that requires grad is being used in an in-place operation.

I am kind of new to PyTorch. Does anyone know what the right way recording rewards is in this case?

Edit: Solution Found

Simply initialize the empty pool(list) for each iteration and append to the pool when new reward is calculated, ie

reward_pool = []

for k in batch_size:
  act = net(state)
  state, reward = env.step(act)
  reward_pool.append(reward)

R = manipulate(reward_pool)
R.backward()
optimizer.step()

The issue is due to assignment to existing object. Simply initialize the empty pool(list) for each iteration and append to the pool when new reward is calculated, ie

reward_pool = []

for k in batch_size:
  act = net(state)
  state, reward = env.step(act)
  reward_pool.append(reward)

R = manipulate(reward_pool)
R.backward()
optimizer.step()

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM