繁体   English   中英

更新 Tensorflow 变量的形状 (TF 2.x)

[英]Update shape of a Tensorflow Variable (TF 2.x)

我正在使用 Tensorflow 2.9.1。

我正在创建一个神经网络,它有一个特定的 tf.Variable 定义为:

tf.Variable(lam_r, dtype = DTYPE)

在哪里

lam_r = tf.reshape(tf.repeat(1.0, 25000), shape = (25000,1))

然后通过培训过程对这个变量进行培训,我对此没有任何问题。 它工作正常。 问题来了,在训练的某个时刻,我需要通过在这个变量的末尾添加 50 个以上的值来更新这个变量,也就是说,我需要获得相同的变量但形状为 (25050, 1),第一个是 25000 个原始变量,最后一个是 50 个新值,它们的初始化方式与“lam_r”相同。

我无法通过使用 tf.concat() function 来做到这一点。 此变量存储在 model.lambdas[0] 中,我正在尝试执行以下操作:

new_lam = tf.Variable(tf.reshape(tf.repeat(1.0, 50), shape = (50,1)), dtype = DTYPE)
model.lamdbas[0] = tf.concat([model.lambdas[0], new_lam], axis = 0)

通过这样做,我得到以下错误:

AttributeError: Tensor.name is undefined when eager execution is enabled.

我什至尝试为这两个变量命名,但我得到了同样的错误。 我只需要将几个点“连接”到原始 tf.Variable(),新形状的新变量仍然可以训练。

任何帮助表示赞赏。

谢谢。

concat返回一个Tensor ,而不是一个Variable ,因此您将向您的lambdas添加一个非Variable 如果这是您存储用于训练等的变量的地方,那么如果过程期望这些变量是Variable类型(因为这些变量具有除Tensor之外的其他属性......例如name ),这可能会导致问题。 相反,您可以这样做

new_lam = tf.reshape(tf.repeat(1.0, 50), shape = (50,1))
model.lambdas[0] = tf.Variable(tf.concat([model.lambdas[0], new_lam], axis = 0), dtype=DTYPE)

也就是说,您必须将concat的结果转换为Variable 在连接之前转换new_lam是没有意义的。

根据DTYPE ,您可能需要在连接之前将new_lam转换为该 dtype。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM