繁体   English   中英

在联合服务器上访问和修改从客户端发送的权重

[英]Access and modify weights sent from client on the server tensorflow federated

我正在使用 Tensorflow Federated,但实际上在读取客户端更新后尝试在服务器上执行某些操作时遇到了一些问题。

这是功能

@tff.federated_computation(federated_server_state_type,
                           federated_dataset_type)
def run_one_round(server_state, federated_dataset):
    """Orchestration logic for one round of computation.
    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.data.Dataset` with placement
        `tff.CLIENTS`.
    Returns:
      A tuple of updated `ServerState` and `tf.Tensor` of average loss.
    """
    tf.print("run_one_round")
    server_message = tff.federated_map(server_message_fn, server_state)
    server_message_at_client = tff.federated_broadcast(server_message)

    client_outputs = tff.federated_map(
        client_update_fn, (federated_dataset, server_message_at_client))

    weight_denom = client_outputs.client_weight


    tf.print(client_outputs.weights_delta)
    round_model_delta = tff.federated_mean(
        client_outputs.weights_delta, weight=weight_denom)

    server_state = tff.federated_map(server_update_fn, (server_state, round_model_delta))
    round_loss_metric = tff.federated_mean(client_outputs.model_output, weight=weight_denom)

    return server_state, round_loss_metric, client_outputs.weights_delta.comp

我想打印client_outputs.weights_delta并在使用tff.federated_mean之前对客户端发送到服务器的权重进行一些操作,但我不知道该怎么做。

当我尝试打印时,我得到了这个

Call(Intrinsic('federated_map', FunctionType(StructType([FunctionType(StructType([('weights_delta', StructType([TensorType(tf.float32, [5, 5, 1, 32]), TensorType(tf.float32, [32]), ....]) as ClientOutput, PlacementLiteral('clients'), False)))]))

有什么办法可以修改这些元素吗?

我尝试使用return client_outputs.weights_delta.compreturn client_outputs.weights_delta.comp进行修改(我可以这样做),然后我尝试调用一种新方法来执行服务器更新的其余操作,但错误是:

AttributeError: 'IterativeProcess' object has no attribute 'calculate_federated_mean' ,其中calculate_federated_mean 是我创建的新函数的名称。

这是主要的:

 for round_num in range(FLAGS.total_rounds):
        print("--------------------------------------------------------")
        sampled_clients = np.random.choice(train_data.client_ids, size=FLAGS.train_clients_per_round, replace=False)
        sampled_train_data = [train_data.create_tf_dataset_for_client(client) for client in sampled_clients]

        server_state, train_metrics, value_comp = iterative_process.next(server_state, sampled_train_data)

        print(f'Round {round_num}')
        print(f'\tTraining loss: {train_metrics:.4f}')
        if round_num % FLAGS.rounds_per_eval == 0:
            server_state.model_weights.assign_weights_to(keras_model)
            accuracy = evaluate(keras_model, test_data)
            print(f'\tValidation accuracy: {accuracy * 100.0:.2f}%')
            tf.print(tf.compat.v2.summary.scalar("Accuracy", accuracy * 100.0, step=round_num))

基于 github Tensorflow Federated simple_fedavg作为基础项目的simple_fedavg项目。

我觉得这个答复到其他的问题,我只是写在这里也适用。

当您打印client_outputs.weights_delta您将获得另一个计算结果的抽象表示,这是 TFF 的主要内部实现细节。

使用 TensorFlow 代码编写一个tff.tf_computation装饰方法,该方法会执行您需要的修改,然后使用tff.federated_map运算符从您尝试打印值的位置调用它。

暂无
暂无

声明:本站的技术帖子网页,遵循CC BY-SA 4.0协议,如果您需要转载,请注明本站网址或者原文地址。任何问题请咨询:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM