[英]InaccessibleTensorError - When using `tf.keras.layers.Layer` output in loop condition of another layer
當我從一個層( tf.keras.layers.Layer
)使用 output 作為另一層的循環迭代器時,我得到一個InaccessibleTensorError
,
InaccessibleTensorError: The tensor 'Tensor("looper/while/sub:0", shape=(None, 1), dtype=float32)'
cannot be accessed here: it is defined in another function or code block. Use return values,
explicit Python locals or TensorFlow collections to access it. Defined in:
FuncGraph(name=looper_while_body_483, id=2098967820416); accessed from:
FuncGraph(name=looper_scratch_graph, id=2098808987904).
重現錯誤的簡約代碼,
import tensorflow as tf
import numpy as np
class Looper(tf.keras.layers.Layer):
# custom layer
def __init__(self, units, **kwargs):
super(Looper, self).__init__(**kwargs)
self.units = units
def call(self, input):
output = []
while input > 0:
input = input - 0.01
output.append(input)
return tf.stack(output, axis=1)
input_label = tf.keras.Input((1, 3))
lstm1 = tf.keras.layers.LSTM(1)
looper = Looper(10)
output = lstm1(input_label)
output = looper(output)
model = tf.keras.Model(input_label, output)
adam = tf.keras.optimizers.Adam(0.01)
model.compile(adam, 'mse')
我認為問題可能出在您在自定義層中使用的 python 列表中。 您應該為您的用例使用像tf.TensorArray
這樣的Tensorflow
集合:
import tensorflow as tf
import numpy as np
class Looper(tf.keras.layers.Layer):
# custom layer
def __init__(self, units, **kwargs):
super(Looper, self).__init__(**kwargs)
self.units = units
def call(self, input):
output = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
while input > 0:
input = input - 0.01
output = output.write(output.size(), input)
return output.stack()
input_label = tf.keras.Input((1, 3))
lstm1 = tf.keras.layers.LSTM(1)
looper = Looper(10)
output = lstm1(input_label)
output = looper(output)
model = tf.keras.Model(input_label, output)
adam = tf.keras.optimizers.Adam(0.01)
model.compile(adam, 'mse')
print(model(tf.random.normal((1, 1, 3))))
tf.Tensor(
[[[ 0.01392288]]
[[ 0.00392288]]
[[-0.00607712]]], shape=(3, 1, 1), dtype=float32)
根據您想要做什么,您可能必須從Looper
層重塑 output。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.