I'm trying to implement the soft actor critic algorithm for discrete action space and I have trouble with the loss function.
Here is the link from SAC with continuous action space: https://spinningup.openai.com/en/latest/algorithms/sac.html
I do not know what I'm doing wrong.
The problem is the network do not learn anything on the cartpole environment.
The full code on github: https://github.com/tk2232/sac_discrete/blob/master/sac_discrete.py
Here is my idea how to calculate the loss for discrete actions.
class ValueNet:
def __init__(self, sess, state_size, hidden_dim, name):
self.sess = sess
with tf.variable_scope(name):
self.states = tf.placeholder(dtype=tf.float32, shape=[None, state_size], name='value_states')
self.targets = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='value_targets')
x = Dense(units=hidden_dim, activation='relu')(self.states)
x = Dense(units=hidden_dim, activation='relu')(x)
self.values = Dense(units=1, activation=None)(x)
optimizer = tf.train.AdamOptimizer(0.001)
loss = 0.5 * tf.reduce_mean((self.values - tf.stop_gradient(self.targets)) ** 2)
self.train_op = optimizer.minimize(loss, var_list=_params(name))
def get_value(self, s):
return self.sess.run(self.values, feed_dict={self.states: s})
def update(self, s, targets):
self.sess.run(self.train_op, feed_dict={self.states: s, self.targets: targets})
In the Q_Network I'm gather the values with the collected actions
q_out = [[0.5533, 0.4444], [0.2222, 0.6666]]
collected_actions = [0, 1]
gather = [[0.5533], [0.6666]]
def gather_tensor(params, idx):
idx = tf.stack([tf.range(tf.shape(idx)[0]), idx[:, 0]], axis=-1)
params = tf.gather_nd(params, idx)
return params
class SoftQNetwork:
def __init__(self, sess, state_size, action_size, hidden_dim, name):
self.sess = sess
with tf.variable_scope(name):
self.states = tf.placeholder(dtype=tf.float32, shape=[None, state_size], name='q_states')
self.targets = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='q_targets')
self.actions = tf.placeholder(dtype=tf.int32, shape=[None, 1], name='q_actions')
x = Dense(units=hidden_dim, activation='relu')(self.states)
x = Dense(units=hidden_dim, activation='relu')(x)
x = Dense(units=action_size, activation=None)(x)
self.q = tf.reshape(gather_tensor(x, self.actions), shape=(-1, 1))
optimizer = tf.train.AdamOptimizer(0.001)
loss = 0.5 * tf.reduce_mean((self.q - tf.stop_gradient(self.targets)) ** 2)
self.train_op = optimizer.minimize(loss, var_list=_params(name))
def update(self, s, a, target):
self.sess.run(self.train_op, feed_dict={self.states: s, self.actions: a, self.targets: target})
def get_q(self, s, a):
return self.sess.run(self.q, feed_dict={self.states: s, self.actions: a})
class PolicyNet:
def __init__(self, sess, state_size, action_size, hidden_dim):
self.sess = sess
with tf.variable_scope('policy_net'):
self.states = tf.placeholder(dtype=tf.float32, shape=[None, state_size], name='policy_states')
self.targets = tf.placeholder(dtype=tf.float32, shape=[None, 1], name='policy_targets')
self.actions = tf.placeholder(dtype=tf.int32, shape=[None, 1], name='policy_actions')
x = Dense(units=hidden_dim, activation='relu')(self.states)
x = Dense(units=hidden_dim, activation='relu')(x)
self.logits = Dense(units=action_size, activation=None)(x)
dist = Categorical(logits=self.logits)
optimizer = tf.train.AdamOptimizer(0.001)
# Get action
self.new_action = dist.sample()
self.new_log_prob = dist.log_prob(self.new_action)
# Calc loss
log_prob = dist.log_prob(tf.squeeze(self.actions))
loss = tf.reduce_mean(tf.squeeze(self.targets) - 0.2 * log_prob)
self.train_op = optimizer.minimize(loss, var_list=_params('policy_net'))
def get_action(self, s):
action = self.sess.run(self.new_action, feed_dict={self.states: s[np.newaxis, :]})
return action[0]
def get_next_action(self, s):
next_action, next_log_prob = self.sess.run([self.new_action, self.new_log_prob], feed_dict={self.states: s})
return next_action.reshape((-1, 1)), next_log_prob.reshape((-1, 1))
def update(self, s, a, target):
self.sess.run(self.train_op, feed_dict={self.states: s, self.actions: a, self.targets: target})
def soft_q_update(batch_size, frame_idx):
gamma = 0.99
alpha = 0.2
state, action, reward, next_state, done = replay_buffer.sample(batch_size)
action = action.reshape((-1, 1))
reward = reward.reshape((-1, 1))
done = done.reshape((-1, 1))
v_ = value_net_target.get_value(next_state)
q_target = reward + (1 - done) * gamma * v_
next_action, next_log_prob = policy_net.get_next_action(state)
q1 = soft_q_net_1.get_q(state, next_action)
q2 = soft_q_net_2.get_q(state, next_action)
q = np.minimum(q1, q2)
v_target = q - alpha * next_log_prob
q1 = soft_q_net_1.get_q(state, action)
q2 = soft_q_net_2.get_q(state, action)
policy_target = np.minimum(q1, q2)
Since the algorithm is generic to both discrete and continuous policy, the key idea is that we need a discrete distribution that is reparametrizable. Then the extension should involve minimal code modification from the continuous SAC --- by just changing the policy distribution class.
There is one such distribution — the GumbelSoftmax distribution. PyTorch does not have this built-in, so I simply extend it from a close cousin which has the right rsample() and add a correct log prob calculation method. With the ability to calculate a reparametrized action and its log prob, SAC works beautifully for discrete actions with minimal extra code, as seen below.
def calc_log_prob_action(self, action_pd, reparam=False):
'''Calculate log_probs and actions with option to reparametrize from paper eq. 11'''
samples = action_pd.rsample() if reparam else action_pd.sample()
if self.body.is_discrete: # this is straightforward using GumbelSoftmax
actions = samples
log_probs = action_pd.log_prob(actions)
else:
mus = samples
actions = self.scale_action(torch.tanh(mus))
# paper Appendix C. Enforcing Action Bounds for continuous actions
log_probs = (action_pd.log_prob(mus) - torch.log(1 - actions.pow(2) + 1e-6).sum(1))
return log_probs, actions
# ... for discrete action, GumbelSoftmax distribution
class GumbelSoftmax(distributions.RelaxedOneHotCategorical):
'''
A differentiable Categorical distribution using reparametrization trick with Gumbel-Softmax
Explanation http://amid.fish/assets/gumbel.html
NOTE: use this in place PyTorch's RelaxedOneHotCategorical distribution since its log_prob is not working right (returns positive values)
Papers:
[1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables (Maddison et al, 2017)
[2] Categorical Reparametrization with Gumbel-Softmax (Jang et al, 2017)
'''
def sample(self, sample_shape=torch.Size()):
'''Gumbel-softmax sampling. Note rsample is inherited from RelaxedOneHotCategorical'''
u = torch.empty(self.logits.size(), device=self.logits.device, dtype=self.logits.dtype).uniform_(0, 1)
noisy_logits = self.logits - torch.log(-torch.log(u))
return torch.argmax(noisy_logits, dim=-1)
def log_prob(self, value):
'''value is one-hot or relaxed'''
if value.shape != self.logits.shape:
value = F.one_hot(value.long(), self.logits.shape[-1]).float()
assert value.shape == self.logits.shape
return - torch.sum(- value * F.log_softmax(self.logits, -1), -1)
And here's the LunarLander results. SAC learns to solve it really fast.
The full implementation code is in SLM Lab at https://github.com/kengz/SLM-Lab/blob/master/slm_lab/agent/algorithm/sac.py
The SAC benchmark results on Roboschool (continuous) and LunarLander (discrete) are shown here: https://github.com/kengz/SLM-Lab/pull/399
There is a paper about SAC with discrete action spaces. It says SAC for discrete action spaces doesn't need re-parametrization tricks like Gumbel softmax. Instead, SAC needs some modifications. please refer to the paper for more details.
Paper / Author's implementation (without codes for atari) / Reproduction (with codes for atari)
I hope it helps you.
Pytorch 1.8 has RelaxedOneHotCategorical, this supports re-parameterized sampling using gumbel softmax.
import torch
import torch.nn as nn
from torch.distributions import RelaxedOneHotCategorical
class Policy(nn.Module):
def __init__(self, input_dims, hidden_dims, actions):
super().__init__()
self.mlp = nn.Sequential(nn.Linear(input_dims, hidden_dims), nn.SELU(inplace=True),
nn.Linear(hidden_dims, hidden_dims), nn.SELU(inplace=True),
nn.Linear(hidden_dims, out_dims))
def forward(self, state):
logits = torch.log_softmax(self.mlp(state), dim=-1)
return RelaxedOneHotCategorical(logits=logits, temperature=torch.ones(1) * 1.0)
>>> policy = Policy(4, 16, 2)
>>> a_dist = policy(torch.randn(8, 4))
>>> a_dist.rsample()
tensor([[0.0353, 0.9647],
[0.1348, 0.8652],
[0.1110, 0.8890],
[0.4956, 0.5044],
[0.6941, 0.3059],
[0.6126, 0.3874],
[0.2932, 0.7068],
[0.0498, 0.9502]], grad_fn=<ExpBackward>)
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.