繁体   English   中英

tf.GradientTape 与外积返回无

[英]tf.GradientTape with outer product returns None

我试图在计算损失函数之前对模型的预测进行后处理,因为我的真实数据 (y_train) 是 NN 输出的外积。 我已按照以下步骤操作:

  1. 我知道我尝试使用 numpy 执行的操作是:
nX = 201
nT = 101
nNNout = nX+nT
nBatch = 32

NNout = np.random.rand(nBatch, nNNout)

f = NNout[:, :nX]
g = NNout[:,nX:]

test = np.empty([nBatch, nX*nT])

for i in range(nBatch):
    test[i,:] = np.outer(f[i,:], g[i,:]).flatten('F')

其中 NN 输出包含 f 和 g。 我真正需要的是每个批处理实例的 f 和 g 外积的矢量化版本。

  1. 我在一个紧凑的 tensorflow 操作中将其翻译为:
test2 = tf.Variable([tf.reshape(tf.transpose(tf.tensordot(f[i,:],g[i,:], axes=0)),[nX*nT]) for i in range(nBatch)])

我已经检查过它是正确的,并且输出的值与步骤 1 中的值相同。

  1. 然后,我只是想在我的模型预测之后添加这个操作:
    n_epochs = 20
    batch_size = 32
    n_steps = len(x_train) // batch_size
    optimizer = keras.optimizers.Nadam(learning_rate=0.01)
    loss_fn = keras.losses.mean_squared_error
    mean_loss = keras.metrics.Mean()
    metrics = [keras.metrics.MeanAbsoluteError()]

    # ------------ Training ------------
    for epoch in range(1, n_epochs + 1):
        print("Epoch {}/{}".format(epoch, n_epochs))
        for step in range(1, n_steps + 1):
            X_batch, y_batch = random_batch(x_train, np.array(y_train))
            with tf.GradientTape() as tape:
                y_pred = model(X_batch, training=True)
                u_pred = tf.Variable([tf.reshape(tf.transpose(tf.tensordot(y_pred[i, :nX], y_pred[i, nX:], axes=0)), [nX * nT]) for i in
                             range(batch_size)])
                main_loss = tf.reduce_mean(loss_fn(y_batch, u_pred))
                loss = tf.add_n([main_loss] + model.losses)
            gradients = tape.gradient(loss, model.trainable_variables) 

我的主要问题是,当我添加操作时,梯度变成了一个 None 列表。 如果我简单地使用模型的预测 (y_pred) 计算损失函数,则代码能够计算梯度。

你能帮我找出我在这里犯的错误吗?

您正在 u_pred 中创建一个新的(可训练的)变量,从而打破 u_pred 对 y_pred 的任何依赖。 为什么值相匹配的原因是因为你与预测初始化你的新的变量,但对对方没有功能依赖关系了,有没有流动梯度。

我猜你这样做是因为你需要一个 tf.Tensor 而不是一个列表,你最终遇到了类型错误。 您可能希望在tf.concatenate行中使用某些内容,而不是tf.Variable

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM