[英]If-else in @tf.function
我想定义一个自定义LearningRateSchedule
,但 AutoGraph 似乎无法转换它。 以下代码在没有@tf.function 的情况下工作正常。 但是在使用@tf.function
时会引发错误
def linear_interpolation(l, r, alpha):
return l + alpha * (r - l)
class TFPiecewiseSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
# This class currently cannot be used in @tf.function,
# Since tf.cond See the following link for details
def __init__(self, endpoints, end_learning_rate=None, name=None):
"""Piecewise schedule.
endpoints: [(int, int)]
list of pairs `(time, value)` meanining that schedule should output
`value` when `t==time`. All the values for time must be sorted in
an increasing order. When t is between two times, e.g. `(time_a, value_a)`
and `(time_b, value_b)`, such that `time_a <= t < time_b` then value outputs
`interpolation(value_a, value_b, alpha)` where alpha is a fraction of
time passed between `time_a` and `time_b` for time `t`.
outside_value: float
if the value is requested outside of all the intervals sepecified in
`endpoints` this value is returned. If None then AssertionError is
raised when outside value is requested.
"""
super().__init__()
idxes = [e[0] for e in endpoints]
assert idxes == sorted(idxes)
self.end_learning_rate = end_learning_rate or endpoints[-1][1]
self.endpoints = endpoints
self.name=name
def __call__(self, step):
if step < self.endpoints[0][0]:
return self.endpoints[0][1]
else:
for (l_t, l), (r_t, r) in zip(self.endpoints[:-1], self.endpoints[1:]):
if l_t <= step < r_t:
alpha = float(step - l_t) / (r_t - l_t)
return linear_interpolation(l, r, alpha)
# t does not belong to any of the pieces, so doom.
assert self.end_learning_rate is not None
return self.end_learning_rate
def get_config(self):
return dict(
endpoints=self.endpoints,
end_learning_rate=self.end_learning_rate,
name=self._name,
)
lr = TFPiecewiseSchedule([[10, 1e-3], [20, 1e-4]])
@tf.function
def f(x):
l = layers.Dense(10)
with tf.GradientTape() as tape:
y = l(x)
loss = tf.reduce_mean(y**2)
grads = tape.gradient(loss, l.trainable_variables)
opt = tf.keras.optimizers.Adam(lr)
opt.apply_gradients(zip(grads, l.trainable_variables))
f(tf.random.normal((2, 3)))
错误消息说:
:10 f * opt.apply_gradients(zip(grads, l.trainable_variables))
/Users/aptx4869/anaconda3/envs/drl/lib/python3.7/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py:437 apply_gradients apply_state = self._prepare(var_list)
/Users/aptx4869/anaconda3/envs/drl/lib/python3.7/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py:614 _prepare self._prepare_local(var_device, var_dtype, apply_state)
/Users/aptx4869/anaconda3/envs/drl/lib/python3.7/site-packages/tensorflow_core/python/keras/optimizer_v2/adam.py:154 _prepare_local super(Adam, self)._prepare_local(var_device, var_dtype, apply_state)
/Users/aptx4869/anaconda3/envs/drl/lib/python3.7/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py:620 _prepare_local lr_t = array_ops.identity(self._decayed_lr)(var_dtype)
/Users/aptx4869/anaconda3/envs/drl/lib/python3.7/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py:672 _decayed_lr lr_t = math_ops.cast(lr_t(local_step), var_dtype)
:32 如果 step < self.endpoints[0][0] 则调用:
/Users/aptx4869/anaconda3/envs/drl/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py:765 bool self._disallow_bool_casting()
/Users/aptx4869/anaconda3/envs/drl/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py:531 _disallow_bool_casting “使用
tf.Tensor
作为 Pythonbool
”)/Users/aptx4869/anaconda3/envs/drl/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py:518 _disallow_when_autograph_enabled "直接用@tf.function 装饰它。".format(task))
OperatorNotAllowedInGraphError:不允许使用
tf.Tensor
作为 Pythonbool
:AutoGraph 未转换此函数。 尝试直接用@tf.function 装饰它。
我认为错误是由 if 语句引起的,所以我将__call__
函数的内容替换为以下代码。 但几乎相同的错误出现。
def compute_lr(step):
for (l_t, l), (r_t, r) in zip(self.endpoints[:-1], self.endpoints[1:]):
if l_t <= step < r_t:
alpha = float(step - l_t) / (r_t - l_t)
return linear_interpolation(l, r, alpha)
# t does not belong to any of the pieces, so doom.
assert self.end_learning_rate is not None
return self.end_learning_rate
return tf.cond(tf.less(step, self.endpoints[0][0]), lambda: self.endpoints[0][1], lambda: compute_lr(step))
我应该怎么做才能使代码按我的意愿工作?
报错信息被markdown formatter弄乱了,不过貌似是__call__
函数本身没有被AutoGraph处理。 在错误消息中,转换后的函数用星号标记。 这是 Adam 优化器中的一个错误。 无论如何,您可以直接使用 tf.function 对其进行注释,它将被拾取:
@tf.function
def __call__(self, step):
也就是说,AutoGraph 不喜欢代码中的一些内容: zip
、从循环中返回、链式不等式 - 尽可能使用基本构造更安全。 可悲的是,你得到的错误仍然有点令人困惑。 像这样重写它应该可以工作:
@tf.function
def __call__(self, step):
if step < self.endpoints[0][0]:
return self.endpoints[0][1]
else:
# Can't return from a loop
lr = self.end_learning_rate
# Since it needs to break based on the value of a tensor, loop
# needs to be a tf.while_loop
for pair in tf.stack([self.endpoints[:-1], self.endpoints[1:]], axis=1):
left, right = tf.unstack(pair)
l_t, l = tf.unstack(left)
r_t, r = tf.unstack(right)
# Chained inequalities not supported yet
if l_t <= step and step < r_t:
alpha = float(step - l_t) / (r_t - l_t)
lr = linear_interpolation(l, r, alpha)
break
return lr
还有最后一个问题 - tf.function
在创建变量时不喜欢它,因此您需要将层的创建和优化器移到外面:
lr = TFPiecewiseSchedule([[10, 1e-3], [20, 1e-4]])
l = layers.Dense(10)
opt = tf.keras.optimizers.Adam(lr)
@tf.function
def f(x):
...
我希望这有帮助!
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.