-
-
Notifications
You must be signed in to change notification settings - Fork 74
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #100 from KarhouTam/dev
Periodic update from `dev`
- Loading branch information
Showing
4 changed files
with
156 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from copy import deepcopy | ||
from typing import Any | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
|
||
from src.client.fedavg import FedAvgClient | ||
from src.utils.constants import NUM_CLASSES | ||
from src.utils.models import DecoupledModel | ||
|
||
|
||
class FedASClient(FedAvgClient): | ||
def __init__(self, **commons): | ||
super().__init__(**commons) | ||
self.prev_model: DecoupledModel = deepcopy(self.model) | ||
|
||
def get_fim_trace_sum(self) -> float: | ||
self.model.eval() | ||
self.dataset.eval() | ||
|
||
fim_trace_sum = 0 | ||
|
||
for x, y in self.trainloader: | ||
x, y = x.to(self.device), y.to(self.device) | ||
logits = self.model(x) | ||
loss = ( | ||
-F.log_softmax(logits, dim=1).gather(dim=1, index=y.unsqueeze(1)).mean() | ||
) | ||
|
||
self.model.zero_grad() | ||
loss.backward() | ||
|
||
for param in self.model.parameters(): | ||
if param.grad is not None: | ||
fim_trace_sum += (param.grad.data**2).sum().item() | ||
|
||
return fim_trace_sum | ||
|
||
def package(self): | ||
client_package = super().package() | ||
# FedAS uses the sum of FIM traces as the weight | ||
client_package["weight"] = self.get_fim_trace_sum() | ||
client_package["prev_model_state"] = deepcopy(self.model.state_dict()) | ||
return client_package | ||
|
||
def set_parameters(self, package: dict[str, Any]) -> None: | ||
super().set_parameters(package) | ||
if package["prev_model_state"] is not None: | ||
self.prev_model.load_state_dict(package["prev_model_state"]) | ||
else: | ||
self.prev_model.load_state_dict(self.model.state_dict()) | ||
if not self.testing: | ||
self.align_federated_parameters() | ||
else: | ||
# FedAS evaluates clients' personalized models | ||
self.model.load_state_dict(self.prev_model.state_dict()) | ||
|
||
def align_federated_parameters(self): | ||
self.prev_model.eval() | ||
self.prev_model.to(self.device) | ||
self.model.train() | ||
self.dataset.train() | ||
|
||
prototypes = [[] for _ in range(NUM_CLASSES[self.args.dataset.name])] | ||
|
||
with torch.no_grad(): | ||
for x, y in self.trainloader: | ||
x, y = x.to(self.device), y.to(self.device) | ||
features = self.prev_model.get_last_features(x) | ||
|
||
for y, feat in zip(y, features): | ||
prototypes[y].append(feat) | ||
|
||
mean_prototypes = [ | ||
torch.stack(prototype).mean(dim=0) if prototype else None | ||
for prototype in prototypes | ||
] | ||
|
||
alignment_optimizer = torch.optim.SGD( | ||
self.model.base.parameters(), lr=self.args.fedas.alignment_lr | ||
) | ||
|
||
for _ in range(self.args.fedas.alignment_epoch): | ||
for x, y in self.trainloader: | ||
x, y = x.to(self.device), y.to(self.device) | ||
features = self.model.get_last_features(x, detach=False) | ||
loss = 0 | ||
for label in y.unique().tolist(): | ||
if mean_prototypes[label] is not None: | ||
loss += F.mse_loss( | ||
features[y == label].mean(dim=0), mean_prototypes[label] | ||
) | ||
|
||
alignment_optimizer.zero_grad() | ||
loss.backward() | ||
alignment_optimizer.step() | ||
|
||
self.prev_model.cpu() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from argparse import ArgumentParser, Namespace | ||
from typing import Any, Dict | ||
|
||
from omegaconf import DictConfig | ||
|
||
from src.client.fedas import FedASClient | ||
from src.server.fedavg import FedAvgServer | ||
|
||
|
||
class FedASServer(FedAvgServer): | ||
@staticmethod | ||
def get_hyperparams(args_list=None) -> Namespace: | ||
parser = ArgumentParser() | ||
parser.add_argument("--alignment_lr", type=float, default=0.01) | ||
parser.add_argument("--alignment_epoch", type=int, default=1) | ||
return parser.parse_args(args_list) | ||
|
||
def __init__( | ||
self, | ||
args: DictConfig, | ||
algorithm_name: str = "FedAS", | ||
unique_model=False, | ||
use_fedavg_client_cls=False, | ||
return_diff=False, | ||
): | ||
super().__init__( | ||
args, algorithm_name, unique_model, use_fedavg_client_cls, return_diff | ||
) | ||
self.client_prev_model_states: Dict[int, Dict[str, Any]] = {} | ||
self.init_trainer(FedASClient) | ||
|
||
def train_one_round(self): | ||
"""The function of indicating specific things FL method need to do (at | ||
server side) in each communication round.""" | ||
|
||
client_packages = self.trainer.train() | ||
for client_id, package in client_packages.items(): | ||
self.client_prev_model_states[client_id] = package["prev_model_state"] | ||
self.aggregate(client_packages) | ||
|
||
def package(self, client_id: int): | ||
server_package = super().package(client_id) | ||
if client_id in self.client_prev_model_states: | ||
server_package["prev_model_state"] = self.client_prev_model_states[ | ||
client_id | ||
] | ||
else: | ||
server_package["prev_model_state"] = None | ||
return server_package |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters