Skip to content

Commit

Permalink
🐛Fix Adversary Initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
lishenghui committed May 13, 2024
1 parent a78c089 commit 2dc39ad
Show file tree
Hide file tree
Showing 14 changed files with 47 additions and 95 deletions.
2 changes: 0 additions & 2 deletions blades/adversaries/alie_adversary.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@

class ALIEAdversary(Adversary):
def on_trainer_init(self, trainer: Trainer):
# super().__init__(clients, global_config)
# trainer.config.num_clients = len(trainer.clients)
self.num_clients = trainer.config.num_clients
num_byzantine = len(self.clients)

Expand Down
8 changes: 3 additions & 5 deletions blades/adversaries/minmax_adversary.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import random
from typing import Dict

import torch

Expand All @@ -10,10 +9,10 @@


class MinMaxAdversary(Adversary):
def __init__(self, clients, global_config: Dict = None):
super().__init__(clients, global_config)
def __init__(self, threshold=1.0):
super().__init__()

self.threshold = 3.0
self.threshold = threshold
self.threshold_diff = 1e-4
self.num_byzantine = None
self.negative_indices = None
Expand All @@ -27,7 +26,6 @@ def on_local_round_end(self, algorithm: Algorithm):
self.num_byzantine += 1

updates = self._attack_by_binary_search(algorithm)
# updates = self._attack_median_and_trimmedmean(algorithm)
self.num_byzantine = 0
for result in algorithm.local_results:
client = algorithm.client_manager.get_client_by_id(result[CLIENT_ID])
Expand Down
5 changes: 2 additions & 3 deletions blades/adversaries/signflip_adversary.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from fedlib.trainers import Trainer as Algorithm
from fedlib.trainers import Trainer
from fedlib.clients import ClientCallback
from .adversary import Adversary


class SignFlipAdversary(Adversary):
def on_algorithm_start(self, algorithm: Algorithm):
def on_trainer_init(self, trainer: Trainer):
class SignFlipCallback(ClientCallback):
def on_backward_end(self, task):
model = task.model
# breakpoint()
for _, para in model.named_parameters():
para.grad.data = -para.grad.data

Expand Down
8 changes: 0 additions & 8 deletions blades/adversaries/signguard_adversary.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Dict

import torch

from fedlib.trainers import Trainer as Algorithm
Expand All @@ -8,25 +6,19 @@


def find_orthogonal_unit_vector(v):
# 随机生成一个与v不平行的向量
random_vector = torch.randn_like(v)

# 使用Gram-Schmidt正交化方法
orthogonal_vector = (
random_vector - torch.dot(random_vector, v) / torch.dot(v, v) * v
)

# 标准化向量
orthogonal_vector[-200] = 20000
orthogonal_vector[-100] = 20000
orthogonal_unit_vector = orthogonal_vector / torch.norm(orthogonal_vector)
return orthogonal_unit_vector


class SignGuardAdversary(Adversary):
def __init__(self, clients, global_config: Dict = None):
super().__init__(clients, global_config)

def on_local_round_end(self, algorithm: Algorithm):
updates = self._attack_sign_guard(algorithm)
for result in algorithm.local_results:
Expand Down
4 changes: 2 additions & 2 deletions blades/aggregators/aggregators.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def obj_func(median, inputs, weights):
# return (torch.sum(norms * weights) / torch.sum(weights)).item()

return np.average(
[torch.norm(p - median).item() for p in inputs],
[torch.norm(p - median, p=1).item() for p in inputs],
weights=weights.cpu(),
)

Expand All @@ -107,7 +107,7 @@ def obj_func(median, inputs, weights):
# Weiszfeld iterations
for _ in range(maxiter):
prev_obj_value = objective_value
denom = torch.stack([torch.norm(p - median) for p in inputs])
denom = torch.stack([torch.norm(p - median, p=1) for p in inputs])
new_weights = weights / torch.clamp(denom, min=eps)
median = weighted_average(inputs, new_weights)

Expand Down
38 changes: 0 additions & 38 deletions blades/algorithms/fedavg/fedavg.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,44 +112,6 @@ def get_default_config(cls) -> FedavgTrainerConfig:
def setup(self, config: FedavgTrainerConfig):
super().setup(config)

# self.adversary = self.config.get_adversary_config().build(
# self.client_manager.clients[: self.config.num_malicious_clients]
# )
# self.adversary.on_algorithm_start(self)

# def training_step(self):
# self.worker_group.sync_weights(self.server.get_global_model().state_dict())

# def local_training(worker, client):
# dataset = worker.dataset.get_client_dataset(client.client_id)
# result = client.train_one_round(dataset)
# return result

# clients = self.client_manager.trainable_clients
# self.local_results = self.worker_group.foreach_execution(
# local_training, clients
# )

# self.adversary.on_local_round_end(self)
# updates = [result.pop(CLIENT_UPDATE, None) for result in self.local_results]

# losses = []
# for result in self.local_results:
# client = self.client_manager.get_client_by_id(result[CLIENT_ID])
# if not client.is_malicious:
# loss = result.pop("avg_loss")
# losses.append(loss)

# self._counters[NUM_GLOBAL_STEPS] += 1
# global_vars = {
# "timestep": self._counters[NUM_GLOBAL_STEPS],
# }
# results = {"train_loss": np.mean(losses)}
# server_return = self.server.step(updates, global_vars)
# results.update(server_return)

# return results

def evaluate(self):
self.worker_group.sync_weights(self.server.get_global_model().state_dict())

Expand Down
37 changes: 18 additions & 19 deletions blades/tuned_examples/fedavg_fashion_mnist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ fedavg_blades:
num_clients: 60
train_batch_size: 64

partitioner_config:
type: IIDPartitioner
num_clients: 60

evaluation_interval: 50

num_remote_workers: 0
Expand All @@ -21,11 +25,6 @@ fedavg_blades:
num_cpus_for_driver: 2
num_gpus_for_driver: 0.3

# num_remote_workers: 14
# num_gpus_per_worker: 0.25
# num_cpus_per_worker: 2
# num_cpus_for_driver: 2
# num_gpus_for_driver: 0.5

# num_clients: 2
# global_model:
Expand All @@ -47,12 +46,12 @@ fedavg_blades:
aggregator:
grid_search: [
# type: Mean,
type: Clippedclustering,
type: Median,
# type: Clippedclustering,
# type: Median,
type: GeoMed,
type: DnC,
type: Trimmedmean,
type: Signguard,
# type: DnC,
# type: Trimmedmean,
# type: Signguard,
]

optimizer:
Expand All @@ -67,18 +66,18 @@ fedavg_blades:

num_malicious_clients:
# grid_search: [0]
grid_search: [3, 6, 9, 12, 15, 18]
grid_search: [9, 12, 15]
adversary_config:
# type: blades.adversaries.LabelFlipAdversary
# type: blades.adversaries.SignFlipAdversary
# type: blades.adversaries.AdaptiveAdversary
grid_search:
- type: blades.adversaries.ALIEAdversary
- type: blades.adversaries.LabelFlipAdversary
- type: blades.adversaries.NoiseAdversary
- type: blades.adversaries.SignFlipAdversary
- type: blades.adversaries.IPMAdversary
scale: 0.1
- type: blades.adversaries.IPMAdversary
scale: 100
# - type: blades.adversaries.ALIEAdversary
# - type: blades.adversaries.LabelFlipAdversary
# - type: blades.adversaries.NoiseAdversary
# - type: blades.adversaries.SignFlipAdversary
# - type: blades.adversaries.IPMAdversary
# scale: 0.1
# - type: blades.adversaries.IPMAdversary
# scale: 100
- type: blades.adversaries.MinMaxAdversary
2 changes: 1 addition & 1 deletion blades/tuned_examples/fedavg_toy_mnist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ fedavg_blades:
train_batch_size: 64
partitioner_config:
type: DirichletPartitioner
# alpha: 100.
alpha: 100.
num_clients: 4

evaluation_interval: 50
Expand Down
9 changes: 5 additions & 4 deletions blades/tuned_examples/fedsgd_cifar10.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ fedavg_blades:

partitioner_config:
type: DirichletPartitioner
# alpha: 100.
alpha: 100.
num_clients: 60

evaluation_interval: 50
Expand All @@ -41,9 +41,10 @@ fedavg_blades:
server_config:
aggregator:
grid_search: [
type: Mean,
# type: Mean,
# type: Clippedclustering,
# type: Median,
type: blades.aggregators.GeoMed,
# type: GeoMed,
# type: DnC,
# type: Trimmedmean,
Expand All @@ -68,10 +69,10 @@ fedavg_blades:
grid_search:
# - type: blades.adversaries.ALIEAdversary
# - type: blades.adversaries.LabelFlipAdversary
# - type: blades.adversaries.NoiseAdversary
- type: blades.adversaries.NoiseAdversary
- type: blades.adversaries.SignFlipAdversary
# - type: blades.adversaries.IPMAdversary
# scale: 0.1
# - type: blades.adversaries.IPMAdversary
# scale: 100
# - type: blades.adversaries.MinMaxAdversary
- type: blades.adversaries.MinMaxAdversary
23 changes: 13 additions & 10 deletions blades/tuned_examples/fedsgd_cnn_fashion_mnist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ fedavg_blades:
type: FashionMNIST
num_clients: 60
train_batch_size: 64
partitioner_config:
type: IIDPartitioner
num_clients: 60

evaluation_interval: 50

Expand All @@ -36,12 +39,12 @@ fedavg_blades:
aggregator:
grid_search: [
# type: Mean,
type: Clippedclustering,
type: Median,
# type: Clippedclustering,
# type: Median,
type: GeoMed,
type: DnC,
type: Trimmedmean,
type: Signguard,
# type: DnC,
# type: Trimmedmean,
# type: Signguard,
# type: Multikrum,
# type: Centeredclipping
]
Expand All @@ -56,16 +59,16 @@ fedavg_blades:
num_malicious_clients:
# grid_search: [0,9]
# grid_search: [3, 6]
grid_search: [3, 6, 9, 12, 15, 18]
grid_search: [12, 15, 18]

adversary_config:
grid_search:
- type: blades.adversaries.ALIEAdversary
- type: blades.adversaries.LabelFlipAdversary
- type: blades.adversaries.NoiseAdversary
- type: blades.adversaries.SignFlipAdversary
- type: blades.adversaries.IPMAdversary
scale: 0.1
- type: blades.adversaries.IPMAdversary
scale: 100
# - type: blades.adversaries.IPMAdversary
# scale: 0.1
# - type: blades.adversaries.IPMAdversary
# scale: 100
- type: blades.adversaries.MinMaxAdversary
6 changes: 3 additions & 3 deletions blades/tuned_examples/local20_resnet_cifar10.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ fedavg_blades:
train_batch_size: 64

partitioner_config:
type: DirichletPartitioner
alpha:
grid_search: [0.1, 1.0]
type: IIDPartitioner
# alpha:
# grid_search: [0.1, 1.0]
num_clients: 20

evaluation_interval: 50
Expand Down
Binary file modified docs/source/images/client_pipeline.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/source/images/dirichlet_partition.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/source/images/shard_partition.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 2dc39ad

Please sign in to comment.