簡體   English   中英

TF Gradient Tape 有交叉產品的問題?

[英]TF Gradient Tape has issues with cross products?

我正在嘗試使用 TF 漸變帶作為自動梯度工具,通過牛頓法進行根查找。 但是當我試圖計算雅可比矩陣時,似乎 tf.GradientTape.jacobian 無法處理叉積:

x = tf.convert_to_tensor(np.array([1., 2., 3.]))
Wx = np.ones((3))
with tf.GradientTape() as tape:
    tape.watch(x)
    y = tf.linalg.cross(x, Wx)
print(tape.jacobian(y, x))

給出以下錯誤:

StagingError:在轉換后的代碼中:相對於/Users/xinzhang/anaconda3/lib/python3.7/site-packages:

tensorflow_core/python/ops/parallel_for/control_flow_ops.py:184 f  *
    return _pfor_impl(loop_fn, iters, parallel_iterations=parallel_iterations)
tensorflow_core/python/ops/parallel_for/control_flow_ops.py:257 _pfor_impl
    outputs.append(converter.convert(loop_fn_output))
tensorflow_core/python/ops/parallel_for/pfor.py:1231 convert
    output = self._convert_helper(y)
tensorflow_core/python/ops/parallel_for/pfor.py:1395 _convert_helper
    if flags.FLAGS.op_conversion_fallback_to_while_loop:
tensorflow_core/python/platform/flags.py:84 __getattr__
    wrapped(_sys.argv)
absl/flags/_flagvalues.py:633 __call__
    name, value, suggestions=suggestions)

UnrecognizedFlagError: Unknown command line flag 'f'

而如果我將對 jacobian 的調用切換為簡單的漸變:

x = tf.convert_to_tensor(np.array([1., 2., 3.]))
Wx = np.ones((3))
with tf.GradientTape() as tape:
    tape.watch(x)
    y = tf.linalg.cross(x, Wx)
print(tape.gradient(y, x))

給出預期的結果:

tf.Tensor([0. 0. 0.], shape=(3,), dtype=float64)

這是bug嗎?? 還是我對 tape.jacobian 方法做錯了什么?

ps python 版本 3.7.4; tf 版本 2.0.0 使用 conda 安裝的所有內容。

這可能是Tensorflow Version 2.0中的一個錯誤,但它已在Tensorflow Version 2.1中修復。

因此,請將您的 Tensorflow 版本升級到2.12.2 ,問題將得到解決。

工作代碼如下:

!pip install tensorflow==2.2

import tensorflow as tf
import numpy as np

print(tf.__version__)

x = tf.convert_to_tensor(np.array([1., 2., 3.]))
Wx = np.ones((3))
with tf.GradientTape() as tape:
    tape.watch(x)
    y = tf.linalg.cross(x, Wx)
print(tape.jacobian(y, x))

Output 如下圖所示:

2.2.0

tf.Tensor(
[[ 0.  1. -1.]
 [-1.  0.  1.]
 [ 1. -1.  0.]], shape=(3, 3), dtype=float64)

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM