![](/img/trans.png)
[英]Training a tf.keras model with a basic low-level TensorFlow training loop doesn't work
[英]Combine tensorflow low level API (tensors/placeholders) with Keras model
根据张量流 。 使用tf.keras.Input
提供占位符,使用tf.keras.layers.Dense
提供张量。 因此,我想使用带有Tensorflow低级API的Tensor和Placeholders测试等效性,然后使用keras高级别API训练我的模型。 这是我的代码:
tf.reset_default_graph()
inputs = tf.placeholder(tf.float32, shape=[None, 32])
W_h = tf.get_variable(name="W_h", shape=[32, 64], initializer=tf.truncated_normal_initializer(stddev=0.01))
W_out = tf.get_variable(name="W_out", shape=[64, 10], initializer=tf.truncated_normal_initializer(stddev=0.01))
h = tf.nn.relu(tf.matmul(inputs, W_h, name="MatMul"), name='relu')
predictions = tf.nn.relu(tf.matmul(h, W_out, name="MatMul"), name='relu')
model = tf.keras.Model(inputs=inputs, outputs=predictions)
model.compile(loss='mean_squared_error', optimizer='sgd') # sgd stands for stochastic gradient descent
model.fit(x_train, y_train, batch_size=32, epochs=5)
但是,在调用tf.keras.Model
出现错误:
InvalidArgumentError: You must feed a value for placeholder tensor 'Placeholder' with dtype float and shape [?,32]
[[{{node Placeholder}}]]
我确实给输入占位符了吗?
PS:完整错误消息:
InvalidArgumentError Traceback (most recent call last)
<ipython-input-15-27f92d1f784d> in <module>
8 predictions = tf.nn.relu(tf.matmul(h, W_out, name="MatMul"), name='relu')
9
---> 10 model = tf.keras.Model(inputs=inputs, outputs=predictions)
11 model.compile(loss='mean_squared_error', optimizer='sgd') # sgd stands for stochastic gradient descent
12 model.fit(x_train, y_train, batch_size=32, epochs=5)
~\.conda\envs\tensorflow_cpu\lib\site-packages\tensorflow\python\keras\engine\training.py in __init__(self, *args, **kwargs)
127
128 def __init__(self, *args, **kwargs):
--> 129 super(Model, self).__init__(*args, **kwargs)
130 # initializing _distribution_strategy here since it is possible to call
131 # predict on a model without compiling it.
~\.conda\envs\tensorflow_cpu\lib\site-packages\tensorflow\python\keras\engine\network.py in __init__(self, *args, **kwargs)
160 'inputs' in kwargs and 'outputs' in kwargs):
161 # Graph network
--> 162 self._init_graph_network(*args, **kwargs)
163 else:
164 # Subclassed network
~\.conda\envs\tensorflow_cpu\lib\site-packages\tensorflow\python\training\tracking\base.py in _method_wrapper(self, *args, **kwargs)
455 self._self_setattr_tracking = False # pylint: disable=protected-access
456 try:
--> 457 result = method(self, *args, **kwargs)
458 finally:
459 self._self_setattr_tracking = previous_value # pylint: disable=protected-access
~\.conda\envs\tensorflow_cpu\lib\site-packages\tensorflow\python\keras\engine\network.py in _init_graph_network(self, inputs, outputs, name, **kwargs)
267
268 if any(not hasattr(tensor, '_keras_history') for tensor in self.outputs):
--> 269 base_layer_utils.create_keras_history(self._nested_outputs)
270
271 self._base_init(name=name, **kwargs)
~\.conda\envs\tensorflow_cpu\lib\site-packages\tensorflow\python\keras\engine\base_layer_utils.py in create_keras_history(tensors)
198 keras_tensors: The Tensors found that came from a Keras Layer.
199 """
--> 200 _, created_layers = _create_keras_history_helper(tensors, set(), [])
201 return created_layers
202
~\.conda\envs\tensorflow_cpu\lib\site-packages\tensorflow\python\keras\engine\base_layer_utils.py in _create_keras_history_helper(tensors, processed_ops, created_layers)
242 constants[i] = op_input
243 else:
--> 244 constants[i] = backend.function([], op_input)([])
245 processed_ops, created_layers = _create_keras_history_helper(
246 layer_inputs, processed_ops, created_layers)
~\.conda\envs\tensorflow_cpu\lib\site-packages\tensorflow\python\keras\backend.py in __call__(self, inputs)
3290
3291 fetched = self._callable_fn(*array_vals,
-> 3292 run_metadata=self.run_metadata)
3293 self._call_fetch_callbacks(fetched[-len(self._fetches):])
3294 output_structure = nest.pack_sequence_as(
~\.conda\envs\tensorflow_cpu\lib\site-packages\tensorflow\python\client\session.py in __call__(self, *args, **kwargs)
1456 ret = tf_session.TF_SessionRunCallable(self._session._session,
1457 self._handle, args,
-> 1458 run_metadata_ptr)
1459 if run_metadata:
1460 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
InvalidArgumentError: You must feed a value for placeholder tensor 'Placeholder' with dtype float and shape [?,32]
[[{{node Placeholder}}]]
您不能混合使用tf.placeholder
和tf.keras.Input
。 换句话说,如果您想使用tf.keras
API,请使用tf.keras.Input
,或者如果您想使用tf
本机API,请使用tf.placeholder
。 此外,您的选择将反映代码的其他部分。 假设您想使用tf.keras
API,则应采用以下方法:
class CustomLayer(tf.keras.layers.Layer):
def __init__(self, output_dim, **kwargs):
self.output_dim = output_dim
super(CustomLayer, self).__init__(**kwargs)
def build(self, input_shape):
# Create a trainable weight variable for this layer.
self.kernel = self.add_weight(name='kernel',
shape=tf.TensorShape((input_shape[1], self.output_dim)),
initializer=tf.truncated_normal_initializer(stddev=0.01),
trainable=True)
super(CustomLayer, self).build(input_shape) # Be sure to call this at the end
def call(self, x):
return tf.nn.relu(tf.matmul(x, self.kernel))
def compute_output_shape(self, input_shape):
return (input_shape[0], self.output_dim)
inputs = tf.keras.Input(shape=(32,), dtype=tf.float32)
h = CustomLayer(output_dim=64)(inputs)
predictions = CustomLayer(output_dim=10)(h)
model = tf.keras.Model(inputs=inputs, outputs=predictions)
model.compile(loss='mean_squared_error', optimizer='sgd') # sgd stands for stochastic gradient descent
model.summary()
model.fit(np.random.rand(100,32), np.random.rand(100,10), batch_size=32, epochs=5)
请注意,如果我们认为您使用注释中提到的TF 1.14,则此方法很重要。 在TF 2.0上,这不那么复杂,而且更直观。 另一方面,如果要坚持使用TF 1.14本机API并使用tf.placeholder
,则应使用tf.placeholder
作为将用于馈送数据的节点来构建图形。 此外,关于您的问题tf.keras.Input
是否返回一个占位符-它确实返回一个占位符节点,您可以使用该节点来tf.keras.Input
数据。 但它不会返回tf.placeholder
。 tf.placeholder
的用法如下:
X = tf.placeholder(shape=None, dtype=tf.int32)
Y = tf.placeholder(shape=None, dtype=tf.int32)
add = X + Y
with tf.Session() as sess:
print(sess.run(add, feed_dict={X: 2, Y: 3}))
# 5
print(sess.run(add, feed_dict={X: 10, Y: 9}))
# 19
正如可以看到,静态图形被创建,然后将其节点与一个执行tf.Session
同时使用在图中馈送数据tf.placeholder
。 另一方面, tf.keras.Input
具有相同的目的(这就是为什么在文档中称为占位符的原因),但其用例与tf.keras
API有关,而与tf
本机API不相关。
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.