diff --git a/src/main.py b/src/main.py index c2265ec..d6f004d 100755 --- a/src/main.py +++ b/src/main.py @@ -212,12 +212,11 @@ else: for key in v.keys(): v[key] = update_avg[key] + v[key] * args.server_momentum - #new_weights = deepcopy(model.state_dict()) - #for key in new_weights.keys(): - # new_weights[key] = new_weights[key] - v[key] * args.server_lr - #model.load_state_dict(new_weights) - for key in model.state_dict(): - model.state_dict()[key] -= v[key] * args.server_lr + + for key in model.state_dict(): + last_dot_index = key.rfind('.') + if key[last_dot_index + 1:] != "num_batches_tracked": + model.state_dict()[key] -= v[key] * args.server_lr # Compute round average loss and accuracies if round % args.server_stats_every == 0: