Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: A few issues with AdaLora, extending GPU tests #1146

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions src/peft/tuners/adalora/bnb.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
if requires_conversion:
output = output.to(expected_dtype)
output = output * scaling / ranknum
result += output
# inplace operation on view is forbidden for MatMul8bitLtBackward, so avoid it
result = result + output
return result

def __repr__(self) -> str:
Expand Down Expand Up @@ -127,7 +128,7 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor:
requires_conversion = not torch.is_autocast_enabled()
if requires_conversion:
expected_dtype = result.dtype
compute_dtype = lora_A.weight.dtype
compute_dtype = lora_A.dtype
if x.dtype != compute_dtype:
x = x.to(compute_dtype)

Expand Down
2 changes: 1 addition & 1 deletion src/peft/tuners/adalora/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ def __getattr__(self, name: str):
def forward(self, *args, **kwargs):
outputs = self.model.forward(*args, **kwargs)

if getattr(outputs, "loss", None) is not None:
if (getattr(outputs, "loss", None) is not None) and isinstance(outputs.loss, torch.Tensor):
# Calculate the orthogonal regularization
orth_reg_weight = self.peft_config[self.trainable_adapter_name].orth_reg_weight

Expand Down
84 changes: 84 additions & 0 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,14 @@ def tearDown(self):
torch.cuda.empty_cache()
gc.collect()

def _check_inference_finite(self, model, batch):
# try inference without Trainer class
training = model.training
model.eval()
output = model(**batch.to(model.device))
self.assertTrue(torch.isfinite(output.logits).all())
model.train(training)

@pytest.mark.single_gpu_tests
def test_causal_lm_training(self):
r"""
Expand Down Expand Up @@ -335,6 +343,71 @@ def test_4bit_adalora_causalLM(self):

data = load_dataset("ybelkada/english_quotes_copy")
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
batch = tokenizer(data["train"][:3]["quote"], return_tensors="pt", padding=True)
self._check_inference_finite(model, batch)

with tempfile.TemporaryDirectory() as tmp_dir:
trainer = Trainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
warmup_steps=2,
max_steps=3,
learning_rate=2e-4,
fp16=True,
logging_steps=1,
output_dir=tmp_dir,
),
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
model.config.use_cache = False
trainer.train()

model.cpu().save_pretrained(tmp_dir)

self.assertTrue("adapter_config.json" in os.listdir(tmp_dir))
self.assertTrue(SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir))

# assert loss is not None
self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

@pytest.mark.single_gpu_tests
@require_torch_gpu
def test_8bit_adalora_causalLM(self):
r"""
Tests the 8bit training with adalora
"""
model_id = "facebook/opt-350m"

model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True)
tokenizer = AutoTokenizer.from_pretrained(model_id)

model.gradient_checkpointing_enable()
model = prepare_model_for_kbit_training(model)

peft_config = AdaLoraConfig(
init_r=6,
target_r=4,
tinit=50,
tfinal=100,
deltaT=5,
beta1=0.3,
beta2=0.3,
orth_reg_weight=0.2,
lora_alpha=32,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)

model = get_peft_model(model, peft_config)

data = load_dataset("ybelkada/english_quotes_copy")
data = data.map(lambda samples: tokenizer(samples["quote"]), batched=True)
batch = tokenizer(data["train"][:3]["quote"], return_tensors="pt", padding=True)
self._check_inference_finite(model, batch)

with tempfile.TemporaryDirectory() as tmp_dir:
trainer = Trainer(
Expand Down Expand Up @@ -671,6 +744,14 @@ def tearDown(self):
gc.collect()
torch.cuda.empty_cache()

def _check_inference_finite(self, model, batch):
# try inference without Trainer class
training = model.training
model.eval()
output = model(**batch.to(model.device))
self.assertTrue(torch.isfinite(output.logits).all())
model.train(training)

@pytest.mark.single_gpu_tests
def test_causal_lm_training(self):
r"""
Expand Down Expand Up @@ -738,6 +819,7 @@ def test_adalora_causalLM(self):
quantization_config=self.quantization_config,
)

tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id)
model = prepare_model_for_kbit_training(model)

peft_config = AdaLoraConfig(
Expand All @@ -759,6 +841,8 @@ def test_adalora_causalLM(self):

data = load_dataset("ybelkada/english_quotes_copy")
data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True)
batch = tokenizer(data["train"][:3]["quote"], return_tensors="pt", padding=True)
self._check_inference_finite(model, batch)

with tempfile.TemporaryDirectory() as tmp_dir:
trainer = Trainer(
Expand Down
Loading