diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index edb1273381..04294735d3 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -444,6 +444,19 @@ def _enable_peft_forward_hooks(self, *args, **kwargs): if self.training: raise ValueError("Cannot pass `adapter_names` when the model is in training mode.") + # Check that users only passed actually existing adapters. + # Note: We cannot do this on the layer level, as each individual layer may not have each adapter. Still, we want + # to check that there is at least one layer with the given name, or else something like typos can easily slip. + expected_adapters = set() + for layer in self.modules(): + if isinstance(layer, LoraLayer): + expected_adapters |= layer.lora_A.keys() + expected_adapters |= layer.lora_embedding_A.keys() + unique_adapters = {name for name in adapter_names if name != "__base__"} + unexpected_adapters = unique_adapters - expected_adapters + if unexpected_adapters: + raise ValueError(f"Trying to infer with non-existing adapter(s): {', '.join(sorted(unexpected_adapters))}") + hook_handles = [] for module in self.modules(): if isinstance(module, LoraLayer) or isinstance(module, ModulesToSaveWrapper): diff --git a/tests/test_custom_models.py b/tests/test_custom_models.py index aa747ad245..611b07bf97 100644 --- a/tests/test_custom_models.py +++ b/tests/test_custom_models.py @@ -3536,13 +3536,40 @@ def test_mixed_adapter_batches_lora_merged_raises(self, mlp_lora): # When there are merged adapters, passing adapter names should raise an error inputs = { "X": torch.arange(90).view(-1, 10).to(self.torch_device), - "adapter_names": ["default"] * 9, + "adapter_names": ["adapter0"] * 9, } mlp_lora.merge_adapter(["adapter0"]) msg = r"Cannot pass `adapter_names` when there are merged adapters, please call `unmerge_adapter` first." with pytest.raises(ValueError, match=msg): mlp_lora.forward(**inputs) + def test_mixed_adapter_batches_lora_wrong_adapter_name_raises(self): + # Ensure that all of the adapter names that are being passed actually exist + torch.manual_seed(0) + x = torch.arange(90).view(-1, 10).to(self.torch_device) + + base_model = MLP().to(self.torch_device).eval() + config = LoraConfig(target_modules=["lin0"], init_lora_weights=False) + peft_model = get_peft_model(base_model, config).eval() + peft_model.add_adapter(adapter_name="other", peft_config=config) + + # sanity check: this works + peft_model.forward(x, adapter_names=["default"] * 5 + ["other"] * 4) + + # check one correct and one incorrect adapter + msg = re.escape("Trying to infer with non-existing adapter(s): does-not-exist") + with pytest.raises(ValueError, match=msg): + peft_model.forward(x, adapter_names=["default"] * 5 + ["does-not-exist"] * 4) + + # check two correct adapters and one incorrect adapter + with pytest.raises(ValueError, match=msg): + peft_model.forward(x, adapter_names=["default"] * 3 + ["does-not-exist"] * 4 + ["other"] * 2) + + # check only incorrect adapters + msg = re.escape("Trying to infer with non-existing adapter(s): does-not-exist, other-does-not-exist") + with pytest.raises(ValueError, match=msg): + peft_model.forward(x, adapter_names=["does-not-exist"] * 5 + ["other-does-not-exist"] * 4) + def test_mixed_adapter_batches_lora_with_dora_raises(self): # When there are DoRA adapters, passing adapter names should raise an error torch.manual_seed(0)