I am trying to replicate (a way smaller version) of the AlphaGo Zero system. However, in the network model, I am having a problem. The loss function I am supposed to implement is the following:
Where:
I pass to the network a list of channels (representing the game state) and an array (same size of the pi and p ) representing which actions are indeed valid (by putting 1
if valid, 0
otherwise).
As you can see, the loss function uses both the target and the network predictions for the calculation. But after extensive search, when implementing my custom loss function, I can only pass as parameter y_true
and y_pred
even though I have two "y_true's" and two "y_pred's". I have tried using indexing to get those values but I'm pretty sure it is not working.
The modeling of the network and the custom loss function is in the code below:
def custom_loss(y_true, y_pred):
# I am pretty sure this does not work
output_prob_dist = y_pred[0]
output_value = y_pred[1]
label_prob_dist = y_true[0]
label_value = y_pred[1]
mse_loss = K.mean(K.square(label_value - output_value), axis=-1)
cross_entropy_loss = K.dot(K.transpose(label_prob_dist), output_prob_dist)
return mse_loss - cross_entropy_loss
def define_model():
"""Neural Network model implementation using Keras + Tensorflow."""
state_channels = Input(shape = (5,5,6), name='States_Channels_Input')
valid_actions_dist = Input(shape = (32,), name='Valid_Actions_Input')
conv = Conv2D(filters=10, kernel_size=2, kernel_regularizer=regularizers.l2(0.0001), activation='relu', name='Conv_Layer')(state_channels)
pool = MaxPooling2D(pool_size=(2, 2), name='Pooling_Layer')(conv)
flat = Flatten(name='Flatten_Layer')(pool)
# Merge of the flattened channels (after pooling) and the valid action
# distribution. Used only as input in the probability distribution head.
merge = concatenate([flat, valid_actions_dist])
#Probability distribution over actions
hidden_fc_prob_dist_1 = Dense(100, kernel_regularizer=regularizers.l2(0.0001), activation='relu', name='FC_Prob_1')(merge)
hidden_fc_prob_dist_2 = Dense(100, kernel_regularizer=regularizers.l2(0.0001), activation='relu', name='FC_Prob_2')(hidden_fc_prob_dist_1)
output_prob_dist = Dense(32, kernel_regularizer=regularizers.l2(0.0001), activation='softmax', name='Output_Dist')(hidden_fc_prob_dist_2)
#Value of a state
hidden_fc_value_1 = Dense(100, kernel_regularizer=regularizers.l2(0.0001), activation='relu', name='FC_Value_1')(flat)
hidden_fc_value_2 = Dense(100, kernel_regularizer=regularizers.l2(0.0001), activation='relu', name='FC_Value_2')(hidden_fc_value_1)
output_value = Dense(1, kernel_regularizer=regularizers.l2(0.0001), activation='tanh', name='Output_Value')(hidden_fc_value_2)
model = Model(inputs=[state_channels, valid_actions_dist], outputs=[output_prob_dist, output_value])
model.compile(loss=custom_loss, optimizer='adam', metrics=['accuracy'])
return model
# In the main method
model = define_model()
# ...
# MCTS routine to collect the data for the network input
# ...
x_train = [channels_input, valid_actions_dist_input]
y_train = [dist_probs_label, who_won_label]
model.fit(x_train, y_train, epochs=10)
In short, my question is: how do I correctly implement this custom loss function that uses both the network outputs and label values of the network?
I check their git and there is a lot going on; As showing in the equetion the final loss is the combination of three different losses, and the three networks are minimizing this final loss. Their code of losses is below:
# train ops
policy_cost = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits_v2(
logits=logits, labels=tf.stop_gradient(labels['pi_tensor'])))
value_cost = params['value_cost_weight'] * tf.reduce_mean(
tf.square(value_output - labels['value_tensor']))
reg_vars = [v for v in tf.trainable_variables()
if 'bias' not in v.name and 'beta' not in v.name]
l2_cost = params['l2_strength'] * \
tf.add_n([tf.nn.l2_loss(v) for v in reg_vars])
combined_cost = policy_cost + value_cost + l2_cost
You can refer this and make your changes accordingly.
The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.