Skip to content

Commit

Permalink
working on the federated optimisation
Browse files Browse the repository at this point in the history
  • Loading branch information
ouaelesi committed Mar 12, 2024
1 parent b0a5bb3 commit c494829
Show file tree
Hide file tree
Showing 12 changed files with 2,550 additions and 408 deletions.
Binary file modified Medfl/LearningManager/__pycache__/client.cpython-38.pyc
Binary file not shown.
Binary file modified Medfl/LearningManager/__pycache__/flpipeline.cpython-38.pyc
Binary file not shown.
Binary file modified Medfl/LearningManager/__pycache__/server.cpython-38.pyc
Binary file not shown.
Binary file modified Medfl/LearningManager/__pycache__/strategy.cpython-38.pyc
Binary file not shown.
6 changes: 5 additions & 1 deletion Medfl/LearningManager/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,12 @@ def fit(self, parameters, config):
Returns:
Tuple: Parameters of the local model, number of training examples, and privacy information.
"""
print('\n -------------------------------- \n this is the config of the client')
print(f"[Client {self.cid}] fit, config: {config}")
print(config['epochs'])
print('\n -------------------------------- \n ')
self.local_model.set_parameters(parameters)
for _ in range(params["train_epochs"]):
for _ in range(config['epochs']):
epsilon = self.local_model.train(
self.trainloader,
epoch=_,
Expand Down Expand Up @@ -131,4 +134,5 @@ def evaluate(self, parameters, config):
)
self.losses.append(loss)
self.accuracies.append(accuracy)

return float(loss), len(self.valloader), {"accuracy": float(accuracy)}
4 changes: 2 additions & 2 deletions Medfl/LearningManager/flpipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ def validate(self) -> None:
if not isinstance(self.description, str):
raise TypeError("description argument must be a string")

if not isinstance(self.server, FlowerServer):
raise TypeError("server argument must be a FlowerServer")
# if not isinstance(self.server, FlowerServer):
# raise TypeError("server argument must be a FlowerServer")

def create(self, result: str) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion Medfl/LearningManager/params.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
num_rounds: 100
min_evalclient: 2
# path_to_master_csv : 'D:\ESI\3CS\PFE\last_year\Code\MEDfl\notebooks\sapsii_score_knnimputed_eicu.csv'
path_to_master_csv : '/home/local/USHERBROOKE/saho6810/MEDfl/code/MEDfl/notebooks/data/masterDataSet/Mimic_train.csv'
path_to_master_csv : '/home/local/USHERBROOKE/saho6810/MEDfl/code/MEDfl/notebooks/data/masterDataSet/Mimic_ouael.csv'

path_to_test_csv : '/home/local/USHERBROOKE/saho6810/MEDfl/code/MEDfl/notebooks/data/masterDataSet/Mimic_train.csv'

Expand Down
8 changes: 6 additions & 2 deletions Medfl/LearningManager/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,10 @@ def validate(self) -> None:
if not isinstance(self.global_model, Model):
raise TypeError("global_model argument must be a Model instance")

if not isinstance(self.strategy, Strategy):
raise TypeError("strategy argument must be a Strategy instance")
# if not isinstance(self.strategy, Strategy):
# print(self.strategy)
# print(isinstance(self.strategy, Strategy))
# raise TypeError("strategy argument must be a Strategy instance")

if not isinstance(self.num_clients, int):
raise TypeError("num_clients argument must be an int")
Expand Down Expand Up @@ -145,6 +147,8 @@ def evaluate(
loss, accuracy = self.global_model.evaluate(testloader, self.device)
self.losses.append(loss)
self.accuracies.append(accuracy)
if(server_round > 1 ):
self.strategy.study.tell(server_round-1 , accuracy)
print(f"Server-side evaluation loss {loss} / accuracy {accuracy}")
return loss, {"accuracy": accuracy}

Expand Down
73 changes: 57 additions & 16 deletions Medfl/LearningManager/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,33 +56,74 @@ def __init__(
self.min_evaluate_clients = min_evaluate_clients
self.min_available_clients = min_available_clients
self.initial_parameters = initial_parameters
self.evaluate_fn = None
# self.evaluate_fn = None
self.name = name
self.strategy_object = eval(
f"fl.server.strategy.{self.name}(\
fraction_fit={self.fraction_fit},\
fraction_evaluate= {self.fraction_evaluate},\
min_fit_clients= {self.min_fit_clients},\
min_evaluate_clients= {self.min_evaluate_clients},\
min_available_clients={self.min_available_clients},\
on_fit_config_fn : {self.fit_config},\
initial_parameters=fl.common.ndarrays_to_parameters(self.initial_parameters),\
evaluate_fn={self.evaluate_fn})"
)

def optuna_fed_optimization(self, direction:str , hpo_rate:int , params_config):
self.study = optuna.create_study(direction=direction)
self.hpo_rate = hpo_rate
self.params_config = params_config

def evaluate_fn(state, round_results):
print('\n ********************************* round results ')
print(round_results)
print('***********************')

# Return a dictionary containing the evaluation metrics
return {}

def fit_config(self, server_round:int)->Dict [ str , fl . common . Scalar ]:
if(server_round - 1)%(self.hpo_rate) == 0:
trial = self.stydy.ask()
trial = self.study.ask()
config = {
'learning_rate' :trial.suggest_float('learning_rate', **self.params_config['learning_rate']),
'optimiser' : trial.suggest_categorical('optimizer', self.params_config['optimizer']),
'fl_strategy' : trial.suggest_categorical('fl_strategy', self.params_config['fl_strategy']),
'reset': self.resetParams,
'epochs': 10
# 'reset': self.resetParams,
'epochs': trial.suggest_int('epochs', **self.params_config['epochs'])
}
return config

print("\n ======================== \n the fit config is used \n")
print(server_round)
print(config)
print("\n ======================== \n ")
if(server_round == 10):
print("\n *-*-*-*-*-*-*-*---*-*-*--*-*-*-*-* \n ")
print(self.study.best_params)
optuna.visualization.plot_parallel_coordinate(self.study)
print("\n *--*-*-*-*--*-*-*---*-**-**--*--*-*-*-*-*-*-*-**---*-* \n ")
return config

return {}

def create_strategy(self):
self.strategy_object = self.get_strategy_by_name()(
fraction_fit=self.fraction_fit,
fraction_evaluate=self.fraction_evaluate,
min_fit_clients=self.min_fit_clients,
min_evaluate_clients=self.min_evaluate_clients,
min_available_clients=self.min_available_clients,
on_fit_config_fn=self.fit_config,
initial_parameters=fl.common.ndarrays_to_parameters(self.initial_parameters),
evaluate_fn=self.evaluate_fn
)
def get_strategy_by_name(self):
return eval(f"fl.server.strategy.{self.name}")

def plot_param_importances(self):
return optuna.visualization.plot_param_importances(self.study)

def plot_slice(self , params):
return optuna.visualization.plot_slice(self.study , params=params)

def plot_parallel_coordinate(self):
return optuna.visualization.plot_parallel_coordinate(self.study)

def plot_rank(self , params=None):
return optuna.visualization.plot_rank(self.study , params=params)

def plot_optimization_history(self):
return optuna.visualization.plot_optimization_history(self.study)



Loading

0 comments on commit c494829

Please sign in to comment.