[英]Access and modify weights sent from client on the server tensorflow federated
我正在使用 Tensorflow Federated,但實際上在讀取客戶端更新后嘗試在服務器上執行某些操作時遇到了一些問題。
這是功能
@tff.federated_computation(federated_server_state_type,
federated_dataset_type)
def run_one_round(server_state, federated_dataset):
"""Orchestration logic for one round of computation.
Args:
server_state: A `ServerState`.
federated_dataset: A federated `tf.data.Dataset` with placement
`tff.CLIENTS`.
Returns:
A tuple of updated `ServerState` and `tf.Tensor` of average loss.
"""
tf.print("run_one_round")
server_message = tff.federated_map(server_message_fn, server_state)
server_message_at_client = tff.federated_broadcast(server_message)
client_outputs = tff.federated_map(
client_update_fn, (federated_dataset, server_message_at_client))
weight_denom = client_outputs.client_weight
tf.print(client_outputs.weights_delta)
round_model_delta = tff.federated_mean(
client_outputs.weights_delta, weight=weight_denom)
server_state = tff.federated_map(server_update_fn, (server_state, round_model_delta))
round_loss_metric = tff.federated_mean(client_outputs.model_output, weight=weight_denom)
return server_state, round_loss_metric, client_outputs.weights_delta.comp
我想打印client_outputs.weights_delta
並在使用tff.federated_mean
之前對客戶端發送到服務器的權重進行一些操作,但我不知道該怎么做。
當我嘗試打印時,我得到了這個
Call(Intrinsic('federated_map', FunctionType(StructType([FunctionType(StructType([('weights_delta', StructType([TensorType(tf.float32, [5, 5, 1, 32]), TensorType(tf.float32, [32]), ....]) as ClientOutput, PlacementLiteral('clients'), False)))]))
有什么辦法可以修改這些元素嗎?
我嘗試使用return client_outputs.weights_delta.comp
在return client_outputs.weights_delta.comp
進行修改(我可以這樣做),然后我嘗試調用一種新方法來執行服務器更新的其余操作,但錯誤是:
AttributeError: 'IterativeProcess' object has no attribute 'calculate_federated_mean'
,其中calculate_federated_mean 是我創建的新函數的名稱。
這是主要的:
for round_num in range(FLAGS.total_rounds):
print("--------------------------------------------------------")
sampled_clients = np.random.choice(train_data.client_ids, size=FLAGS.train_clients_per_round, replace=False)
sampled_train_data = [train_data.create_tf_dataset_for_client(client) for client in sampled_clients]
server_state, train_metrics, value_comp = iterative_process.next(server_state, sampled_train_data)
print(f'Round {round_num}')
print(f'\tTraining loss: {train_metrics:.4f}')
if round_num % FLAGS.rounds_per_eval == 0:
server_state.model_weights.assign_weights_to(keras_model)
accuracy = evaluate(keras_model, test_data)
print(f'\tValidation accuracy: {accuracy * 100.0:.2f}%')
tf.print(tf.compat.v2.summary.scalar("Accuracy", accuracy * 100.0, step=round_num))
基於 github Tensorflow Federated simple_fedavg作為基礎項目的simple_fedavg項目。
我覺得這個答復到其他的問題,我只是寫在這里也適用。
當您打印client_outputs.weights_delta
您將獲得另一個計算結果的抽象表示,這是 TFF 的主要內部實現細節。
使用 TensorFlow 代碼編寫一個tff.tf_computation
裝飾方法,該方法會執行您需要的修改,然后使用tff.federated_map
運算符從您嘗試打印值的位置調用它。
聲明:本站的技術帖子網頁,遵循CC BY-SA 4.0協議,如果您需要轉載,請注明本站網址或者原文地址。任何問題請咨詢:yoyou2525@163.com.