[英]How does tf.GradientTape record operations inside the with statement?
我不明白 tf.GradientTape 如何在“with”語句(以下操作)中記錄像y=x**2
這樣的操作。
x = tf.Variable(3.0)
with tf.GradientTape() as tape:
y = x**2
可以使用什么 Python 語法來實現此行為?
編輯:
根據 GitHub 源代碼, GradientTape ,在Line 897
:
@tf_contextlib.contextmanager
def _ensure_recording(self):
"""Ensures that this tape is recording."""
if not self._recording:
try:
self._push_tape()
yield
finally:
self._pop_tape()
else:
yield
如果您不知道,只要with
關鍵字, contextmanager
就會觸發。 它告訴我們它開始跟蹤磁帶。
self._pop_tape()
在Line 891
:
def _pop_tape(self):
if not self._recording:
raise ValueError("Tape is not recording.")
tape.pop_tape(self._tape)
self._recording = False
self._push_tape()
在Line 878
:
def _push_tape(self):
"""Pushes a new tape onto the tape stack."""
if self._recording:
raise ValueError("Tape is still recording, This can happen if you try to "
"re-enter an already-active tape.")
if self._tape is None:
self._tape = tape.push_new_tape(
persistent=self._persistent,
watch_accessed_variables=self._watch_accessed_variables)
else:
tape.push_tape(self._tape)
self._recording = True
在這里,您可以注意到tape.push_new_tape
在這里被訪問,可以在Line 43
的源代碼中找到:
def push_new_tape(persistent=False, watch_accessed_variables=True):
"""Pushes a new tape onto the tape stack."""
tape = pywrap_tfe.TFE_Py_TapeSetNew(persistent, watch_accessed_variables)
return Tape(tape)
在這里,您可以在Line 31
上方看到Tape
class。
class Tape(object):
"""Represents a gradient propagation trace."""
__slots__ = ["_tape"]
def __init__(self, tape):
self._tape = tape
def watched_variables(self):
return pywrap_tfe.TFE_Py_TapeWatchedVariables(self._tape)
此外,我試圖跟蹤pywrap_tfe.TFE_Py_TapeSetNew
但無法在文件的源代碼中找到它。
原始答案:
GradientTape的文檔指出:
默認情況下,GradientTape 將自動監視在上下文中訪問的任何可訓練變量。 如果你想細粒度地控制監視哪些變量,你可以通過將 watch_accessed_variables=False 傳遞給磁帶構造函數來禁用自動跟蹤
使用以下代碼:
x = tf.Variable(2.0)
w = tf.Variable(5.0)
with tf.GradientTape(
watch_accessed_variables=False, persistent=True) as tape:
tape.watch(x)
y = x ** 2 # Gradients will be available for `x`.
z = w ** 3 # No gradients will be available as `w` isn't being watched.
dy_dx = tape.gradient(y, x)
print(dy_dx)
>>> tf.Tensor(4.0, shape=(), dtype=float32)
# No gradients will be available as `w` isn't being watched.
dz_dw = tape.gradient(z, w)
print(dz_dw)
>>> None
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.