简体   繁体   English

使用tf.data.Dataset可以使保存的模型更大

[英]Using tf.data.Dataset makes saved model bigger

I recently have an issue with saving the model in a bigger size. 我最近遇到了将模型保存到更大尺寸的问题。 I am using tensorflow 1.4 我正在使用tensorflow 1.4

Before, I used 以前,我用过

tf.train.string_input_producer() and tf.train.batch() tf.train.string_input_producer()tf.train.batch()

to load images from a text file. 从文本文件加载图像。 And in the training, 在训练中,

tf.train.start_queue_runners() and tf.train.Coordinator() tf.train.start_queue_runners()tf.train.Coordinator()

were used to provide data to the network. 用于向网络提供数据。 In this case, every time I saved the model using 在这种情况下,每次我使用保存模型

saver.save(sess, checkpoint_path, global_step=iters)

only gave me a small size file, ie a file named model.ckpt-1000.data-00000-of-00001 with 1.6MB. 只给了我一个小尺寸的文件,即一个名为model.ckpt-1000.data-00000-of-00001且1.6MB的文件。

Now, I use 现在,我用

tf.data.Dataset.from_tensor_slices()

to supply images to an input placeholder and the saved model become 290MB. 将图像提供给输入placeholder ,保存的模型变为290MB。 But I don't know why. 但我不知道为什么。 I suspect the tensorflow saver saved the dataset into the model as well. 我怀疑张量流saver将数据集保存到模型中。 If so, how to remove them to make it smaller, and only the weights of the network are saved. 如果是这样,如何删除它们以使其变小,并且仅保存网络的权重。

This is not network depended because I tried in two networks and they were all like that. 这不是网络依赖,因为我尝试了两个网络,他们都是这样的。

I have googled but unfortunately didn't see any inspiration related to this issue. 我用谷歌搜索,但遗憾的是没有看到任何与此问题相关的灵感。 (Or this is not an issue, just I don't know how do?) (或者这不是问题,只是我不知道怎么办?)

Thank you very much for any idea and help! 非常感谢您的任何想法和帮助!

Edit 编辑

The method I initialised the dataset is: 我初始化数据集的方法是:

1.First generated numpy.array dataset: numpy.array生成numpy.array数据集:

self.train_hr, self.train_lr = cifar10.load_dataset(sess)

The initial dataset is numpy.array, for example [8000,32,32,3] . 初始数据集是numpy.array,例如[8000,32,32,3] I passed sess into this function is because in the function, I did tf.image.resize_images() and use sess.run() to generate numpy.array . 我把sess传递给了这个函数是因为在函数中,我做了tf.image.resize_images()并使用sess.run()来生成numpy.array The returns self.train_hr and self.train_lr are numpy.array in shape [8000,64,64,3] . 返回self.train_hrself.train_lr的形状为numpy.array [8000,64,64,3]

2.Then I created the dataset: 然后我创建了数据集:

self.img_hr = tf.placeholder(tf.float32)
self.img_lr = tf.placeholder(tf.float32)
dataset = tf.data.Dataset.from_tensor_slices((self.img_hr, self.img_lr))
dataset = dataset.repeat(conf.num_epoch).shuffle(buffer_size=conf.shuffle_size).batch(conf.batch_size)
self.iterator = dataset.make_initializable_iterator()
self.next_batch = self.iterator.get_next()

3.Then I initialised network and dataset, did the training and saved model: 然后我初始化了网络和数据集,做了培训并保存了模型:

self.labels = tf.placeholder(tf.float32,
                                     shape=[conf.batch_size, conf.hr_size, conf.hr_size, conf.img_channel])
self.inputs = tf.placeholder(tf.float32,
                                     shape=[conf.batch_size, conf.lr_size, conf.lr_size, conf.img_channel])
self.net = Net(self.labels, self.inputs, mask_type=conf.mask_type,
                       is_linear_only=conf.linear_mapping_only, scope='sr_spc')

sess.run(self.iterator.initializer,
                         feed_dict={self.img_hr: self.train_hr, self.img_lr: self.train_lr})
while True:
    hr_img, lr_img = sess.run(self.next_batch)
    _, loss, summary_str = sess.run([train_op, self.net.loss, summary_op],
                                    feed_dict={self.labels: hr_img, self.inputs: lr_img})
    ...
    ...
    checkpoint_path = os.path.join(conf.model_dir, 'model.ckpt')
    saver.save(sess, checkpoint_path, global_step=iters)

All the sess are the same instance. 所有sess都是同一个实例。

I suspect you created a tensorflow constant tf.constant out of your dataset, which would explain why the dataset gets stored with the graph. 我怀疑你从数据集中创建了张量流常数tf.constant ,这可以解释数据集与图表一起存储的原因。 There is an initializeable dataset which let's you feed in the data using feed_dict at runtime. 有一个可初始化的数据集,您feed_dict在运行时使用feed_dict提供数据。 It's a few extra lines of code to configure but it's probably what you wanted to use. 这是一些额外的代码行来配置,但它可能是您想要使用的。

https://www.tensorflow.org/programmers_guide/datasets https://www.tensorflow.org/programmers_guide/datasets

Note that constants get created for you automatically in the Python wrapper. 请注意,在Python包装器中自动为您创建常量。 The following statements are equivalent: 以下陈述是等效的:

tf.Variable(42)
tf.Variable(tf.constant(42))

Tensorflow indeed saves your dataset. Tensorflow确实保存了您的数据集。 To solve it, lets understand why. 要解决它,让我们理解为什么。

How tensorflow works and what does it save? 张量流如何工作以及它节省了什么?

In short, Tensorflow API lets you build a computation graph via code, and then optimize it. 简而言之,Tensorflow API允许您通过代码构建计算图,然后对其进行优化。 Every op/variable/constant you define in the graph is working on tensors and is part of that graph. 您在图中定义的每个op / variable / constant都在处理张量并且是该图的一部分。 This framework is convenient since Tensorflow just build a graph, then the framework decides (or you specify) where to compute the graph in order to gain maximum speed out of your hardware, for instance, by computing on your GPU. 这个框架很方便,因为Tensorflow只是构建一个图形,然后框架决定(或指定)计算图形的位置,以便从硬件中获得最大速度,例如,通过在GPU上进行计算。

The GPU is a great example since this is a great example for your issue. GPU就是一个很好的例子,因为这是一个很好的例子。 Sending data from HDD/RAM/Processor to GPU is expensive time-wise. 从HDD / RAM /处理器向GPU发送数据在时间上是昂贵的。 Therefore, Tensorflow also allow you to create input producers that will pretty much automatically manage the data transferred between all peripheral units, by queuing them and managing threads . 因此,Tensorflow还允许您创建输入生成器,通过排队和管理线程 ,几乎可以自动管理在所有外围设备之间传输的数据。 However, I haven't seen much gain from that approach. 但是,我没有从这种方法中获得太多收益。 Note that the inputs produced by datasets are also tensors, specifically constants/variables that are used as input to the network. 请注意,数据集生成的输入也是张量,特别是用作网络输入的常量/变量。 . Therefore, they are part of the graph. 因此,它们是图表的一部分。

When saving a graph, we save several things: 保存图表时,我们保存了以下几项内容:

  1. Metadata - which defines the graph and its structure. 元数据 - 定义图形及其结构。
  2. Values - of each variable/constant in the graph, in order to load it and reuse the network. 值 - 图中每个变量/常量的值,以便加载和重用网络。

When you use datasets, the values of the non-trainable variables are saved, and therefore, your checkpoint file is larger. 使用数据集时,将保存不可训练变量的值,因此检查点文件较大。

To better understand datasets, see its implementation in the package files. 要更好地理解数据集,请在包文件中查看其实现。

TL;DR - How do I fix my problem? TL; DR - 如何解决我的问题?

  1. If its not reducing performance, use feeding dictionary to feed placeholders. 如果它没有降低性能,请使用馈送字典来提供占位符。 Do not use tensors to store your data. 不要使用张量来存储数据。 This way those variables will not be saved. 这样就不会保存这些变量。

  2. Save only tensors that you would like to load (weights, biases, etc). 仅保存您要加载的张量(权重,偏差等)。 You can user .eval() method to find its values, save it as JSON or such, and load it later by reconstructing the graph. 您可以使用.eval()方法查找其值,将其另存为JSON等,稍后通过重新构建图形来加载它。

Good luck! 祝好运!

I solved this issue (not perfectly as I still don't know where the problem happens). 我解决了这个问题(不完美,因为我仍然不知道问题出在哪里)。 Instead, I made a workaround to avoid saving a large amount of data. 相反,我做了一个解决方法,以避免保存大量数据。

I defined a saver fed in a specific list of variables. 我在一个特定的变量列表中定义了一个saver That list only contains the nodes of my graph. 该列表仅包含我的图表的节点。 Here I show a small example of my workaround: 在这里,我展示了我的解决方法的一个小例子:

import tensorflow as tf  

v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")  
v2= tf.Variable(tf.zeros([200]), name="v2")  

saver = tf.train.Saver( [v2])  
# saver = tf.train.Saver()  
with tf.Session() as sess:  
    init_op = tf.global_variables_initializer()  
    sess.run(init_op)  
    saver.save(sess,"checkpoint/model_test",global_step=1)

v2 is the variable list. v2是变量列表。 Or you can use variables_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net') to collect all the nodes. 或者您可以使用variables_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='net')来收集所有节点。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM