简体   繁体   English

我在哪里可以找到 tensorflow federated 中的“fit funcion/model.fit”?

[英]Where can I find the 'fit funcion/model.fit' in tensorflow federated?

I am using the tff.我正在使用tff。 I have the following model which is prunned layer wise我有以下 model 逐层修剪

def model_net():
  model = tf.keras.Sequential([ tfmot.sparsity.keras.prune_low_magnitude(Conv2D( 6, 5, padding='same', activation='relu', input_shape=(28, 28,1)),**pruning_params), 
                              MaxPooling2D((2, 2)), 
                              tfmot.sparsity.keras.prune_low_magnitude(Conv2D(3, 5, padding='same', activation='relu'),**pruning_params),
                              MaxPooling2D((2, 2)), Flatten(),
                              tfmot.sparsity.keras.prune_low_magnitude(Dense(2, activation='relu'),**pruning_params), 
                              tfmot.sparsity.keras.prune_low_magnitude(Dense(10, activation='softmax'),**pruning_params) ])
  return model

I have the global model function that creates a global model for tff:我有全局 model function 为 tff 创建全局 model:

def model_fn():
  # We _must_ create a new model here, and _not_ capture it from an external
  # scope. TFF will call this within different graph contexts.
  global_model = model_net()

  return tff.learning.from_keras_model(
      global_model,
      input_spec=preprocessed_example_dataset.element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
  

The iterative process starts as迭代过程开始于

iterative_process = tff.learning.algorithms.build_weighted_fed_avg(
  model_fn,
  client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
  server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.00))

During the training session following error shows up,在训练 session 期间出现以下错误,

  state = iterative_process.initialize()

  Accuracy=[]
  loss=[]

  for round_num in range(1, 100):
    result = iterative_process.next(state, federated_train_data)
    state = result.state
    metrics = result.metrics['client_work']['train']['sparse_categorical_accuracy']
    Accuracy.append(metrics)
    los=result.metrics['client_work']['train']['loss']
    loss.append(los)
    #print('round {:2d}, metrics={}'.format(round_num, metrics))
  Accuracy=tf.stack(Accuracy)



  state = iterative_process.initialize()
  result = iterative_process.next(state, federated_train_data)
  state = result.state
  metrics = result.metrics
  #print('round  1, metrics={}'.format(metrics))

InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-115-930f71d2e953> in <module>
     13 
     14   for round_num in range(1, 100):
---> 15     result = iterative_process.next(state, federated_train_data)
     16     state = result.state
     17     metrics = result.metrics['client_work']['train']['sparse_categorical_accuracy']

395 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     53     ctx.ensure_initialized()
     54     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 55                                         inputs, attrs, num_outputs)
     56   except core._NotOkStatusException as e:
     57     if name is not None:

InvalidArgumentError: Graph execution error:

assertion failed: [Prune() wrapper requires the UpdatePruningStep callback to be provided during training. Please add it as a callback to your model.fit call.] [Condition x >= y did not hold element-wise:] [x (sequential/prune_low_magnitude_conv2d/assert_greater_equal/ReadVariableOp:0) = ] [0] [y (sequential/prune_low_magnitude_conv2d/assert_greater_equal/y:0) = ] [1]
     [[{{node Assert}}]]
     [[StatefulPartitionedCall/ReduceDataset]] [Op:__inference_pruned_439714]

Either I have to find the model.fit function (which I am unable to find) in order to add the callback or how can I remove the error without going to the model.fit function? Either I have to find the model.fit function (which I am unable to find) in order to add the callback or how can I remove the error without going to the model.fit function?

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM