繁体   English   中英

在检查点中更改 tensorflow 变量的形状

[英]Changing the shape of a tensorflow variable in a checkpoint

我有一个带有不同变量和权重的预训练 model(检查点,tensorflow v1)。 我不知道所有的变量,但我知道两个我想改变它们的形状:v1 的形状是 [4,768],v2 的形状是 [4]。 我想分别增加到 [5,768] 和 [5] 并再次保存检查点以进行微调。 为了填补缺失的数据,我想取变量的平均值。

这是我的代码:

# The vars I want to change
v1 = tf.get_variable("v1", shape=[4, 768], initializer=utils.classification_initializer())

v2 = tf.get_variable("v2", shape=[4], initializer=tf.zeros_initializer())

checkpoint = {}
saver = tf.train.Saver()

with tf.Session() as sess:

 # Restore checkpoint from source location (path).
 saver.restore(sess, source)

 # Get the vars values
 checkpoint[v1.name] = v1.eval()
 checkpoint[v2.name] = v2.eval()

 new_data = {}
 # Calc v1 average and reshape
 avg = numpy.average(checkpoint[v1.name], axis=0)
 new_data[v1.name] = numpy.vstack((checkpoint[v1.name], avg))

 # Calc v2 average and reshape
 avg = numpy.average(checkpoint[v2.name], axis=0)
 new_data[v2.name] = numpy.append(checkpoint[v2.name], avg)

 # Assign the new data and shape
 sess.run(tf.assign(v1, new_data[v1.name], validate_shape=False))
 sess.run(tf.assign(v2, new_data[v2.name], validate_shape=False))

 # Save the checkpoint to target location (path).
 saver.save(sess, target)

我期待看到类似大小的 model(源检查点大约 1GB),但我得到一个小得多的文件(目标检查点大约 15KB)。 似乎只保存了我更改的变量而不是整个检查点(其他变量、权重等)。

1 - 这是实现我的目标的方法(在检查点重塑和填充 2 个变量)?

2 - 如果是这样,我如何保存整个 model(其他变量、权重等)而不仅仅是加载的变量?

更新

model 最初是在 TPU 机器上(由其他人)训练的。 因此加载元图在 GPU 机器(我的机器)上不起作用。 但是,使用 tf.estimator.tpu.TPUEstimator 我可以训练和预测这个 model。因此 TPUEstimator 有办法加载所有内容,更改变量并保存 model。

model: https://storage.googleapis.com/tapas_models/2020_10_07/tapas_wikisql_sqa_inter_masklm_base_reset.zip

要更改的变量示例:output_weights_agg 是 [4, 768],output_bias_agg 是 [4]。

完整代码示例:

https://colab.research.google.com/drive/1yoyZ-45So5pEIGmZp85ut38lW653KHXL?usp=sharing

在您的代码中,图表中只有两个变量(v1 和 v2),然后保护程序只会从检查点恢复它们。
您可以先从检查点导入图形,然后执行您想执行的操作。

您的示例代码,tensorflow 版本==1.15.0

checkpoint = {}
# import graph from checkoutpoint meta like "/tmp/model.ckpt.meta"
saver = tf.train.import_meta_graph("path of meta")

with tf.Session() as sess:

 # Restore checkpoint from source location (path) like "/tmp/model.ckpt".
 saver.restore(sess, source)

 # print(tf.global_variables())
 # you can get the name of variable from graph through tf.global_variables()
 v1 = [v for v in tf.global_variables() if v.name == "v1:0"][0]
 v2 = [v for v in tf.global_variables() if v.name == "v2:0"][0]

 # Get the vars values
 checkpoint[v1.name] = v1.eval()
 checkpoint[v2.name] = v2.eval()

 new_data = {}
 # Calc v1 average and reshape
 avg = numpy.average(checkpoint[v1.name], axis=0)
 new_data[v1.name] = numpy.vstack((checkpoint[v1.name], avg))

 # Calc v2 average and reshape
 avg = numpy.average(checkpoint[v2.name], axis=0)
 new_data[v2.name] = numpy.append(checkpoint[v2.name], avg)

 # Assign the new data and shape
 sess.run(tf.assign(v1, new_data[v1.name], validate_shape=False))
 sess.run(tf.assign(v2, new_data[v2.name], validate_shape=False))

 # Save the checkpoint to target location (path).
 saver.save(sess, target)

更新

解决方法:从检查点打印所有张量,获取它们的名称和形状,然后使用 tf.get_variable 在图中构建变量。

from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file

# List ALL tensors 
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='', all_tensors=True)

替代方法:

# List ALL tensors
vars_list = tf.train.list_variables(checkpoint_path)
print(vars_list)

PS:当import graph from meta data来自其他平台时,可能需要自己建图,节点之间的关系可以从graph_def中找到

tf.get_default_graph().as_graph_def().node

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM