I'm trying to create an extremely simple network with a GRUcell layer to perform the following task: a cue is given in one of two locations. After T timesteps, the agent must learn to take a particular action at the opposite location.
I get the following error when trying to compute backward gradients:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation.
One problem is that I don't fully understand which piece of my code is performing the inplace operation.
I have read other posts on stackoverflow and on the pytorch forum, which all suggest using the .clone()
operation. I've peppered it all over my code anywhere I think it could conceivably make a difference, but I haven't had success.
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.gru = nn.GRUCell(2,50) # GRU layer taking 2 inputs (L or R), has 50 units
self.actor = nn.Linear(50,2) # Linear Actor layer with 2 outputs, takes GRU as input
self.critic = nn.Linear(50,1) # Linear Critic layer with 1 output, takes GRU as input
def forward(self, s, h):
h = self.gru(s,h) # give the input and previous hidden state to the GRU layer
c = self.critic(h) # estimate the value of the current state
pi = F.softmax(self.actor(h),dim=1) # calculate the policy
return (h,c,pi)
def backward_rollout(self, gamma, R_t, c_t, t):
R_t[0,t] = gamma*R_t[0,t+1].clone()
# calculate the reward prediction error
Delta_t[0,t] = c_t[0,t].clone() - R_t[0,t].clone()
#calculate the loss for the critic
crit = c_t[0,t].clone()
ret = R_t[0,t].clone()
Value_l[0,t] = F.smooth_l1_loss(crit,ret)
###################################
# Run a trial
# parameters
N = 1 # number of trials to run
T = 10 # number of time-steps in a trial
gamma = 0.98 # temporal discount factor
# for each trial
for n in range(N):
sample = np.random.choice([0,1],1)[0] # pick the sample input for this trial
s_t = torch.zeros((1,2,T)) # state at each time step
h_0 = torch.zeros((1,50)) # initial hidden state
h_t = torch.zeros((1,50,T)) # hidden state at each time step
c_t = torch.zeros((1,T)) # critic at each time step
pi_t = torch.zeros((1,2,T)) # policy at each time step
R_t = torch.zeros((1,T)) # return at each time step
Delta_t = torch.zeros((1,T)) # difference between critic and true return at each step
Value_l = torch.zeros((1,T)) # value loss
# set the input (state) vector/tensor
s_t[0,sample,0] = 1.0 # set first time-step stimulus
s_t[0,0,-1] = 1.0 # set last time-step stimulus
s_t[0,1,-1] = 1.0 # set last time-step stimulus
# step through the trial
for t in range(T):
# run a forward step
state = s_t[:,:,t].clone()
if t is 0:
(hidden_state, critic, policy) = net(state, h_0)
else:
(hidden_state, critic, policy) = net(state, h_t[:,:,t-1])
h_t[:,:,t] = hidden_state.clone()
c_t[:,t] = critic.clone()
pi_t[:,:,t] = policy.clone()
# select an action using the policy
action = np.random.choice([0,1], 1, p = policy[0,:].detach().numpy())
#action = int(np.random.uniform() < pi[0,1])
# compare the action to the sample
if action is sample:
r = 0
print("WRONG!")
else:
r = 1
print("RIGHT!")
#h_t_old = h_t
#s_t_old = s_t
# step backwards through the trial to calculate gradients
R_t[0,-1] = r
Delta_t[0,-1] = c_t[0,-1].clone() - r
Value_l[0,-1] = F.smooth_l1_loss(c_t[0,-1],R_t[0,-1]).clone()
for t in np.arange(T-2,-1,-1): #backwards rollout
net.backward_rollout(gamma, R_t, c_t, t)
Vl = Value_l.clone().sum()#calculate total loss
Vl.backward() #calculate the derivatives
opt.step() #update the weights
opt.zero_grad() #zero gradients before next trial
You can try anomaly_detection
to pinpoint the exact offending in-place operation: https://github.com/pytorch/pytorch/issues/15803
Value_l[0,-1] =
and similar are in-place operations. You can sidestep the check by doing Value_l.data[0,-1] =
, but this is not stored in the computational graph and may be a bad idea. A relevant discussion is here: https://discuss.pytorch.org/t/how-to-get-around-in-place-operation-error-if-index-leaf-variable-for-gradient-update/14554
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.