简体   繁体   中英

I need to upload weights that were saved on tensorflow 1.x to an identical model in tensroflow 2.x

So I have an old model with tensorflow 1.x code and it includes too much stuff I don't need, all I need is just the model and I created the model in a way I'm almost certain is identical to the previous one (I checked a bunch of stuff)

I have the .data and .index and a .meta file and I tried very many different types of things and either it says that "a few things weren't saved" and then lists all of the weights (but not really the entire thing, cause when the weights are too big it just adds three dots (...) )

I would LOVE to have someone tell me how I can use that in my new model

I tried:

model.load_weights

I tried:

tf.compat.v1.disable_eager_execution()

sess = tf.compat.v1.Session()

saver = tf.compat.v1.train.import_meta_graph('checkpoints/pix2pix-60.meta')

saver.restore( "checkpoints/pix2pix-60")

I tried:

tf.compat.v1.disable_eager_execution()

sess = tf.compat.v1.Session()

saver = tf.compat.v1.train.Checkpoint(model=gen)

saver.restore(tf.train.latest_checkpoint('checkpoints')).assert_consumed()

I tried:

ck_path = tf.train.latest_checkpoint('checkpoints')

gen.load_weights(ck_path)

I tried:

from tensorflow.python.training import checkpoint_utils as cp

ckpt = cp.load_checkpoint('checkpoints/pix2pix--60')

and then tried to see what I can do with that

and I think I tried honestly a bunch of more stuff

I honestly won't mind if someone can even just tell me how I can read the .index or .data files so that I can just copy the weights and from there I can deal with it

I would again really love some help,

Thanks!

It seems that your TF1.x model is saved as a ckpt format, and to restore a ckpt model, you need get the graph before load weight.

To convert it to TF2.x model, you may instantiate the original model, then save it as like recommended saved_model format use 2.x api.

Your can continue your second trying, use compat v1 to instantiate a default Session, then load graph from meta file, then load weight, after this, your Session will contain your graph and loaded weights.

To convert to 2.x model, you need get the inputs and outputs tensors from graph:

# you have loaded graph and weight into sess
sess.as_default()
g = sess.graph
# assuming that your input output names are "input:0", "output:0"
input_tensor = g.get_tensor_by_name("input:0")
output_tensor = g.get_tensor_by_name("output:0")

# then use tf2.x to save a saved_model format model
model = tf.keras.Model(input_tensor, output_tensor, name="tf2_model")
model.save("your_saved_dir")

A saved_model format model stores all graph and weight, you can simply use

model = tf.saved_model.load("your_model_dir")

to instantiate model for using.

Ok, So I think I figured it out although it was quite tedious

In the model in tensorflow 1.x all variables were created with tf.name_scope and in tensorflow 2.x there is no such thing so the variable names were unmatched and so I pretty much had to kind of manually change the names so they would fit and then it really did upload the weights as such:

checkpoint = tf.train.Checkpoint(model=gen) checkpoint.restore('checkpoints/pix2pix--60').assert_consumed()

this also seemed to work:

gen.load_weights('checkpoints/pix2pix--60')

however something is still not working correctly since the output is actually not what I am expecting (what the output is like in the tensorflow 1.x model)

It may have something to do with the batch_normalization weights that aren't being loaded but I checked and in my current tf 2.x model they are untrainable and are equal to exactly the weights that aren't being loaded

Another weird thing is that when I do gen.predict(x) it gives me a different outcome each time, so I guess the weights aren't being frozen or something...

So I have yet to understand what went wrong previously, but I do know that there have been many changes in the API of tf2 from tf1 including default parameters and more so what I eventually did which worked perfectly was this:

tf_upgrade_v2
--intree my_project/
--outtree my_project_v2/
--reportfile report.txt

as explained here

you just put all the pieces of code you want to change in folder my_project and it creates a folder named myproject_v2 with the tf1 code converted to tf2

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM