簡體   English   中英

Tensorflow 加載 model 和 load_weights 給出不同的行為

[英]Tensorflow load model and load_weights give different behavior

我有一個工作 tensorflow RNN model (稱為my_model ),我使用my_model.save("<dir_to_my_model>")保存它。

但是,當我重新加載這個 model 時

model = tf.keras.models.load_model("<dir_to_my_model>")

並運行一些增量學習時期

with tf.GradientTape() as tape:

logits = model(x_test_batch, training=True)

我收到一個錯誤(很長,請參閱底部)。

But, if I save the weights of this model to a checkpoint using model.save_weights() , create a new model and load the weights using model.load_weights from that checkpoint. 上面的代碼可以成功運行。

我的問題是為什么這兩種方法的工作方式不同,我應該怎么做才能使方法 1 工作?


InvalidArgumentError                      Traceback (most recent call last)
~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py in get_attr(self, name)
   2325       with c_api_util.tf_buffer() as buf:
-> 2326         c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf)
   2327         data = c_api.TF_GetBuffer(buf)

InvalidArgumentError: Operation 'StatefulPartitionedCall' has no attr named '_XlaCompile'.

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/ops/gradients_util.py in _MaybeCompile(scope, op, func, grad_fn)
    330     try:
--> 331       xla_compile = op.get_attr("_XlaCompile")
    332       xla_separate_compiled_gradients = op.get_attr(

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py in get_attr(self, name)
   2329       # Convert to ValueError for backwards compatibility.
-> 2330       raise ValueError(str(e))
   2331     x = attr_value_pb2.AttrValue()

ValueError: Operation 'StatefulPartitionedCall' has no attr named '_XlaCompile'.

During handling of the above exception, another exception occurred:

InvalidArgumentError                      Traceback (most recent call last)
~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py in get_attr(self, name)
   2325       with c_api_util.tf_buffer() as buf:
-> 2326         c_api.TF_OperationGetAttrValueProto(self._c_op, name, buf)
   2327         data = c_api.TF_GetBuffer(buf)

InvalidArgumentError: Operation 'StatefulPartitionedCall' has no attr named '_XlaCompile'.

During handling of the above exception, another exception occurred:

ValueError                                Traceback (most recent call last)
~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/ops/gradients_util.py in _MaybeCompile(scope, op, func, grad_fn)
    330     try:
--> 331       xla_compile = op.get_attr("_XlaCompile")
    332       xla_separate_compiled_gradients = op.get_attr(

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py in get_attr(self, name)
   2329       # Convert to ValueError for backwards compatibility.
-> 2330       raise ValueError(str(e))
   2331     x = attr_value_pb2.AttrValue()

ValueError: Operation 'StatefulPartitionedCall' has no attr named '_XlaCompile'.

During handling of the above exception, another exception occurred:

LookupError                               Traceback (most recent call last)
~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/ops/gradients_util.py in _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients, aggregation_method, stop_gradients, unconnected_gradients, src_graph)
    606           try:
--> 607             grad_fn = ops.get_gradient_function(op)
    608           except LookupError:

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/framework/ops.py in get_gradient_function(op)
   2494     op_type = op.type
-> 2495   return _gradient_registry.lookup(op_type)
   2496 

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/framework/registry.py in lookup(self, name)
     96       raise LookupError(
---> 97           "%s registry has no entry for: %s" % (self._name, name))

LookupError: gradient registry has no entry for: While

During handling of the above exception, another exception occurred:

LookupError                               Traceback (most recent call last)
<ipython-input-7-55225939f96a> in <module>
     10             mode = 'DNN' if 'ANN' in m else 'RNN'
     11             c_datasets = continual_dataset(datasets[i], mode=mode)
---> 12             res = learning_rate_test(continual_fine_tune, mname, c_datasets, lr_list=lrs)
     13             dfs.append(res)

<ipython-input-5-f4d4cc64dc17> in learning_rate_test(continual_func, model_fname, c_datasets, epochs, lr_list, loss)
     14                 model.save_weights('Model/tmp')
     15                 model.load_weights('Model/tmp')
---> 16             test_predictions = continual_func(c_datasets, model, loss_fn, lr, epochs=e, save_model='')
     17             test_predictions = (np.array(test_predictions) > 0.5).astype(int)
     18             epoch_test_dict[e].append(np.sum(np.equal(y_test_list, test_predictions)) / len(y_test_list) * 100)

<ipython-input-4-9ae956b3e918> in continual_fine_tune(c_datasets, model, loss_fn, lr, epochs, save_model)
     12         for e in range(epochs):
     13             with tf.GradientTape() as tape:
---> 14                 logits = model(x_test_batch, training=True)
     15                 loss_value = loss_fn(y_test_batch, logits)
     16             grads = tape.gradient(loss_value, model.trainable_weights)

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py in __call__(self, inputs, *args, **kwargs)
    820           with base_layer_utils.autocast_context_manager(
    821               self._compute_dtype):
--> 822             outputs = self.call(cast_inputs, *args, **kwargs)
    823           self._handle_activity_regularization(inputs, outputs)
    824           self._set_mask_metadata(inputs, outputs, input_masks)

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/sequential.py in call(self, inputs, training, mask)
    265       if not self.built:
    266         self._init_graph_network(self.inputs, self.outputs, name=self.name)
--> 267       return super(Sequential, self).call(inputs, training=training, mask=mask)
    268 
    269     outputs = inputs  # handle the corner case where self.layers is empty

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/network.py in call(self, inputs, training, mask)
    715     return self._run_internal_graph(
    716         inputs, training=training, mask=mask,
--> 717         convert_kwargs_to_constants=base_layer_utils.call_context().saving)
    718 
    719   def compute_output_shape(self, input_shape):

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/network.py in _run_internal_graph(self, inputs, training, mask, convert_kwargs_to_constants)
    889 
    890           # Compute outputs.
--> 891           output_tensors = layer(computed_tensors, **kwargs)
    892 
    893           # Update tensor_dict.

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/keras/engine/base_layer.py in __call__(self, inputs, *args, **kwargs)
    820           with base_layer_utils.autocast_context_manager(
    821               self._compute_dtype):
--> 822             outputs = self.call(cast_inputs, *args, **kwargs)
    823           self._handle_activity_regularization(inputs, outputs)
    824           self._set_mask_metadata(inputs, outputs, input_masks)

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/utils.py in return_outputs_and_add_losses(*args, **kwargs)
     57     inputs = args[inputs_arg_index]
     58     args = args[inputs_arg_index + 1:]
---> 59     outputs, losses = fn(inputs, *args, **kwargs)
     60     layer.add_loss(losses, inputs)
     61     return outputs

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/utils.py in wrap_with_training_arg(*args, **kwargs)
    111         training,
    112         lambda: replace_training_and_call(True),
--> 113         lambda: replace_training_and_call(False))
    114 
    115   # Create arg spec for decorated function. If 'training' is not defined in the

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/keras/utils/tf_utils.py in smart_cond(pred, true_fn, false_fn, name)
     57         pred, true_fn=true_fn, false_fn=false_fn, name=name)
     58   return smart_module.smart_cond(
---> 59       pred, true_fn=true_fn, false_fn=false_fn, name=name)
     60 
     61 

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/framework/smart_cond.py in smart_cond(pred, true_fn, false_fn, name)
     52   if pred_value is not None:
     53     if pred_value:
---> 54       return true_fn()
     55     else:
     56       return false_fn()

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/utils.py in <lambda>()
    110     return tf_utils.smart_cond(
    111         training,
--> 112         lambda: replace_training_and_call(True),
    113         lambda: replace_training_and_call(False))
    114 

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/keras/saving/saved_model/utils.py in replace_training_and_call(training)
    106     def replace_training_and_call(training):
    107       set_training_arg(training, training_arg_index, args, kwargs)
--> 108       return wrapped_call(*args, **kwargs)
    109 
    110     return tf_utils.smart_cond(

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in __call__(self, *args, **kwds)
    566         xla_context.Exit()
    567     else:
--> 568       result = self._call(*args, **kwds)
    569 
    570     if tracing_count == self._get_tracing_count():

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/eager/def_function.py in _call(self, *args, **kwds)
    604       # In this case we have not created variables on the first call. So we can
    605       # run the first trace but we should fail if variables are created.
--> 606       results = self._stateful_fn(*args, **kwds)
    607       if self._created_variables:
    608         raise ValueError("Creating variables on a non-first call to a function"

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in __call__(self, *args, **kwargs)
   2361     with self._lock:
   2362       graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
-> 2363     return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
   2364 
   2365   @property

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _filtered_call(self, args, kwargs)
   1609          if isinstance(t, (ops.Tensor,
   1610                            resource_variable_ops.BaseResourceVariable))),
-> 1611         self.captured_inputs)
   1612 
   1613   def _call_flat(self, args, captured_inputs, cancellation_manager=None):

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1695         possible_gradient_type,
   1696         executing_eagerly)
-> 1697     forward_function, args_with_tangents = forward_backward.forward()
   1698     if executing_eagerly:
   1699       flat_outputs = forward_function.call(

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in forward(self)
   1421     """Builds or retrieves a forward function for this call."""
   1422     forward_function = self._functions.forward(
-> 1423         self._inference_args, self._input_tangents)
   1424     return forward_function, self._inference_args + self._input_tangents
   1425 

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in forward(self, inference_args, input_tangents)
   1183       (self._forward, self._forward_graph, self._backward,
   1184        self._forwardprop_output_indices, self._num_forwardprop_outputs) = (
-> 1185            self._forward_and_backward_functions(inference_args, input_tangents))
   1186     return self._forward
   1187 

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _forward_and_backward_functions(self, inference_args, input_tangents)
   1329     outputs = self._func_graph.outputs[:self._num_inference_outputs]
   1330     return self._build_functions_for_outputs(
-> 1331         outputs, inference_args, input_tangents)
   1332 
   1333 

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _build_functions_for_outputs(self, outputs, inference_args, input_tangents)
    888             self._func_graph.inputs,
    889             grad_ys=gradients_wrt_outputs,
--> 890             src_graph=self._func_graph)
    891 
    892       captures_from_forward = [

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/ops/gradients_util.py in _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients, aggregation_method, stop_gradients, unconnected_gradients, src_graph)
    667                 # functions.
    668                 in_grads = _MaybeCompile(grad_scope, op, func_call,
--> 669                                          lambda: grad_fn(op, *out_grads))
    670               else:
    671                 # For function call ops, we add a 'SymbolicGradient'

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/ops/gradients_util.py in _MaybeCompile(scope, op, func, grad_fn)
    334       xla_scope = op.get_attr("_XlaScope").decode()
    335     except ValueError:
--> 336       return grad_fn()  # Exit early
    337 
    338   if not xla_compile:

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/ops/gradients_util.py in <lambda>()
    667                 # functions.
    668                 in_grads = _MaybeCompile(grad_scope, op, func_call,
--> 669                                          lambda: grad_fn(op, *out_grads))
    670               else:
    671                 # For function call ops, we add a 'SymbolicGradient'

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _rewrite_forward_and_call_backward(self, op, *doutputs)
    705   def _rewrite_forward_and_call_backward(self, op, *doutputs):
    706     """Add outputs to the forward call and feed them to the grad function."""
--> 707     forward_function, backwards_function = self.forward_backward(len(doutputs))
    708     if not backwards_function.outputs:
    709       return backwards_function.structured_outputs

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in forward_backward(self, num_doutputs)
    614     if forward_backward is not None:
    615       return forward_backward
--> 616     forward, backward = self._construct_forward_backward(num_doutputs)
    617     self._cached_function_pairs[num_doutputs] = (forward, backward)
    618     return forward, backward

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _construct_forward_backward(self, num_doutputs)
    662           args=[], kwargs={},
    663           signature=signature,
--> 664           func_graph=backwards_graph)
    665       backwards_graph_captures = backwards_graph.external_captures
    666       captures_from_forward = [

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    976                                           converted_func)
    977 
--> 978       func_outputs = python_func(*func_args, **func_kwargs)
    979 
    980       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _backprop_function(*grad_ys)
    652             self._func_graph.inputs,
    653             grad_ys=grad_ys,
--> 654             src_graph=self._func_graph)
    655 
    656     with self._func_graph.as_default():

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/ops/gradients_util.py in _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients, aggregation_method, stop_gradients, unconnected_gradients, src_graph)
    667                 # functions.
    668                 in_grads = _MaybeCompile(grad_scope, op, func_call,
--> 669                                          lambda: grad_fn(op, *out_grads))
    670               else:
    671                 # For function call ops, we add a 'SymbolicGradient'

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/ops/gradients_util.py in _MaybeCompile(scope, op, func, grad_fn)
    334       xla_scope = op.get_attr("_XlaScope").decode()
    335     except ValueError:
--> 336       return grad_fn()  # Exit early
    337 
    338   if not xla_compile:

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/ops/gradients_util.py in <lambda>()
    667                 # functions.
    668                 in_grads = _MaybeCompile(grad_scope, op, func_call,
--> 669                                          lambda: grad_fn(op, *out_grads))
    670               else:
    671                 # For function call ops, we add a 'SymbolicGradient'

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _rewrite_forward_and_call_backward(self, op, *doutputs)
    705   def _rewrite_forward_and_call_backward(self, op, *doutputs):
    706     """Add outputs to the forward call and feed them to the grad function."""
--> 707     forward_function, backwards_function = self.forward_backward(len(doutputs))
    708     if not backwards_function.outputs:
    709       return backwards_function.structured_outputs

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in forward_backward(self, num_doutputs)
    614     if forward_backward is not None:
    615       return forward_backward
--> 616     forward, backward = self._construct_forward_backward(num_doutputs)
    617     self._cached_function_pairs[num_doutputs] = (forward, backward)
    618     return forward, backward

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _construct_forward_backward(self, num_doutputs)
    662           args=[], kwargs={},
    663           signature=signature,
--> 664           func_graph=backwards_graph)
    665       backwards_graph_captures = backwards_graph.external_captures
    666       captures_from_forward = [

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    976                                           converted_func)
    977 
--> 978       func_outputs = python_func(*func_args, **func_kwargs)
    979 
    980       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/eager/function.py in _backprop_function(*grad_ys)
    652             self._func_graph.inputs,
    653             grad_ys=grad_ys,
--> 654             src_graph=self._func_graph)
    655 
    656     with self._func_graph.as_default():

~/work/tf/lib/python3.7/site-packages/tensorflow_core/python/ops/gradients_util.py in _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops, gate_gradients, aggregation_method, stop_gradients, unconnected_gradients, src_graph)
    621               raise LookupError(
    622                   "No gradient defined for operation '%s' (op type: %s)" %
--> 623                   (op.name, op.type))
    624         if loop_state:
    625           loop_state.EnterGradWhileContext(op, before=False)

LookupError: No gradient defined for operation 'while' (op type: While) 

我過去曾遇到過這種情況。 以下是我如何保存/加載 keras 模型。 請注意,架構和權重是分開的。

#saving model
open('model_weights.h5', 'w').write(model.to_json())
model.save_weights('architecture.json', overwrite=True)



from tensorflow.keras.models import model_from_json

#loading model
model = model_from_json(open('architecture.json').read())
model.load_weights('model_weights.h5')

暫無
暫無

聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.

 
粵ICP備18138465號  © 2020-2024 STACKOOM.COM