简体   繁体   English

用于逐元素乘法的 Lambda 层在第一次模型更新后给出 NaN。 (凯拉斯)

[英]Lambda Layer for element-wise multiplication gives NaNs after the first model update. (Keras)

My model is written in Keras.我的模型是用 Keras 编写的。 It has multiple inputs, one of which is to be multiplied with the outputs of the penultimate Dense layer before the logits are fed into the softmax.它有多个输入,其中之一是在 logits 输入 softmax 之前与倒数第二个 Dense 层的输出相乘。 This element wise multiplication is carried out by means of a Lambda Layer.这种逐元素乘法是通过 Lambda 层执行的。

def mul(x, mask):
    output = x*mask
    return output

logits = Lambda(lambda x: mul(x, input_2))(dense_output) # gives nan after first update

After the model is updated for the first time, the Lambda Layer starts giving NaN as output .模型第一次更新后,Lambda 层开始给出 NaN 作为输出
This does not happen if I have a constant vector defined within the computational graph being multiplied with the output of the dense layer.如果我在计算图中定义了一个常数向量与密集层的输出相乘,则不会发生这种情况。

logits = dense_output * [1, 1, 1, 1, 1, -100, 1, -100, 1, 1] # does not give nan

I have tried using the Multiply Layer provided by Keras as well.我也尝试过使用 Keras 提供的Multiply Layer But this too throws NaN after the first update.但这在第一次更新后也会抛出 NaN 。 Here is a snippet for the same:这是相同的片段:

logits = Multiply()([dense_output, input_2]) # gives nan after first update

I basically want to mask certain output states by the means of this multiplication with the input, but can't do it if the layer keeps giving NaN as output.我基本上想通过与输入相乘的方式来屏蔽某些输出状态,但如果层不断提供 NaN 作为输出,则无法做到。
Is there any way to solve this?有没有办法解决这个问题? Any and all help will be appreciated!任何和所有帮助将不胜感激!

I am not sure how this works, but I have found a solution to this.我不确定这是如何工作的,但我已经找到了解决方案。 It seems that adding a Dense layer in front of the Multiply layer solves the problem.看来在Multiply层前面加一个Dense层就解决了。 It doesn't matter if the Dense layer is trainable or not. Dense 层是否可训练并不重要。 Here is the code:这是代码:

logits = Multiply()([dense_output, input_2])
initializer = tf.keras.initializers.Identity()
masked_actions = Dense(num_actions, use_bias=False, 
                       trainable=False)(logits) #returns the same logits

The model updates now work as expected without throwing any NaN ValueErrors.模型更新现在按预期工作,不会抛出任何 NaN ValueErrors。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM