簡體   English   中英

獲取 tensorflow 中可訓練參數的值

[英]Getting values of trainable parameters in tensorflow

我正在嘗試從 model 中提取所有可訓練的權重。 在 pytorch p.grad.data for p in model.parameters() if p.requires_grad可以完成類似的事情,但是我正在努力在 TF 中想出一個簡單的解決方案。

我目前的嘗試如下所示:

sess = tf.Session()

... #model initialization and training here

p = model.trainable_weights
p_vals = sess.run(p)

但是,最后一行會產生錯誤:

  File "/.../lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1356, in _do_call
    return fn(*args)
  File "/.../lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1341, in _run_fn
    options, feed_dict, fetch_list, target_list, run_metadata)
  File "/.../lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1429, in _call_tf_sessionrun
    run_metadata)
tensorflow.python.framework.errors_impl.FailedPreconditionError: Error while reading resource variable conv1/bias from Container: localhost. This could mean that the variable was uninitialized. Not found: Container localhost does not exist. (Could not find resource: localhost/conv1/bias)
     [[{{node conv1/bias/Read/ReadVariableOp}}]]

我在這里做錯了什么? 我假設會話/圖表沒有正確鏈接到 model? 或者它確實是一個初始化問題(但是 model 能夠成功訓練)?

使用自定義回調 Fn 會更容易,然后使用自定義操作進行交易!

class custom_callback(tf.keras.callbacks.Callback): 
tf.summary.create_file_writer(val_dir)

def _val_writer(self):
    if 'val' not in self._writers:
        self._writers['val'] = tf.summary.create_file_writer(val_dir)
    return self._writers['val']

def on_epoch_end(self, epoch, logs={}):
    print('weights: ' + str(self.model.get_weights()))
    
    if self.model.optimizer and hasattr(self.model.optimizer, 'iterations'):
        with tf.summary.record_if(True): # self._val_writer.as_default():
            step = ''
            for name, value in logs.items():
                tf.summary.scalar(
                'evaluation_' + name + '_vs_iterations',
                value,
                step=self.model.optimizer.iterations.read_value(),
                )
            print('step :' + str(self.model.optimizer.iterations.read_value()))

    if(logs['accuracy'] == None) : pass
    else:
        if(logs['accuracy']> 0.90):
            self.model.stop_training = True

custom_callback = custom_callback()

history = model_highscores.fit(batched_features, epochs=99 ,validation_data=(dataset.shuffle(len(list_image))), callbacks=[custom_callback]) 

來自客戶回調 Fn 的權重值

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM