[英]Changing the shape of a tensorflow variable in a checkpoint
我有一个带有不同变量和权重的预训练 model(检查点,tensorflow v1)。 我不知道所有的变量,但我知道两个我想改变它们的形状:v1 的形状是 [4,768],v2 的形状是 [4]。 我想分别增加到 [5,768] 和 [5] 并再次保存检查点以进行微调。 为了填补缺失的数据,我想取变量的平均值。
这是我的代码:
# The vars I want to change
v1 = tf.get_variable("v1", shape=[4, 768], initializer=utils.classification_initializer())
v2 = tf.get_variable("v2", shape=[4], initializer=tf.zeros_initializer())
checkpoint = {}
saver = tf.train.Saver()
with tf.Session() as sess:
# Restore checkpoint from source location (path).
saver.restore(sess, source)
# Get the vars values
checkpoint[v1.name] = v1.eval()
checkpoint[v2.name] = v2.eval()
new_data = {}
# Calc v1 average and reshape
avg = numpy.average(checkpoint[v1.name], axis=0)
new_data[v1.name] = numpy.vstack((checkpoint[v1.name], avg))
# Calc v2 average and reshape
avg = numpy.average(checkpoint[v2.name], axis=0)
new_data[v2.name] = numpy.append(checkpoint[v2.name], avg)
# Assign the new data and shape
sess.run(tf.assign(v1, new_data[v1.name], validate_shape=False))
sess.run(tf.assign(v2, new_data[v2.name], validate_shape=False))
# Save the checkpoint to target location (path).
saver.save(sess, target)
我期待看到类似大小的 model(源检查点大约 1GB),但我得到一个小得多的文件(目标检查点大约 15KB)。 似乎只保存了我更改的变量而不是整个检查点(其他变量、权重等)。
1 - 这是实现我的目标的方法(在检查点重塑和填充 2 个变量)?
2 - 如果是这样,我如何保存整个 model(其他变量、权重等)而不仅仅是加载的变量?
更新
model 最初是在 TPU 机器上(由其他人)训练的。 因此加载元图在 GPU 机器(我的机器)上不起作用。 但是,使用 tf.estimator.tpu.TPUEstimator 我可以训练和预测这个 model。因此 TPUEstimator 有办法加载所有内容,更改变量并保存 model。
model: https://storage.googleapis.com/tapas_models/2020_10_07/tapas_wikisql_sqa_inter_masklm_base_reset.zip
要更改的变量示例:output_weights_agg 是 [4, 768],output_bias_agg 是 [4]。
完整代码示例:
https://colab.research.google.com/drive/1yoyZ-45So5pEIGmZp85ut38lW653KHXL?usp=sharing
在您的代码中,图表中只有两个变量(v1 和 v2),然后保护程序只会从检查点恢复它们。
您可以先从检查点导入图形,然后执行您想执行的操作。
您的示例代码,tensorflow 版本==1.15.0
checkpoint = {}
# import graph from checkoutpoint meta like "/tmp/model.ckpt.meta"
saver = tf.train.import_meta_graph("path of meta")
with tf.Session() as sess:
# Restore checkpoint from source location (path) like "/tmp/model.ckpt".
saver.restore(sess, source)
# print(tf.global_variables())
# you can get the name of variable from graph through tf.global_variables()
v1 = [v for v in tf.global_variables() if v.name == "v1:0"][0]
v2 = [v for v in tf.global_variables() if v.name == "v2:0"][0]
# Get the vars values
checkpoint[v1.name] = v1.eval()
checkpoint[v2.name] = v2.eval()
new_data = {}
# Calc v1 average and reshape
avg = numpy.average(checkpoint[v1.name], axis=0)
new_data[v1.name] = numpy.vstack((checkpoint[v1.name], avg))
# Calc v2 average and reshape
avg = numpy.average(checkpoint[v2.name], axis=0)
new_data[v2.name] = numpy.append(checkpoint[v2.name], avg)
# Assign the new data and shape
sess.run(tf.assign(v1, new_data[v1.name], validate_shape=False))
sess.run(tf.assign(v2, new_data[v2.name], validate_shape=False))
# Save the checkpoint to target location (path).
saver.save(sess, target)
更新
解决方法:从检查点打印所有张量,获取它们的名称和形状,然后使用 tf.get_variable 在图中构建变量。
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
# List ALL tensors
print_tensors_in_checkpoint_file(file_name=checkpoint_path, tensor_name='', all_tensors=True)
替代方法:
# List ALL tensors
vars_list = tf.train.list_variables(checkpoint_path)
print(vars_list)
PS:当import graph from meta data来自其他平台时,可能需要自己建图,节点之间的关系可以从graph_def中找到
tf.get_default_graph().as_graph_def().node
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.