Skip to content

Commit

Permalink
FIX: PiSSA now works with Conv1D layers (#2103) (#2104)
Browse files Browse the repository at this point in the history
Transpose weight matrix based on fan_in_fan_out condition in PiSSA
initialization.

Co-authored-by: Yang Su <suyang360@gmail.com>
  • Loading branch information
suyang160 and Yang Su authored Oct 8, 2024
1 parent 9918977 commit a724834
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def pissa_init(self, adapter_name, init_lora_weights):
"Please initialize PiSSA under float32, float16, or bfloat16. "
"Subsequently, re-quantize the residual model to help minimize quantization errors."
)
weight = weight.to(torch.float32)
weight = transpose(weight.to(torch.float32), self.fan_in_fan_out)
if init_lora_weights == "pissa":
# USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},
V, S, Uh = torch.linalg.svd(weight.data, full_matrices=False)
Expand All @@ -245,7 +245,7 @@ def pissa_init(self, adapter_name, init_lora_weights):
self.lora_A[adapter_name].weight.data = lora_A
self.lora_B[adapter_name].weight.data = lora_B
weight = weight.data - self.scaling[adapter_name] * lora_B @ lora_A
weight = weight.to(dtype)
weight = transpose(weight.to(dtype), self.fan_in_fan_out)
self.get_base_layer().weight.data = weight

def loftq_init(self, adapter_name):
Expand Down
15 changes: 13 additions & 2 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
WhisperProcessor,
WhisperTokenizer,
)
from transformers.pytorch_utils import Conv1D

from peft import (
AdaLoraConfig,
Expand Down Expand Up @@ -1719,7 +1720,7 @@ def quantize_model(self, model, num_bits=4, device="cuda"):
# Quantize the `weight.data` of the linear layer in the model to `num_bits` and store it with full precision.
quantizer = NFQuantizer(num_bits=num_bits, device=device, method="normal", block_size=64)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and "lm_head" not in name:
if isinstance(module, (torch.nn.Linear, Conv1D)) and "lm_head" not in name:
quantized_weight, max_abs, shape = quantizer.quantize_block(module.weight.data.to(device))
module.weight.data = quantizer.dequantize_block(quantized_weight, max_abs, shape)
return model
Expand All @@ -1728,7 +1729,7 @@ def nuclear_norm(self, base_model, quantized_model):
# Calculate the nuclear norm (sum of singular values) of the error matrices between the `quantized_model` and the `base_model`.
error_list = []
for name, module in base_model.named_modules():
if isinstance(module, torch.nn.Linear) and "lm_head" not in name:
if isinstance(module, (torch.nn.Linear, Conv1D)) and "lm_head" not in name:
quant_module = quantized_model.get_submodule(name)
error_list.append(torch.linalg.svdvals(module.weight.data - quant_module.weight.data).sum())
return torch.Tensor(error_list).sum()
Expand Down Expand Up @@ -1822,6 +1823,16 @@ def test_t5_pissa_4bit(self, device, tmp_path):
def test_t5_pissa_8bit(self, device, tmp_path):
self.get_errors(bits=8, device=device, model_id="t5-small", tmp_path=tmp_path)

@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_gpt2_pissa_4bit(self, device, tmp_path):
# see 2104
self.get_errors(bits=4, device=device, model_id="gpt2", tmp_path=tmp_path)

@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_gpt2_pissa_8bit(self, device, tmp_path):
# see 2104
self.get_errors(bits=8, device=device, model_id="gpt2", tmp_path=tmp_path)

@require_bitsandbytes
def test_lora_pissa_conversion_same_output_after_loading_with_quantization(self, tmp_path):
# A copy of the test `test_lora_pissa_conversion_same_output_after_loading` in peft/tests/test_initialization.py,
Expand Down

0 comments on commit a724834

Please sign in to comment.