![](/img/trans.png)
[英]Training a tf.keras model with a basic low-level TensorFlow training loop doesn't work
[英]Combine tensorflow low level API (tensors/placeholders) with Keras model
根據張量流 。 使用tf.keras.Input
提供占位符,使用tf.keras.layers.Dense
提供張量。 因此,我想使用帶有Tensorflow低級API的Tensor和Placeholders測試等效性,然后使用keras高級別API訓練我的模型。 這是我的代碼:
tf.reset_default_graph()
inputs = tf.placeholder(tf.float32, shape=[None, 32])
W_h = tf.get_variable(name="W_h", shape=[32, 64], initializer=tf.truncated_normal_initializer(stddev=0.01))
W_out = tf.get_variable(name="W_out", shape=[64, 10], initializer=tf.truncated_normal_initializer(stddev=0.01))
h = tf.nn.relu(tf.matmul(inputs, W_h, name="MatMul"), name='relu')
predictions = tf.nn.relu(tf.matmul(h, W_out, name="MatMul"), name='relu')
model = tf.keras.Model(inputs=inputs, outputs=predictions)
model.compile(loss='mean_squared_error', optimizer='sgd') # sgd stands for stochastic gradient descent
model.fit(x_train, y_train, batch_size=32, epochs=5)
但是,在調用tf.keras.Model
出現錯誤:
InvalidArgumentError: You must feed a value for placeholder tensor 'Placeholder' with dtype float and shape [?,32]
[[{{node Placeholder}}]]
我確實給輸入占位符了嗎?
PS:完整錯誤消息:
InvalidArgumentError Traceback (most recent call last)
<ipython-input-15-27f92d1f784d> in <module>
8 predictions = tf.nn.relu(tf.matmul(h, W_out, name="MatMul"), name='relu')
9
---> 10 model = tf.keras.Model(inputs=inputs, outputs=predictions)
11 model.compile(loss='mean_squared_error', optimizer='sgd') # sgd stands for stochastic gradient descent
12 model.fit(x_train, y_train, batch_size=32, epochs=5)
~\.conda\envs\tensorflow_cpu\lib\site-packages\tensorflow\python\keras\engine\training.py in __init__(self, *args, **kwargs)
127
128 def __init__(self, *args, **kwargs):
--> 129 super(Model, self).__init__(*args, **kwargs)
130 # initializing _distribution_strategy here since it is possible to call
131 # predict on a model without compiling it.
~\.conda\envs\tensorflow_cpu\lib\site-packages\tensorflow\python\keras\engine\network.py in __init__(self, *args, **kwargs)
160 'inputs' in kwargs and 'outputs' in kwargs):
161 # Graph network
--> 162 self._init_graph_network(*args, **kwargs)
163 else:
164 # Subclassed network
~\.conda\envs\tensorflow_cpu\lib\site-packages\tensorflow\python\training\tracking\base.py in _method_wrapper(self, *args, **kwargs)
455 self._self_setattr_tracking = False # pylint: disable=protected-access
456 try:
--> 457 result = method(self, *args, **kwargs)
458 finally:
459 self._self_setattr_tracking = previous_value # pylint: disable=protected-access
~\.conda\envs\tensorflow_cpu\lib\site-packages\tensorflow\python\keras\engine\network.py in _init_graph_network(self, inputs, outputs, name, **kwargs)
267
268 if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs):
--> 269 base_layer_utils.create_keras_history(self._nested_outputs)
270
271 self._base_init(name=name, **kwargs)
~\.conda\envs\tensorflow_cpu\lib\site-packages\tensorflow\python\keras\engine\base_layer_utils.py in create_keras_history(tensors)
198 keras_tensors: The Tensors found that came from a Keras Layer.
199 """
--> 200 _, created_layers = _create_keras_history_helper(tensors, set(), [])
201 return created_layers
202
~\.conda\envs\tensorflow_cpu\lib\site-packages\tensorflow\python\keras\engine\base_layer_utils.py in _create_keras_history_helper(tensors, processed_ops, created_layers)
242 constants[i] = op_input
243 else:
--> 244 constants[i] = backend.function([], op_input)([])
245 processed_ops, created_layers = _create_keras_history_helper(
246 layer_inputs, processed_ops, created_layers)
~\.conda\envs\tensorflow_cpu\lib\site-packages\tensorflow\python\keras\backend.py in __call__(self, inputs)
3290
3291 fetched = self._callable_fn(*array_vals,
-> 3292 run_metadata=self.run_metadata)
3293 self._call_fetch_callbacks(fetched[-len(self._fetches):])
3294 output_structure = nest.pack_sequence_as(
~\.conda\envs\tensorflow_cpu\lib\site-packages\tensorflow\python\client\session.py in __call__(self, *args, **kwargs)
1456 ret = tf_session.TF_SessionRunCallable(self._session._session,
1457 self._handle, args,
-> 1458 run_metadata_ptr)
1459 if run_metadata:
1460 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
InvalidArgumentError: You must feed a value for placeholder tensor 'Placeholder' with dtype float and shape [?,32]
[[{{node Placeholder}}]]
您不能混合使用tf.placeholder
和tf.keras.Input
。 換句話說,如果您想使用tf.keras
API,請使用tf.keras.Input
,或者如果您想使用tf
本機API,請使用tf.placeholder
。 此外,您的選擇將反映代碼的其他部分。 假設您想使用tf.keras
API,則應采用以下方法:
class CustomLayer(tf.keras.layers.Layer):
def __init__(self, output_dim, **kwargs):
self.output_dim = output_dim
super(CustomLayer, self).__init__(**kwargs)
def build(self, input_shape):
# Create a trainable weight variable for this layer.
self.kernel = self.add_weight(name='kernel',
shape=tf.TensorShape((input_shape[1], self.output_dim)),
initializer=tf.truncated_normal_initializer(stddev=0.01),
trainable=True)
super(CustomLayer, self).build(input_shape) # Be sure to call this at the end
def call(self, x):
return tf.nn.relu(tf.matmul(x, self.kernel))
def compute_output_shape(self, input_shape):
return (input_shape[0], self.output_dim)
inputs = tf.keras.Input(shape=(32,), dtype=tf.float32)
h = CustomLayer(output_dim=64)(inputs)
predictions = CustomLayer(output_dim=10)(h)
model = tf.keras.Model(inputs=inputs, outputs=predictions)
model.compile(loss='mean_squared_error', optimizer='sgd') # sgd stands for stochastic gradient descent
model.summary()
model.fit(np.random.rand(100,32), np.random.rand(100,10), batch_size=32, epochs=5)
請注意,如果我們認為您使用注釋中提到的TF 1.14,則此方法很重要。 在TF 2.0上,這不那么復雜,而且更直觀。 另一方面,如果要堅持使用TF 1.14本機API並使用tf.placeholder
,則應使用tf.placeholder
作為將用於饋送數據的節點來構建圖形。 此外,關於您的問題tf.keras.Input
是否返回一個占位符-它確實返回一個占位符節點,您可以使用該節點來tf.keras.Input
數據。 但它不會返回tf.placeholder
。 tf.placeholder
的用法如下:
X = tf.placeholder(shape=None, dtype=tf.int32)
Y = tf.placeholder(shape=None, dtype=tf.int32)
add = X + Y
with tf.Session() as sess:
print(sess.run(add, feed_dict={X: 2, Y: 3}))
# 5
print(sess.run(add, feed_dict={X: 10, Y: 9}))
# 19
正如可以看到,靜態圖形被創建,然后將其節點與一個執行tf.Session
同時使用在圖中饋送數據tf.placeholder
。 另一方面, tf.keras.Input
具有相同的目的(這就是為什么在文檔中稱為占位符的原因),但其用例與tf.keras
API有關,而與tf
本機API不相關。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.