diff --git a/README.md b/README.md index 415bcc1..0b7024b 100644 --- a/README.md +++ b/README.md @@ -83,6 +83,7 @@ FL-bench welcomes PR on everything that can make this project better. - ***FedPAC*** -- [Personalized Federated Learning with Feature Alignment and Classifier Collaboration](https://arxiv.org/abs/2306.11867v1) (ICLR'23) - ***PeFLL*** -- [PeFLL: Personalized Federated Learning by Learning to Learn](https://openreview.net/forum?id=MrYiwlDRQO) (ICLR'24) - ***FLUTE*** -- [Federated Representation Learning in the Under-Parameterized Regime](https://openreview.net/forum?id=LIQYhV45D4) (ICML'24) +- ***FedAS*** -- [FedAS: Bridging Inconsistency in Personalized Federated Learning](https://openaccess.thecvf.com/content/CVPR2024/html/Yang_FedAS_Bridging_Inconsistency_in_Personalized_Federated_Learning_CVPR_2024_paper.html) (CVPR'24) diff --git a/src/client/fedas.py b/src/client/fedas.py new file mode 100644 index 0000000..e45b522 --- /dev/null +++ b/src/client/fedas.py @@ -0,0 +1,98 @@ +from copy import deepcopy +from typing import Any + +import torch +import torch.nn.functional as F + +from src.client.fedavg import FedAvgClient +from src.utils.constants import NUM_CLASSES +from src.utils.models import DecoupledModel + + +class FedASClient(FedAvgClient): + def __init__(self, **commons): + super().__init__(**commons) + self.prev_model: DecoupledModel = deepcopy(self.model) + + def get_fim_trace_sum(self) -> float: + self.model.eval() + self.dataset.eval() + + fim_trace_sum = 0 + + for x, y in self.trainloader: + x, y = x.to(self.device), y.to(self.device) + logits = self.model(x) + loss = ( + -F.log_softmax(logits, dim=1).gather(dim=1, index=y.unsqueeze(1)).mean() + ) + + self.model.zero_grad() + loss.backward() + + for param in self.model.parameters(): + if param.grad is not None: + fim_trace_sum += (param.grad.data**2).sum().item() + + return fim_trace_sum + + def package(self): + client_package = super().package() + # FedAS uses the sum of FIM traces as the weight + client_package["weight"] = self.get_fim_trace_sum() + client_package["prev_model_state"] = deepcopy(self.model.state_dict()) + return client_package + + def set_parameters(self, package: dict[str, Any]) -> None: + super().set_parameters(package) + if package["prev_model_state"] is not None: + self.prev_model.load_state_dict(package["prev_model_state"]) + else: + self.prev_model.load_state_dict(self.model.state_dict()) + if not self.testing: + self.align_federated_parameters() + else: + # FedAS evaluates clients' personalized models + self.model.load_state_dict(self.prev_model.state_dict()) + + def align_federated_parameters(self): + self.prev_model.eval() + self.prev_model.to(self.device) + self.model.train() + self.dataset.train() + + prototypes = [[] for _ in range(NUM_CLASSES[self.args.dataset.name])] + + with torch.no_grad(): + for x, y in self.trainloader: + x, y = x.to(self.device), y.to(self.device) + features = self.prev_model.get_last_features(x) + + for y, feat in zip(y, features): + prototypes[y].append(feat) + + mean_prototypes = [ + torch.stack(prototype).mean(dim=0) if prototype else None + for prototype in prototypes + ] + + alignment_optimizer = torch.optim.SGD( + self.model.base.parameters(), lr=self.args.fedas.alignment_lr + ) + + for _ in range(self.args.fedas.alignment_epoch): + for x, y in self.trainloader: + x, y = x.to(self.device), y.to(self.device) + features = self.model.get_last_features(x, detach=False) + loss = 0 + for label in y.unique().tolist(): + if mean_prototypes[label] is not None: + loss += F.mse_loss( + features[y == label].mean(dim=0), mean_prototypes[label] + ) + + alignment_optimizer.zero_grad() + loss.backward() + alignment_optimizer.step() + + self.prev_model.cpu() diff --git a/src/server/fedas.py b/src/server/fedas.py new file mode 100644 index 0000000..c4b1305 --- /dev/null +++ b/src/server/fedas.py @@ -0,0 +1,49 @@ +from argparse import ArgumentParser, Namespace +from typing import Any, Dict + +from omegaconf import DictConfig + +from src.client.fedas import FedASClient +from src.server.fedavg import FedAvgServer + + +class FedASServer(FedAvgServer): + @staticmethod + def get_hyperparams(args_list=None) -> Namespace: + parser = ArgumentParser() + parser.add_argument("--alignment_lr", type=float, default=0.01) + parser.add_argument("--alignment_epoch", type=int, default=1) + return parser.parse_args(args_list) + + def __init__( + self, + args: DictConfig, + algorithm_name: str = "FedAS", + unique_model=False, + use_fedavg_client_cls=False, + return_diff=False, + ): + super().__init__( + args, algorithm_name, unique_model, use_fedavg_client_cls, return_diff + ) + self.client_prev_model_states: Dict[int, Dict[str, Any]] = {} + self.init_trainer(FedASClient) + + def train_one_round(self): + """The function of indicating specific things FL method need to do (at + server side) in each communication round.""" + + client_packages = self.trainer.train() + for client_id, package in client_packages.items(): + self.client_prev_model_states[client_id] = package["prev_model_state"] + self.aggregate(client_packages) + + def package(self, client_id: int): + server_package = super().package(client_id) + if client_id in self.client_prev_model_states: + server_package["prev_model_state"] = self.client_prev_model_states[ + client_id + ] + else: + server_package["prev_model_state"] = None + return server_package diff --git a/src/server/fedavg.py b/src/server/fedavg.py index 657665a..66130ec 100644 --- a/src/server/fedavg.py +++ b/src/server/fedavg.py @@ -4,6 +4,8 @@ import os import pickle import random +import shutil +import traceback import time import warnings from collections import OrderedDict @@ -697,15 +699,17 @@ def run(self): try: self.train() except KeyboardInterrupt: - # when user press Ctrl+C - # indicates that this run should be considered as useless and deleted. + # when user manually terminates the run, FL-bench + # indicates that run should be considered as useless and deleted. self.logger.close() del self.train_progress_bar if self.args.common.delete_useless_run: if os.path.isdir(self.output_dir): - os.removedirs(self.output_dir) + shutil.rmtree(self.output_dir) return - except: + except Exception as e: + self.logger.log(traceback.format_exc()) + self.logger.log(f"Exception occurred: {e}") self.logger.close() del self.train_progress_bar raise