Skip to content

Commit

Permalink
FIX Raise mixed adapter infer with missing adapter (#2090)
Browse files Browse the repository at this point in the history
PEFT allows mixed batch adapter inference, i.e. when predicting, the
same batch can use different adapters by passing the adapter_names
argument. However, when users pass an adapter name that does not
correspond to any of the existing adapters, these samples are currently
being ignored (i.e. just the base model output is used). This is
unexpected and can easily lead to errors, e.g. when users mistype the
name of an adapter.

This PR fixes this issue by checking all the existing adapter names
first and comparing them to the adapter_names that the user passed. If
there are unexpected entries, an error is raised.

Due to this fix, an error in the test
test_mixed_adapter_batches_lora_merged_raises was discovered and
promptly fixed.
  • Loading branch information
BenjaminBossan authored Oct 9, 2024
1 parent 85e3202 commit 8efa0cb
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
13 changes: 13 additions & 0 deletions src/peft/tuners/lora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,19 @@ def _enable_peft_forward_hooks(self, *args, **kwargs):
if self.training:
raise ValueError("Cannot pass `adapter_names` when the model is in training mode.")

# Check that users only passed actually existing adapters.
# Note: We cannot do this on the layer level, as each individual layer may not have each adapter. Still, we want
# to check that there is at least one layer with the given name, or else something like typos can easily slip.
expected_adapters = set()
for layer in self.modules():
if isinstance(layer, LoraLayer):
expected_adapters |= layer.lora_A.keys()
expected_adapters |= layer.lora_embedding_A.keys()
unique_adapters = {name for name in adapter_names if name != "__base__"}
unexpected_adapters = unique_adapters - expected_adapters
if unexpected_adapters:
raise ValueError(f"Trying to infer with non-existing adapter(s): {', '.join(sorted(unexpected_adapters))}")

hook_handles = []
for module in self.modules():
if isinstance(module, LoraLayer) or isinstance(module, ModulesToSaveWrapper):
Expand Down
29 changes: 28 additions & 1 deletion tests/test_custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3536,13 +3536,40 @@ def test_mixed_adapter_batches_lora_merged_raises(self, mlp_lora):
# When there are merged adapters, passing adapter names should raise an error
inputs = {
"X": torch.arange(90).view(-1, 10).to(self.torch_device),
"adapter_names": ["default"] * 9,
"adapter_names": ["adapter0"] * 9,
}
mlp_lora.merge_adapter(["adapter0"])
msg = r"Cannot pass `adapter_names` when there are merged adapters, please call `unmerge_adapter` first."
with pytest.raises(ValueError, match=msg):
mlp_lora.forward(**inputs)

def test_mixed_adapter_batches_lora_wrong_adapter_name_raises(self):
# Ensure that all of the adapter names that are being passed actually exist
torch.manual_seed(0)
x = torch.arange(90).view(-1, 10).to(self.torch_device)

base_model = MLP().to(self.torch_device).eval()
config = LoraConfig(target_modules=["lin0"], init_lora_weights=False)
peft_model = get_peft_model(base_model, config).eval()
peft_model.add_adapter(adapter_name="other", peft_config=config)

# sanity check: this works
peft_model.forward(x, adapter_names=["default"] * 5 + ["other"] * 4)

# check one correct and one incorrect adapter
msg = re.escape("Trying to infer with non-existing adapter(s): does-not-exist")
with pytest.raises(ValueError, match=msg):
peft_model.forward(x, adapter_names=["default"] * 5 + ["does-not-exist"] * 4)

# check two correct adapters and one incorrect adapter
with pytest.raises(ValueError, match=msg):
peft_model.forward(x, adapter_names=["default"] * 3 + ["does-not-exist"] * 4 + ["other"] * 2)

# check only incorrect adapters
msg = re.escape("Trying to infer with non-existing adapter(s): does-not-exist, other-does-not-exist")
with pytest.raises(ValueError, match=msg):
peft_model.forward(x, adapter_names=["does-not-exist"] * 5 + ["other-does-not-exist"] * 4)

def test_mixed_adapter_batches_lora_with_dora_raises(self):
# When there are DoRA adapters, passing adapter names should raise an error
torch.manual_seed(0)
Expand Down

0 comments on commit 8efa0cb

Please sign in to comment.