简体   繁体   English

更新 Tensorflow 变量的形状 (TF 2.x)

[英]Update shape of a Tensorflow Variable (TF 2.x)

I am using Tensorflow 2.9.1.我正在使用 Tensorflow 2.9.1。

I am creating a neural network which has a certain tf.Variable defined as:我正在创建一个神经网络,它有一个特定的 tf.Variable 定义为:

tf.Variable(lam_r, dtype = DTYPE)

where在哪里

lam_r = tf.reshape(tf.repeat(1.0, 25000), shape = (25000,1))

This Variable is then trained through the training process and I have no problem with it.然后通过培训过程对这个变量进行培训,我对此没有任何问题。 It works fine.它工作正常。 The problem comes when, at a certain point of the training I need to update this variable by adding, let's say, 50 more values at the end of this variable, that is, I need to get the same Variable but with shape (25050,1), being the 25000 first ones the original variable and the 50 final ones the new values that are initialized in the same way as "lam_r".问题来了,在训练的某个时刻,我需要通过在这个变量的末尾添加 50 个以上的值来更新这个变量,也就是说,我需要获得相同的变量但形状为 (25050, 1),第一个是 25000 个原始变量,最后一个是 50 个新值,它们的初始化方式与“lam_r”相同。

I am not being able to do it by using the tf.concat() function.我无法通过使用 tf.concat() function 来做到这一点。 This Variable is stored in model.lambdas[0], and I am trying to do the following:此变量存储在 model.lambdas[0] 中,我正在尝试执行以下操作:

new_lam = tf.Variable(tf.reshape(tf.repeat(1.0, 50), shape = (50,1)), dtype = DTYPE)
model.lamdbas[0] = tf.concat([model.lambdas[0], new_lam], axis = 0)

By doing this I get the following error:通过这样做,我得到以下错误:

AttributeError: Tensor.name is undefined when eager execution is enabled.

I tried even to give a name to both variables but I get the same error.我什至尝试为这两个变量命名,但我得到了同样的错误。 I just need to "concat" several points to the original tf.Variable() in a way the new variable of new shape still remains trainable.我只需要将几个点“连接”到原始 tf.Variable(),新形状的新变量仍然可以训练。

Any help is appreciated.任何帮助表示赞赏。

Thanks.谢谢。

concat returns a Tensor , not a Variable , so you will be adding a non- Variable to your lambdas . concat返回一个Tensor ,而不是一个Variable ,因此您将向您的lambdas添加一个非Variable If this is where you are storing your variables for training etc., this can lead to issues if the process expects these to be Variable type (since these have other properties than a Tensor ... such as a name ).如果这是您存储用于训练等的变量的地方,那么如果过程期望这些变量是Variable类型(因为这些变量具有除Tensor之外的其他属性......例如name ),这可能会导致问题。 Instead you can do this相反,您可以这样做

new_lam = tf.reshape(tf.repeat(1.0, 50), shape = (50,1))
model.lambdas[0] = tf.Variable(tf.concat([model.lambdas[0], new_lam], axis = 0), dtype=DTYPE)

That is, you have to convert the result of concat to a Variable .也就是说,您必须将concat的结果转换为Variable Converting new_lam before the concatenation is pointless.在连接之前转换new_lam是没有意义的。

Depending on DTYPE , you might need to cast new_lam to that dtype before concatenating.根据DTYPE ,您可能需要在连接之前将new_lam转换为该 dtype。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM