Federated Learning (FL) has emerged as a promising approach for real-world medical research, which enables the analysis of distributed patient data while preserving privacy. However, a key challenge in FL is covariate shift, where data distributions vary significantly across different hospitals, leading to performance degradation when models are deployed in new environments.
To address this, we introduce Federatively Weighted (FedWeight), a novel framework that re-weights patient data from source hospitals using density estimation techniques. This approach aligns the trained model with the target hospital’s data distribution, thus mitigating covariate shift.
To extend our method to unsupervised learning, we integrate FedWeight into a federated embedded topic model (ETM), resulting in FedWeight-ETM, which enhances topic modelling performance across diverse hospital datasets.
FedWeight training process. a. The target and source hospitals independently train their density estimators. b. The target hospital shares its density estimator with all source hospitals. c. Each source hospital calculates a patient-specific re-weight, which is used to train its local model. d. The target hospital aggregates the parameters from the source hospitals using the FL algorithm, which better generalizes the target hospital’s data. Symmetric FedWeight training process: e. Hospital A and B independently train their density estimators. f. Hospital A and B share their density estimators to each other. g. Hospital A and B regard each other as targets, estimating the re-weight, which is used to train its local model. h. Aggregate the local models from Hospital A and B using FL algorithm.
FedAvg-ETM and FedWeight-ETM training process. a. FedAvg-ETM. In each local hospital
- Python 3.8+
- Required libraries (see
requirements.txt
)
- Clone the repository from GitHub
git clone https://github.com/li-lab-mcgill/FedWeight.git
- Create a Python virtual environment
python -m venv venv
source venv/bin/activate
- Install the required libraries
pip install -r requirements.txt
- Create
data
directory under the root file. - Inside
data
directory, createeicu
andpretrained
subdirectory. - Download the eICU Collaborative Research Database (eICU) and put it inside the
data/eicu
folder. - Download the BioWordVec model from https://github.com/ncbi-nlp/BioSentVec and put it inside the
data/pretrained
folder. - Create
output
directory under the root file - Run
drug_harmonization.py
script. It may take a couple of hours. - Find the harmonized dataset
eicu_harmonized.csv
insideoutput
folder. - Run similar steps for Medical Information Mart for Intensive Care (MIMIC-III) dataset: https://mimic.mit.edu/.
- Combine both harmonized eICU and MIMIC-III dataset and put it inside
data
folder.
This step runs FedWeight experiments to predict clinical outcomes (i.e. mortality, ventilator, sepsis, ICU length of stay).
python3 main.py \
experiment=${experiment} \
experiment.task=${task} \
experiment.reweight_lambda=${reweight_lambda} \
experiment.algorithm=${algorithm} \
experiment.target_hospital_id=${target_hospital_id} \
experiment.density_estimator=${estimator}
Parameters:
experiment
: The experiment nameeicu
: Run FL model on eICUmimic
: Run FL model on MIMICeicu_central
: Run centralized model on eICUmimic_central
: Run centralized model on MIMIC
task
: The task namedeath
: Mortality predictionventilator
: Ventilator use predictionsepsis
: Sepsis diagnosis predictionlength
: ICU length of stay prediction
reweight_lambda
: Hyper-parameter that controls the degree of re-weightingalgorithm
: The FL algorithmweighted
: FedWeightunweighted
: FedAvg
target_hospital_id
: The target hospital ID ("167" "420" "199" "458" "252")density_estimator
: The density estimatormade
: MADEvae
: VAEvqvae
: VQ-VAE
After running the experiments, you can find the trained models saved in output
folder.
With trained FedWeight model, we can use it to identify drugs and lab tests associated with clinical outcomes. Under the notebook
directory, run analyze_causal.ipynb
notebook.
This step runs experiments to identify latent topics in federated learning settings.
Inside notebook/etm
folder, run the following commands:
# Non-federated ETM
python3 etm.py --num_topics ${num_topics}
# FedAvg-ETM
python3 fedavg_etm.py --num_topics ${num_topics}
# FedWeight-ETM
python3 fedweight_etm.py --num_topics ${num_topics}
Parameters:
num_topics
: Number of topics (e.g. 64)
By using FedWeight-ETM, we can identify mortality-associated topics and ICD codes. Under the notebook/etm
directory, run etm_analyse_icd_result.ipynb
notebook.