Skip to content

Commit

Permalink
Merge pull request #100 from KarhouTam/dev
Browse files Browse the repository at this point in the history
Periodic update from `dev`
  • Loading branch information
KarhouTam authored Oct 6, 2024
2 parents d0fe5a3 + ddf06bb commit dc52369
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 4 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ FL-bench welcomes PR on everything that can make this project better.
- ***FedPAC*** -- [Personalized Federated Learning with Feature Alignment and Classifier Collaboration](https://arxiv.org/abs/2306.11867v1) (ICLR'23)
- ***PeFLL*** -- [PeFLL: Personalized Federated Learning by Learning to Learn](https://openreview.net/forum?id=MrYiwlDRQO) (ICLR'24)
- ***FLUTE*** -- [Federated Representation Learning in the Under-Parameterized Regime](https://openreview.net/forum?id=LIQYhV45D4) (ICML'24)
- ***FedAS*** -- [FedAS: Bridging Inconsistency in Personalized Federated Learning](https://openaccess.thecvf.com/content/CVPR2024/html/Yang_FedAS_Bridging_Inconsistency_in_Personalized_Federated_Learning_CVPR_2024_paper.html) (CVPR'24)
<!-- </details> -->


Expand Down
98 changes: 98 additions & 0 deletions src/client/fedas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from copy import deepcopy
from typing import Any

import torch
import torch.nn.functional as F

from src.client.fedavg import FedAvgClient
from src.utils.constants import NUM_CLASSES
from src.utils.models import DecoupledModel


class FedASClient(FedAvgClient):
def __init__(self, **commons):
super().__init__(**commons)
self.prev_model: DecoupledModel = deepcopy(self.model)

def get_fim_trace_sum(self) -> float:
self.model.eval()
self.dataset.eval()

fim_trace_sum = 0

for x, y in self.trainloader:
x, y = x.to(self.device), y.to(self.device)
logits = self.model(x)
loss = (
-F.log_softmax(logits, dim=1).gather(dim=1, index=y.unsqueeze(1)).mean()
)

self.model.zero_grad()
loss.backward()

for param in self.model.parameters():
if param.grad is not None:
fim_trace_sum += (param.grad.data**2).sum().item()

return fim_trace_sum

def package(self):
client_package = super().package()
# FedAS uses the sum of FIM traces as the weight
client_package["weight"] = self.get_fim_trace_sum()
client_package["prev_model_state"] = deepcopy(self.model.state_dict())
return client_package

def set_parameters(self, package: dict[str, Any]) -> None:
super().set_parameters(package)
if package["prev_model_state"] is not None:
self.prev_model.load_state_dict(package["prev_model_state"])
else:
self.prev_model.load_state_dict(self.model.state_dict())
if not self.testing:
self.align_federated_parameters()
else:
# FedAS evaluates clients' personalized models
self.model.load_state_dict(self.prev_model.state_dict())

def align_federated_parameters(self):
self.prev_model.eval()
self.prev_model.to(self.device)
self.model.train()
self.dataset.train()

prototypes = [[] for _ in range(NUM_CLASSES[self.args.dataset.name])]

with torch.no_grad():
for x, y in self.trainloader:
x, y = x.to(self.device), y.to(self.device)
features = self.prev_model.get_last_features(x)

for y, feat in zip(y, features):
prototypes[y].append(feat)

mean_prototypes = [
torch.stack(prototype).mean(dim=0) if prototype else None
for prototype in prototypes
]

alignment_optimizer = torch.optim.SGD(
self.model.base.parameters(), lr=self.args.fedas.alignment_lr
)

for _ in range(self.args.fedas.alignment_epoch):
for x, y in self.trainloader:
x, y = x.to(self.device), y.to(self.device)
features = self.model.get_last_features(x, detach=False)
loss = 0
for label in y.unique().tolist():
if mean_prototypes[label] is not None:
loss += F.mse_loss(
features[y == label].mean(dim=0), mean_prototypes[label]
)

alignment_optimizer.zero_grad()
loss.backward()
alignment_optimizer.step()

self.prev_model.cpu()
49 changes: 49 additions & 0 deletions src/server/fedas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from argparse import ArgumentParser, Namespace
from typing import Any, Dict

from omegaconf import DictConfig

from src.client.fedas import FedASClient
from src.server.fedavg import FedAvgServer


class FedASServer(FedAvgServer):
@staticmethod
def get_hyperparams(args_list=None) -> Namespace:
parser = ArgumentParser()
parser.add_argument("--alignment_lr", type=float, default=0.01)
parser.add_argument("--alignment_epoch", type=int, default=1)
return parser.parse_args(args_list)

def __init__(
self,
args: DictConfig,
algorithm_name: str = "FedAS",
unique_model=False,
use_fedavg_client_cls=False,
return_diff=False,
):
super().__init__(
args, algorithm_name, unique_model, use_fedavg_client_cls, return_diff
)
self.client_prev_model_states: Dict[int, Dict[str, Any]] = {}
self.init_trainer(FedASClient)

def train_one_round(self):
"""The function of indicating specific things FL method need to do (at
server side) in each communication round."""

client_packages = self.trainer.train()
for client_id, package in client_packages.items():
self.client_prev_model_states[client_id] = package["prev_model_state"]
self.aggregate(client_packages)

def package(self, client_id: int):
server_package = super().package(client_id)
if client_id in self.client_prev_model_states:
server_package["prev_model_state"] = self.client_prev_model_states[
client_id
]
else:
server_package["prev_model_state"] = None
return server_package
12 changes: 8 additions & 4 deletions src/server/fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import os
import pickle
import random
import shutil
import traceback
import time
import warnings
from collections import OrderedDict
Expand Down Expand Up @@ -697,15 +699,17 @@ def run(self):
try:
self.train()
except KeyboardInterrupt:
# when user press Ctrl+C
# indicates that this run should be considered as useless and deleted.
# when user manually terminates the run, FL-bench
# indicates that run should be considered as useless and deleted.
self.logger.close()
del self.train_progress_bar
if self.args.common.delete_useless_run:
if os.path.isdir(self.output_dir):
os.removedirs(self.output_dir)
shutil.rmtree(self.output_dir)
return
except:
except Exception as e:
self.logger.log(traceback.format_exc())
self.logger.log(f"Exception occurred: {e}")
self.logger.close()
del self.train_progress_bar
raise
Expand Down

0 comments on commit dc52369

Please sign in to comment.