简体   繁体   中英

Tensorflow Federated object is not subscriptable

I have this run_one_round function like this:

def run_one_round(server_state, federated_dataset):
    """Orchestration logic for one round of computation.
    Args:
      server_state: A `ServerState`.
      federated_dataset: A federated `tf.data.Dataset` with placement
        `tff.CLIENTS`.
    Returns:
      A tuple of updated `ServerState` and `tf.Tensor` of average loss.
    """
    tf.print("run_one_round")
    server_message = tff.federated_map(server_message_fn, server_state)
    server_message_at_client = tff.federated_broadcast(server_message)

    client_outputs = tff.federated_map(
        client_update_fn, (federated_dataset, server_message_at_client))

    weight_denom = client_outputs.client_weight

    from tensorflow_federated.python.core.impl.federated_context import value_impl
    value = value_impl.to_value(client_outputs.test, None)
    from tensorflow_federated.python.core.impl.types import placements
    from tensorflow_federated.python.core.impl.federated_context import value_utils
    value = value_utils.ensure_federated_value(value, placements.CLIENTS,
                                               'value to be averaged')

    value_comp = value.comp
    testing = []
    import sparse_ternary_compression
    for index in range(len(value_comp[0])):
        testing.append(
            sparse_ternary_compression.stc_decompression(value_comp[0][index][0], value_comp[0][index][1],
                                                         value_comp[0][index][2], value_comp[0][index][3],
                                                         value_comp[0][index][4]))

    # round_model_delta indica i pesi che vengono usati su server_update. Quindi è quello che va cambiato
    round_model_delta = tff.federated_mean(
        client_outputs.weights_delta, weight=weight_denom)

    server_state = tff.federated_map(server_update_fn, (server_state, round_model_delta))
    round_loss_metric = tff.federated_mean(client_outputs.model_output, weight=weight_denom)

    return server_state, round_loss_metric, value.comp

But when i try to do:

value_comp = value.comp
testing = []
import sparse_ternary_compression
for index in range(len(value_comp[0])):
    testing.append(
        sparse_ternary_compression.stc_decompression(value_comp[0][index][0], value_comp[0][index][1],
                                                     value_comp[0][index][2], value_comp[0][index][3],
                                                     value_comp[0][index][4]))

I get this error:"

File "/mnt/d/Davide/Uni/TesiMagistrale/ProgettoTesi/simple_fedavg_tff.py", line 137, in run_one_round
    for index in range(len(value_comp[0])):
TypeError: 'Call' object is not subscriptable

While if i return the value value.comp and then i do the same operations inside the main it works fine.

    for round_num in range(FLAGS.total_rounds):
        print("--------------------------------------------------------")
        sampled_clients = np.random.choice(train_data.client_ids, size=FLAGS.train_clients_per_round, replace=False)
        sampled_train_data = [train_data.create_tf_dataset_for_client(client) for client in sampled_clients]

The code is the same, so why i can't use the for loop inside run_one_round function?

        server_state, train_metrics, value_comp = iterative_process.next(server_state, sampled_train_data)

        testing = []
        import sparse_ternary_compression
        for index in range(len(value_comp[0])):
            testing.append(sparse_ternary_compression.stc_decompression(value_comp[0][index][0], value_comp[0][index][1],
                                                                   value_comp[0][index][2], value_comp[0][index][3],
                                                                   value_comp[0][index][4]))
        print(testing)
        print(f'Round {round_num}')
        print(f'\tTraining loss: {train_metrics:.4f}')
        if round_num % FLAGS.rounds_per_eval == 0:
            server_state.model_weights.assign_weights_to(keras_model)
            accuracy = evaluate(keras_model, test_data)
            print(f'\tValidation accuracy: {accuracy * 100.0:.2f}%')
            tf.print(tf.compat.v2.summary.scalar("Accuracy", accuracy * 100.0, step=round_num))

Basically i just want to access to the test variable that the client send using client_update and do some operation on that list before tff.federated_mean function.

The problem maybe is that run_one_round is a tff.federated_computation ?

Maybe try unstacking value_comp :

import tensorflow as tf
import numpy as np

value_comp = tf.constant(np.random.random((1, 8, 5)))
value_comp = tf.unstack(value_comp)
testing = []
for index in value_comp:
    testing.append(
            sparse_ternary_compression.stc_decompression(index[0], index[1],
                                                         index[2], index[3],
                                                         index[4]))

It may be helpful to go through Building Your Own Federated Learning Algorithm tutorial first.

I think the main thing you should keep in mind, is that any TensorFlow code, or any structure manipulation should be inside of a tff.tf_computation -decorated method. And such building blocks are then connected using the tff.federated_* operators inside the scope of a tff.federated_computation -decorated method.

I assume that the stc_decompression in your code snippet is some kind of TensorFlow logic. What you are passing into it, however, are not any TF-understandable values, but abstract representations of results of computations which are primarily internal implementation details of TFF.

So, whatever you want to do with those methods, do it in a tff.tf_computation decorated method, inside of which you are write any TF code. And you will get your value into it by invoking the method using the tff.federated_map operator.

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM