Skip to content

Commit

Permalink
Fixes warning when initializing prompt encoder (#716)
Browse files Browse the repository at this point in the history
Right now, when the user initializes a prompt encoder with MLP, they get
a warning that a certain argument is ignored, and there is no possible
value for the argument that would stop the warning. Usually, warnings
are for issues that something is (probably) going wrong, but here,
everything is going as expected. Therefore, by default, I would not give
this warning, thus avoiding users getting confused.

However, I would still give the warning if the user set the argument for
encoder_num_layers explicitly to a different value. In that case, they
expect the change to make a difference, but since the argument is
ignored, their expectation is not met, which warrants a warning.
  • Loading branch information
BenjaminBossan authored Jul 19, 2023
1 parent 1681ceb commit e06d94d
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/peft/tuners/p_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,12 @@ def __init__(self, config):
)

elif self.encoder_type == PromptEncoderReparameterizationType.MLP:
warnings.warn(
f"for {self.encoder_type}, the `encoder_num_layers` is ignored. Exactly 2 MLP layers are used."
)
encoder_num_layers_default = PromptEncoderConfig.encoder_num_layers
if config.encoder_num_layers != encoder_num_layers_default:
warnings.warn(
f"for {self.encoder_type}, the argument `encoder_num_layers` is ignored. "
f"Exactly {encoder_num_layers_default} MLP layers are used."
)
layers = [
torch.nn.Linear(self.input_size, self.hidden_size),
torch.nn.ReLU(),
Expand Down
26 changes: 26 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@
import pickle
import tempfile
import unittest
import warnings

import pytest

from peft import (
AdaptionPromptConfig,
IA3Config,
LoraConfig,
PeftConfig,
PrefixTuningConfig,
PromptEncoder,
PromptEncoderConfig,
PromptTuningConfig,
)
Expand Down Expand Up @@ -156,3 +160,25 @@ def test_config_pickle_roundtrip(self):
config = config_class()
copied = pickle.loads(pickle.dumps(config))
self.assertEqual(config.to_dict(), copied.to_dict())

def test_prompt_encoder_warning_num_layers(self):
# This test checks that if a prompt encoder config is created with an argument that is ignored, there should be
# warning. However, there should be no warning if the default value is used.
kwargs = {
"num_virtual_tokens": 20,
"num_transformer_submodules": 1,
"token_dim": 768,
"encoder_hidden_size": 768,
}

# there should be no warning with just default argument for encoder_num_layer
config = PromptEncoderConfig(**kwargs)
with warnings.catch_warnings():
PromptEncoder(config)

# when changing encoder_num_layer, there should be a warning for MLP since that value is not used
config = PromptEncoderConfig(encoder_num_layers=123, **kwargs)
with pytest.warns(UserWarning) as record:
PromptEncoder(config)
expected_msg = "for MLP, the argument `encoder_num_layers` is ignored. Exactly 2 MLP layers are used."
assert str(record.list[0].message) == expected_msg

0 comments on commit e06d94d

Please sign in to comment.