简体   繁体   中英

How to save weights in tensorflow federated

I want to save weights only when loss is getting lower and reuse them for evaluation.

lowest_loss = Inf

    if loss[round] < lowest_loss:
        lowest_loss = loss[round]

        model_weights = transfer_learning_iterative_process.get_model_weights(state)



eval_metric = federated_eval(model_weights, [fed_valid_data])

where:

  federated_eval = tff.learning.build_federated_evaluation(model_fn)

Is there a possible way to save server weights in hdf5 format or as a checkpoint and reuse it?

Yes, this can be done with helpers in TFF. Generally, this kind of functionality is implemented by tff.program.ProgramStateManagers . An implementation which saves to a filesystem can be found here , and example usages can be found in the implementation of tff.simulation.run_training_process .

The technical post webpages of this site follow the CC BY-SA 4.0 protocol. If you need to reprint, please indicate the site URL or the original address.Any question please contact:yoyou2525@163.com.

 
粤ICP备18138465号  © 2020-2024 STACKOOM.COM