简体   繁体   中英

Problem with tf.train.Saver() and GPU - TensorFlow

My code is structured as follows:

with tf.device('/gpu:1'):
...
model = get_model(input_pl)
...
    with tf.Session() as sess:
        saver = tf.train.Saver()
        sess.run(tf.global_variables_initializer())
        for epoch in range(num_epochs):
           ...
           for n in range(num_batches):
              ...
              sess.run(...)
           # eval epoch
        saver.save(sess, ...)

I want to save the model after the training phase. When I run it gives me this error:

InvalidArgumentError (see above for traceback): Cannot assign a device for operation 'save/SaveV2': Could not satisfy explicit device specification '/device:GPU:1' because no supported kernel for GPU devices is available.

Reading this question I changed the code in this way:

saver = tf.train.Saver()
with tf.device('/gpu:1'):
...
model = get_model(pointcloud_pl)
...
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for epoch in range(num_epochs):
           ...
           for n in range(num_batches):
              ...
              sess.run(...)
           # eval epoch
        saver.save(sess, ...)

But now I get this error:

ValueError: No variables to save

I've tried also to do this way:

with tf.Session() as sess:
    saver = tf.train.Saver()
    ...
    with tf.device('/gpu:1'):
        sess.run(tf.global_variables_initializer())
        for epoch in range(num_epochs):
        ...
            for n in range(num_batches):
               ...
               sess.run()
            # eval epoch
        saver.save(sess, ...)

And I still get the same error. The error is always in the saver = tf.train.Saver() line.

How can I solve this problem?

Solved doing this:

  1. tf.Session()
  2. model
  3. saver = tf.train.Saver()
  4. with tf.device():

Here an example code

with tf.Session() as sess:
    ...
    model = get_model(input_pl)
    saver = tf.train.Saver()
    ...
    with tf.device('/gpu:1'):
        sess.run(tf.global_variables_initializer())
        for epoch in range(num_epochs):
        ...
            for n in range(num_batches):
               ...
               sess.run()
            # eval epoch
        saver.save(sess, ...)

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM