diff --git a/notebooks/QLoRA_Llama_Finetuning.ipynb b/notebooks/QLoRA_Llama_Finetuning.ipynb index af2f363f29..f2f5044291 100644 --- a/notebooks/QLoRA_Llama_Finetuning.ipynb +++ b/notebooks/QLoRA_Llama_Finetuning.ipynb @@ -261,9 +261,7 @@ "metadata": {}, "outputs": [], "source": [ - "# for _, v in model.get_adapter(\"assistant_adapter\").items():\n", - "# for _, module in v.items():\n", - "# module.to(\"cuda\")" + "# model.adapter_to(\"assistant_adapter\", device=\"cuda\")" ] }, { diff --git a/src/adapters/methods/bottleneck.py b/src/adapters/methods/bottleneck.py index 6681ed4bc5..7ebae5221a 100644 --- a/src/adapters/methods/bottleneck.py +++ b/src/adapters/methods/bottleneck.py @@ -176,6 +176,13 @@ def get_adapter(self, adapter_name: str): else: return None + def get_adapter_fusion(self, adapter_names: Union[List, str]): + adapter_names = adapter_names if isinstance(adapter_names, str) else ",".join(adapter_names) + if adapter_names in self.adapter_fusion_layer: + return self.adapter_fusion_layer[adapter_names] + else: + return None + def pre_block(self, adapter_setup: Union[AdapterCompositionBlock, str], state: BottleneckState) -> BottleneckState: if isinstance(adapter_setup, AdapterCompositionBlock): adapter_name = adapter_setup.first() diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 56d643b06c..f172230c28 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -1035,6 +1035,42 @@ def get_adapter(self, name) -> dict: return dict(destination) + def adapter_to( + self, name: str, device: Optional[Union[torch.device, str]] = None, dtype: Optional[torch.dtype] = None + ): + """ + Moves the adapter with the given name to the specified device and data type. + + Args: + name (str): The name of the adapter to be moved. + device (torch.device or str, optional): The device on which the adapter should be moved. + dtype (torch.dtype, optional): The data type to which the adapter should be cast. + """ + for _, v in self.get_adapter(name).items(): + for _, module in v.items(): + module.to(device=device, dtype=dtype) + + def adapter_fusion_to( + self, + adapter_names: Union[Fuse, list, str], + device: Optional[Union[torch.device, str]] = None, + dtype: Optional[torch.dtype] = None, + ): + """ + Moves the adapter fusion layer with the given name to the specified device and data type. + + Args: + adapter_names (Union[Fuse, list, str]): The name of the adapter fusion layer to be moved. + device (torch.device or str, optional): The device on which the adapter fusion layer should be moved. + dtype (torch.dtype, optional): The data type to which the adapter fusion layer should be cast. + """ + for _, layer in self.iter_layers(): + for module in layer.modules(): + if isinstance(module, BottleneckLayer): + fusion = module.get_adapter_fusion(adapter_names) + if fusion is not None: + fusion.to(device=device, dtype=dtype) + def adapter_summary(self, as_dict=False) -> Union[str, dict]: """ Returns a string summary of all adapters currently added to the model. Each entry in the summary table has the diff --git a/src/adapters/trainer.py b/src/adapters/trainer.py index ff915afceb..6be5b3ee7c 100644 --- a/src/adapters/trainer.py +++ b/src/adapters/trainer.py @@ -212,6 +212,7 @@ def _load_best_model(self): adapter_dir = os.path.join(self.state.best_model_checkpoint, adapter) if os.path.exists(adapter_dir): model.load_adapter(adapter_dir) + model.adapter_to(adapter, device=self.args.device) if self.train_adapter_fusion: logger.info( f"Loading best adapter fusion(s) from {self.state.best_model_checkpoint} (score:" @@ -222,7 +223,7 @@ def _load_best_model(self): fusion_dir = os.path.join(self.state.best_model_checkpoint, fusion) if os.path.exists(fusion_dir): model.load_adapter_fusion(fusion_dir) - model.to(self.args.device) + model.adapter_fusion_to(fusion, device=self.args.device) class AdapterTrainerCallback(TrainerCallback):