简体   繁体   中英

Getting values of trainable parameters in tensorflow

I am trying to extract all trainable weights from a model. In pytorch a similar thing would be done by a single line p.grad.data for p in model.parameters() if p.requires_grad , however I'm struggling to come up with a simple solution in TF.

My current attempt looks like this:

sess = tf.Session()

... #model initialization and training here

p = model.trainable_weights
p_vals = sess.run(p)

the last line, however, produces an error:

  File "/.../lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1356, in _do_call
    return fn(*args)
  File "/.../lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1341, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "/.../lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1429, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.FailedPreconditionError: Error while reading resource variable conv1/bias from Container: localhost. This could mean that the variable was uninitialized. Not found: Container localhost does not exist. (Could not find resource: localhost/conv1/bias)
     [[{{node conv1/bias/Read/ReadVariableOp}}]]

What am I doing wrong here? I'm assuming the session/graph doesn't link to the model properly? Or is it indeed an initialization problem (but then the model is capable of successfull training)?

It is easier you do it by using custom callback Fn, and trading with your custom actions follows !

class custom_callback(tf.keras.callbacks.Callback): 
tf.summary.create_file_writer(val_dir)

def _val_writer(self):
    if 'val' not in self._writers:
        self._writers['val'] = tf.summary.create_file_writer(val_dir)
    return self._writers['val']

def on_epoch_end(self, epoch, logs={}):
    print('weights: ' + str(self.model.get_weights()))
    
    if self.model.optimizer and hasattr(self.model.optimizer, 'iterations'):
        with tf.summary.record_if(True): # self._val_writer.as_default():
            step = ''
            for name, value in logs.items():
                tf.summary.scalar(
                'evaluation_' + name + '_vs_iterations',
                value,
                step=self.model.optimizer.iterations.read_value(),
                )
            print('step :' + str(self.model.optimizer.iterations.read_value()))

    if(logs['accuracy'] == None) : pass
    else:
        if(logs['accuracy']> 0.90):
            self.model.stop_training = True

custom_callback = custom_callback()

history = model_highscores.fit(batched_features, epochs=99 ,validation_data=(dataset.shuffle(len(list_image))), callbacks=[custom_callback]) 

来自客户回调 Fn 的权重值

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM