簡體   English   中英

tf.GradientTape 如何記錄 with 語句內部的操作?

[英]How does tf.GradientTape record operations inside the with statement?

我不明白 tf.GradientTape 如何在“with”語句(以下操作)中記錄像y=x**2這樣的操作。

x = tf.Variable(3.0)

with tf.GradientTape() as tape:
  y = x**2

可以使用什么 Python 語法來實現此行為?

編輯

根據 GitHub 源代碼, GradientTape ,在Line 897

 @tf_contextlib.contextmanager
  def _ensure_recording(self):
    """Ensures that this tape is recording."""
    if not self._recording:
      try:
        self._push_tape()
        yield
      finally:
        self._pop_tape()
    else:
      yield

如果您不知道,只要with關鍵字, contextmanager就會觸發。 它告訴我們它開始跟蹤磁帶。

self._pop_tape()Line 891

def _pop_tape(self):
    if not self._recording:
      raise ValueError("Tape is not recording.")
    tape.pop_tape(self._tape)
    self._recording = False

self._push_tape()Line 878

def _push_tape(self):
    """Pushes a new tape onto the tape stack."""
    if self._recording:
      raise ValueError("Tape is still recording, This can happen if you try to "
                       "re-enter an already-active tape.")
    if self._tape is None:
      self._tape = tape.push_new_tape(
          persistent=self._persistent,
          watch_accessed_variables=self._watch_accessed_variables)
    else:
      tape.push_tape(self._tape)
    self._recording = True

在這里,您可以注意到tape.push_new_tape在這里被訪問,可以在Line 43源代碼中找到:

def push_new_tape(persistent=False, watch_accessed_variables=True):
  """Pushes a new tape onto the tape stack."""
  tape = pywrap_tfe.TFE_Py_TapeSetNew(persistent, watch_accessed_variables)
  return Tape(tape)

在這里,您可以在Line 31上方看到Tape class。

class Tape(object):
  """Represents a gradient propagation trace."""

  __slots__ = ["_tape"]

  def __init__(self, tape):
    self._tape = tape

  def watched_variables(self):
    return pywrap_tfe.TFE_Py_TapeWatchedVariables(self._tape)

此外,我試圖跟蹤pywrap_tfe.TFE_Py_TapeSetNew但無法在文件的源代碼中找到它。

原始答案

GradientTape的文檔指出:

默認情況下,GradientTape 將自動監視在上下文中訪問的任何可訓練變量。 如果你想細粒度地控制監視哪些變量,你可以通過將 watch_accessed_variables=False 傳遞給磁帶構造函數來禁用自動跟蹤

使用以下代碼:

x = tf.Variable(2.0)
w = tf.Variable(5.0)
with tf.GradientTape(
    watch_accessed_variables=False, persistent=True) as tape:
  tape.watch(x)
  y = x ** 2  # Gradients will be available for `x`.
  z = w ** 3  # No gradients will be available as `w` isn't being watched.
dy_dx = tape.gradient(y, x)

print(dy_dx)
>>> tf.Tensor(4.0, shape=(), dtype=float32)

# No gradients will be available as `w` isn't being watched.
dz_dw = tape.gradient(z, w)

print(dz_dw)
>>> None

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM