[英]Getting values of trainable parameters in tensorflow
我正在嘗試從 model 中提取所有可訓練的權重。 在 pytorch p.grad.data for p in model.parameters() if p.requires_grad
可以完成類似的事情,但是我正在努力在 TF 中想出一個簡單的解決方案。
我目前的嘗試如下所示:
sess = tf.Session()
... #model initialization and training here
p = model.trainable_weights
p_vals = sess.run(p)
但是,最后一行會產生錯誤:
File "/.../lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1356, in _do_call
return fn(*args)
File "/.../lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1341, in _run_fn
options, feed_dict, fetch_list, target_list, run_metadata)
File "/.../lib/python3.5/site-packages/tensorflow/python/client/session.py", line 1429, in _call_tf_sessionrun
run_metadata)
tensorflow.python.framework.errors_impl.FailedPreconditionError: Error while reading resource variable conv1/bias from Container: localhost. This could mean that the variable was uninitialized. Not found: Container localhost does not exist. (Could not find resource: localhost/conv1/bias)
[[{{node conv1/bias/Read/ReadVariableOp}}]]
我在這里做錯了什么? 我假設會話/圖表沒有正確鏈接到 model? 或者它確實是一個初始化問題(但是 model 能夠成功訓練)?
使用自定義回調 Fn 會更容易,然后使用自定義操作進行交易!
class custom_callback(tf.keras.callbacks.Callback):
tf.summary.create_file_writer(val_dir)
def _val_writer(self):
if 'val' not in self._writers:
self._writers['val'] = tf.summary.create_file_writer(val_dir)
return self._writers['val']
def on_epoch_end(self, epoch, logs={}):
print('weights: ' + str(self.model.get_weights()))
if self.model.optimizer and hasattr(self.model.optimizer, 'iterations'):
with tf.summary.record_if(True): # self._val_writer.as_default():
step = ''
for name, value in logs.items():
tf.summary.scalar(
'evaluation_' + name + '_vs_iterations',
value,
step=self.model.optimizer.iterations.read_value(),
)
print('step :' + str(self.model.optimizer.iterations.read_value()))
if(logs['accuracy'] == None) : pass
else:
if(logs['accuracy']> 0.90):
self.model.stop_training = True
custom_callback = custom_callback()
history = model_highscores.fit(batched_features, epochs=99 ,validation_data=(dataset.shuffle(len(list_image))), callbacks=[custom_callback])
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.