简体   繁体   English

在联合服务器上访问和修改从客户端发送的权重

[英]Access and modify weights sent from client on the server tensorflow federated

I'm using Tensorflow Federated, but i'm actually have some problem while trying to executes some operation on the server after reading the client update.我正在使用 Tensorflow Federated,但实际上在读取客户端更新后尝试在服务器上执行某些操作时遇到了一些问题。

This is the function这是功能

@tff.federated_computation(federated_server_state_type,
                           federated_dataset_type)
def run_one_round(server_state, federated_dataset):
    """Orchestration logic for one round of computation.
    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.data.Dataset` with placement
        `tff.CLIENTS`.
    Returns:
      A tuple of updated `ServerState` and `tf.Tensor` of average loss.
    """
    tf.print("run_one_round")
    server_message = tff.federated_map(server_message_fn, server_state)
    server_message_at_client = tff.federated_broadcast(server_message)

    client_outputs = tff.federated_map(
        client_update_fn, (federated_dataset, server_message_at_client))

    weight_denom = client_outputs.client_weight


    tf.print(client_outputs.weights_delta)
    round_model_delta = tff.federated_mean(
        client_outputs.weights_delta, weight=weight_denom)

    server_state = tff.federated_map(server_update_fn, (server_state, round_model_delta))
    round_loss_metric = tff.federated_mean(client_outputs.model_output, weight=weight_denom)

    return server_state, round_loss_metric, client_outputs.weights_delta.comp

I want to print the client_outputs.weights_delta and doing some operation on the weights that the client sent to the server before using the tff.federated_mean but i don't get how to do so.我想打印client_outputs.weights_delta并在使用tff.federated_mean之前对客户端发送到服务器的权重进行一些操作,但我不知道该怎么做。

When i try to print i get this当我尝试打印时,我得到了这个

Call(Intrinsic('federated_map', FunctionType(StructType([FunctionType(StructType([('weights_delta', StructType([TensorType(tf.float32, [5, 5, 1, 32]), TensorType(tf.float32, [32]), ....]) as ClientOutput, PlacementLiteral('clients'), False)))]))

Any way to modify those elements?有什么办法可以修改这些元素吗?

I tried with using return client_outputs.weights_delta.comp doing the modification in the main (i can do that) and then i tried to invocate a new method for doing the rest of the operations for the server update, but the error is:我尝试使用return client_outputs.weights_delta.compreturn client_outputs.weights_delta.comp进行修改(我可以这样做),然后我尝试调用一种新方法来执行服务器更新的其余操作,但错误是:

AttributeError: 'IterativeProcess' object has no attribute 'calculate_federated_mean' where calculate_federated_mean was the name of the new function i created. AttributeError: 'IterativeProcess' object has no attribute 'calculate_federated_mean' ,其中calculate_federated_mean 是我创建的新函数的名称。

This is the main:这是主要的:

 for round_num in range(FLAGS.total_rounds):
        print("--------------------------------------------------------")
        sampled_clients = np.random.choice(train_data.client_ids, size=FLAGS.train_clients_per_round, replace=False)
        sampled_train_data = [train_data.create_tf_dataset_for_client(client) for client in sampled_clients]

        server_state, train_metrics, value_comp = iterative_process.next(server_state, sampled_train_data)

        print(f'Round {round_num}')
        print(f'\tTraining loss: {train_metrics:.4f}')
        if round_num % FLAGS.rounds_per_eval == 0:
            server_state.model_weights.assign_weights_to(keras_model)
            accuracy = evaluate(keras_model, test_data)
            print(f'\tValidation accuracy: {accuracy * 100.0:.2f}%')
            tf.print(tf.compat.v2.summary.scalar("Accuracy", accuracy * 100.0, step=round_num))

Based on the simple_fedavg project from github Tensorflow Federated simple_fedavg as basic project.基于 github Tensorflow Federated simple_fedavg作为基础项目的simple_fedavg项目。

I think this reply to your other question I just wrote applies here, too.我觉得这个答复到其他的问题,我只是写在这里也适用。

When you print client_outputs.weights_delta you get abstract representation fo a result of another computation, a primarily internal implementation detail of TFF.当您打印client_outputs.weights_delta您将获得另一个计算结果的抽象表示,这是 TFF 的主要内部实现细节。

Write a tff.tf_computation -decorated method with TensorFlow code, which does the modification you need, and then invoke it using tff.federated_map operator from where you are trying to print the values.使用 TensorFlow 代码编写一个tff.tf_computation装饰方法,该方法会执行您需要的修改,然后使用tff.federated_map运算符从您尝试打印值的位置调用它。

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM