簡體   English   中英

TF 2.0 @ tf.function示例

[英]TF 2.0 @tf.function example

簽名部分的tensorflow文檔中,我們具有以下代碼片段

@tf.function
def train(model, optimizer):
  train_ds = mnist_dataset()
  step = 0
  loss = 0.0
  accuracy = 0.0
  for x, y in train_ds:
    step += 1
    loss = train_one_step(model, optimizer, x, y)
    if tf.equal(step % 10, 0):
      tf.print('Step', step, ': loss', loss, '; accuracy', compute_accuracy.result())
  return step, loss, accuracy

step, loss, accuracy = train(model, optimizer)
print('Final step', step, ': loss', loss, '; accuracy', compute_accuracy.result())

我對step變量有一個小問題,它是一個整數而不是張量,簽名支持內置的python類型,例如integer。 因此,可以將tf.equal(step%10,0)更改為簡單的step%10 == 0對嗎?

你是對的。 即使將整數變量step轉換為其圖形表示形式,它仍然是Python變量。 您可以通過調用tf.autograph.to_code(train.python_function)來查看轉換結果。

如果不報告所有代碼,而僅報告與step變量相關的部分,您將看到

  def loop_body(loop_vars, loss_1, step_1):
    with ag__.function_scope('loop_body'):
      x, y = loop_vars
      step_1 += 1

仍然是python操作(否則,如果步驟1是tf.Tensor則它將為step_1.assign_add(1) )。

有關簽名和tf.function的更多信息,我建議閱讀文章https://pgaleone.eu/tensorflow/tf.function/2019/03/21/dissecting-tf-function-part-1/輕松解釋何時發生的情況函數被轉換。

盡管這在生成的代碼中不可見,但step變量實際上將通過for循環自動裝箱到Tensor,該循環被轉換為TF while_loop。

您可以通過添加打印語句來驗證這一點:

    loss = train_one_step(model, optimizer, x, y)
    print(step)
    if tf.equal(step % 10, 0):

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM