简体   繁体   中英

Update shape of a Tensorflow Variable (TF 2.x)

I am using Tensorflow 2.9.1.

I am creating a neural network which has a certain tf.Variable defined as:

tf.Variable(lam_r, dtype = DTYPE)

where

lam_r = tf.reshape(tf.repeat(1.0, 25000), shape = (25000,1))

This Variable is then trained through the training process and I have no problem with it. It works fine. The problem comes when, at a certain point of the training I need to update this variable by adding, let's say, 50 more values at the end of this variable, that is, I need to get the same Variable but with shape (25050,1), being the 25000 first ones the original variable and the 50 final ones the new values that are initialized in the same way as "lam_r".

I am not being able to do it by using the tf.concat() function. This Variable is stored in model.lambdas[0], and I am trying to do the following:

new_lam = tf.Variable(tf.reshape(tf.repeat(1.0, 50), shape = (50,1)), dtype = DTYPE)
model.lamdbas[0] = tf.concat([model.lambdas[0], new_lam], axis = 0)

By doing this I get the following error:

AttributeError: Tensor.name is undefined when eager execution is enabled.

I tried even to give a name to both variables but I get the same error. I just need to "concat" several points to the original tf.Variable() in a way the new variable of new shape still remains trainable.

Any help is appreciated.

Thanks.

concat returns a Tensor , not a Variable , so you will be adding a non- Variable to your lambdas . If this is where you are storing your variables for training etc., this can lead to issues if the process expects these to be Variable type (since these have other properties than a Tensor ... such as a name ). Instead you can do this

new_lam = tf.reshape(tf.repeat(1.0, 50), shape = (50,1))
model.lambdas[0] = tf.Variable(tf.concat([model.lambdas[0], new_lam], axis = 0), dtype=DTYPE)

That is, you have to convert the result of concat to a Variable . Converting new_lam before the concatenation is pointless.

Depending on DTYPE , you might need to cast new_lam to that dtype before concatenating.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM