简体   繁体   中英

Missing required positional argument:

I tried to implement federated learning based on the LSTM approach.

def create_keras_model():
    model = Sequential()
    model.add(LSTM(32, input_shape=(3,1)))
    return model

def model_fn():
    keras_model = create_keras_model()
    return tff.learning.from_keras_model(
      input_spec=(look_back, 1),

but I got this error when I want to define iterative_process.

iterative_process = tff.learning.build_federated_averaging_process(
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.001),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

TypeError: Missing required positional argument

How do I fix it?

The provided input requirements matching the loopback parameters may replace by the client train data requirements. ( TensorSpec ) federated

You can do works parallel by different types of input parameters.

[ Sample ]:

import tensorflow as tf
import tensorflow_federated as tff

# Load simulation data.
source, _ = tff.simulation.datasets.emnist.load_data()

def client_data(n):
    return source.create_tf_dataset_for_client(source.client_ids[n]).map(
    lambda e: (tf.reshape(e['pixels'], [-1]), e['label'])

train_data = [client_data(n) for n in range(3)]

def create_keras_model():
    model = tf.keras.models.Sequential([ ])
    model.add(tf.keras.layers.InputLayer(input_shape=( 784 )))
    model.add(tf.keras.layers.Reshape((784, 1)))
    return model

def model_fn():
    keras_model = create_keras_model()
    return tff.learning.from_keras_model( keras_model, 
    metrics=[tf.keras.metrics.MeanSquaredError( ), tf.keras.metrics.Accuracy( )])

# Simulate a few rounds of training with the selected client devices.
trainer = tff.learning.build_federated_averaging_process( model_fn, client_optimizer_fn=lambda : tf.keras.optimizers.SGD(0.1),
    server_optimizer_fn=lambda : tf.keras.optimizers.SGD(1.0))

state = trainer.initialize()
for _ in range(50):
    state, metrics = trainer.next(state, train_data)

[ Output ]:

OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('mean_squared_error', 8.502816), ('accuracy', 0.0), ('loss', 8.5030365)]))])
OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('mean_squared_error', 8.500688), ('accuracy', 0.0), ('loss', 8.500914)]))])
OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('value_sum_process', ()), ('weight_sum_process', ())])), ('train', OrderedDict([('mean_squared_error', 8.498711), ('accuracy', 0.0), ('loss', 8.498943)]))])


The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

粤ICP备18138465号  © 2020-2024 STACKOOM.COM