簡體   English   中英

如何在一個張量流概率層中混合多個分布?

[英]How to mix many distributions in one tensorflow probability layer?

我有幾個DistributionLambda層作為一個模型的輸出,我想將類似 Concatenate 的操作放入一個新層,以便只有一個輸出是所有分布的混合,假設它們是獨立的。 然后,我可以對模型的輸出應用對數似然損失。 否則,我無法將損失應用於Concatenate層,因為它丟失了log_prob方法。 我一直在嘗試Blockwise發行版,但到目前為止沒有運氣。

這是一個示例代碼:

from tensorflow.keras import layers
from tensorflow.keras import models
from tensorflow.keras import optimizers
from tensorflow_probability import distributions
from tensorflow_probability import layers as tfp_layers


def likelihood_loss(y_true, y_pred):
    """Adding negative log likelihood loss."""
    return -y_pred.log_prob(y_true)


def distribution_fn(params):
    """Distribution function."""
    return distributions.Normal(
        params[:, 0], math.log(1.0 + math.exp(params[:, 1])))


output_steps = 3
...
lstm_layer = layers.LSTM(10, return_state=True)
last_layer, l_h, l_c = lstm_layer(last_layer)
lstm_state = [l_h, l_c]
dense_layer = layers.Dense(2)
last_layer = dense_layer(last_layer)
last_layer = tfp_layers.DistributionLambda(
    make_distribution_fn=distribution_fn)(last_layer)
output_layers = [last_layer]
# Get output sequence, re-injecting the output of each step
for number in range(1, output_steps):
    last_layer = layers.Reshape((1, 1))(last_layer)
    last_layer, l_h, l_c = lstm_layer(last_layer, initial_state=lstm_states)
    # Storing state for next time step
    lstm_states = [l_h, l_c]
    last_layer = tfp_layers.DistributionLambda(
        make_distribution_fn=distribution_fn)(dense_layer(last_layer))
    output_layers.append(last_layer)

# This does not work
# last_layer = distributions.Blockwise(output_layers)

# This works for the model but cannot compute loss
# last_layer = layers.Concatenate(axis=1)(output_layers)
the_model = models.Model(inputs=[input_layer], outputs=[last_layer])
the_model.compile(loss=likelihood_loss, optimizer=optimizers.Adam(lr=0.001))

問題是你的輸入,而不是你的輸出層;)

Input:0 在您的錯誤消息中被引用。 您能否嘗試更具體地說明您的意見?

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM