From e06d94ddeb6c70913593740618df76908b918d66 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Wed, 19 Jul 2023 16:08:29 +0200 Subject: [PATCH] Fixes warning when initializing prompt encoder (#716) Right now, when the user initializes a prompt encoder with MLP, they get a warning that a certain argument is ignored, and there is no possible value for the argument that would stop the warning. Usually, warnings are for issues that something is (probably) going wrong, but here, everything is going as expected. Therefore, by default, I would not give this warning, thus avoiding users getting confused. However, I would still give the warning if the user set the argument for encoder_num_layers explicitly to a different value. In that case, they expect the change to make a difference, but since the argument is ignored, their expectation is not met, which warrants a warning. --- src/peft/tuners/p_tuning.py | 9 ++++++--- tests/test_config.py | 26 ++++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 3 deletions(-) diff --git a/src/peft/tuners/p_tuning.py b/src/peft/tuners/p_tuning.py index 4a272f357d..de52c2fae9 100644 --- a/src/peft/tuners/p_tuning.py +++ b/src/peft/tuners/p_tuning.py @@ -143,9 +143,12 @@ def __init__(self, config): ) elif self.encoder_type == PromptEncoderReparameterizationType.MLP: - warnings.warn( - f"for {self.encoder_type}, the `encoder_num_layers` is ignored. Exactly 2 MLP layers are used." - ) + encoder_num_layers_default = PromptEncoderConfig.encoder_num_layers + if config.encoder_num_layers != encoder_num_layers_default: + warnings.warn( + f"for {self.encoder_type}, the argument `encoder_num_layers` is ignored. " + f"Exactly {encoder_num_layers_default} MLP layers are used." + ) layers = [ torch.nn.Linear(self.input_size, self.hidden_size), torch.nn.ReLU(), diff --git a/tests/test_config.py b/tests/test_config.py index 28a61771c1..ca0ef0e3bd 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -17,6 +17,9 @@ import pickle import tempfile import unittest +import warnings + +import pytest from peft import ( AdaptionPromptConfig, @@ -24,6 +27,7 @@ LoraConfig, PeftConfig, PrefixTuningConfig, + PromptEncoder, PromptEncoderConfig, PromptTuningConfig, ) @@ -156,3 +160,25 @@ def test_config_pickle_roundtrip(self): config = config_class() copied = pickle.loads(pickle.dumps(config)) self.assertEqual(config.to_dict(), copied.to_dict()) + + def test_prompt_encoder_warning_num_layers(self): + # This test checks that if a prompt encoder config is created with an argument that is ignored, there should be + # warning. However, there should be no warning if the default value is used. + kwargs = { + "num_virtual_tokens": 20, + "num_transformer_submodules": 1, + "token_dim": 768, + "encoder_hidden_size": 768, + } + + # there should be no warning with just default argument for encoder_num_layer + config = PromptEncoderConfig(**kwargs) + with warnings.catch_warnings(): + PromptEncoder(config) + + # when changing encoder_num_layer, there should be a warning for MLP since that value is not used + config = PromptEncoderConfig(encoder_num_layers=123, **kwargs) + with pytest.warns(UserWarning) as record: + PromptEncoder(config) + expected_msg = "for MLP, the argument `encoder_num_layers` is ignored. Exactly 2 MLP layers are used." + assert str(record.list[0].message) == expected_msg