[英]Where can I find the 'fit funcion/model.fit' in tensorflow federated?
I am using the tff.我正在使用tff。 I have the following model which is prunned layer wise
我有以下 model 逐层修剪
def model_net():
model = tf.keras.Sequential([ tfmot.sparsity.keras.prune_low_magnitude(Conv2D( 6, 5, padding='same', activation='relu', input_shape=(28, 28,1)),**pruning_params),
MaxPooling2D((2, 2)),
tfmot.sparsity.keras.prune_low_magnitude(Conv2D(3, 5, padding='same', activation='relu'),**pruning_params),
MaxPooling2D((2, 2)), Flatten(),
tfmot.sparsity.keras.prune_low_magnitude(Dense(2, activation='relu'),**pruning_params),
tfmot.sparsity.keras.prune_low_magnitude(Dense(10, activation='softmax'),**pruning_params) ])
return model
I have the global model function that creates a global model for tff:我有全局 model function 为 tff 创建全局 model:
def model_fn():
# We _must_ create a new model here, and _not_ capture it from an external
# scope. TFF will call this within different graph contexts.
global_model = model_net()
return tff.learning.from_keras_model(
global_model,
input_spec=preprocessed_example_dataset.element_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
The iterative process starts as迭代过程开始于
iterative_process = tff.learning.algorithms.build_weighted_fed_avg(
model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.00))
During the training session following error shows up,在训练 session 期间出现以下错误,
state = iterative_process.initialize()
Accuracy=[]
loss=[]
for round_num in range(1, 100):
result = iterative_process.next(state, federated_train_data)
state = result.state
metrics = result.metrics['client_work']['train']['sparse_categorical_accuracy']
Accuracy.append(metrics)
los=result.metrics['client_work']['train']['loss']
loss.append(los)
#print('round {:2d}, metrics={}'.format(round_num, metrics))
Accuracy=tf.stack(Accuracy)
state = iterative_process.initialize()
result = iterative_process.next(state, federated_train_data)
state = result.state
metrics = result.metrics
#print('round 1, metrics={}'.format(metrics))
InvalidArgumentError Traceback (most recent call last)
<ipython-input-115-930f71d2e953> in <module>
13
14 for round_num in range(1, 100):
---> 15 result = iterative_process.next(state, federated_train_data)
16 state = result.state
17 metrics = result.metrics['client_work']['train']['sparse_categorical_accuracy']
395 frames
/usr/local/lib/python3.7/dist-packages/tensorflow/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
53 ctx.ensure_initialized()
54 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
---> 55 inputs, attrs, num_outputs)
56 except core._NotOkStatusException as e:
57 if name is not None:
InvalidArgumentError: Graph execution error:
assertion failed: [Prune() wrapper requires the UpdatePruningStep callback to be provided during training. Please add it as a callback to your model.fit call.] [Condition x >= y did not hold element-wise:] [x (sequential/prune_low_magnitude_conv2d/assert_greater_equal/ReadVariableOp:0) = ] [0] [y (sequential/prune_low_magnitude_conv2d/assert_greater_equal/y:0) = ] [1]
[[{{node Assert}}]]
[[StatefulPartitionedCall/ReduceDataset]] [Op:__inference_pruned_439714]
Either I have to find the model.fit function (which I am unable to find) in order to add the callback or how can I remove the error without going to the model.fit function? Either I have to find the model.fit function (which I am unable to find) in order to add the callback or how can I remove the error without going to the model.fit function?
声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.