Skip to content

Commit

Permalink
Fix all remaining bugs (T5 and BeIT HF bug)
Browse files Browse the repository at this point in the history
  • Loading branch information
lenglaender committed Jan 14, 2025
1 parent 5999bb7 commit b89b941
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 11 deletions.
3 changes: 0 additions & 3 deletions notebooks/Gradient_Checkpointing_Llama.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,6 @@
"# Activate gradient checkpointing\n",
"model.gradient_checkpointing_enable()\n",
"\n",
"# For gradient checkpointing with adapters, it is beneficial to set enable_input_require_grads.\n",
"model.enable_input_require_grads()\n",
"\n",
"print(model.adapter_summary())"
]
},
Expand Down
10 changes: 4 additions & 6 deletions src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1674,12 +1674,10 @@ def gradient_checkpointing_function(function, *args, **kwargs):
"Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
)

if getattr(self, "_hf_peft_config_loaded", False):
# When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
# we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
# When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
# the gradients to make sure the gradient flows.
self.enable_input_require_grads()
# >>> START AH Changes <<<
# For adapter training, we always require requires_grad=True for the input embeddings.
self.enable_input_require_grads()
# >>> END AH Changes <<<


@inherit_doc
Expand Down
14 changes: 14 additions & 0 deletions src/adapters/models/beit/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,20 @@ def __init__(self, config):
# Initialize weights and apply final processing
self.post_init()

# Overwrites the function from: transformers.modeling_utils.PreTrainedModel
def enable_input_require_grads(self):
"""
Enables the gradients for the input embeddings specifically for BEiT's tuple output format.
"""

def make_inputs_require_grads(module, input, output):
# >>> START AH Changes <<<
# Handle BEiT's specific tuple output format. Hugging Face's implementation is buggy and doesn't work for BEiT.
output[0].requires_grad_(True)
# >>> END AH Changes <<<

self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)

@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
def forward(
self,
Expand Down
14 changes: 13 additions & 1 deletion src/adapters/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,20 @@ def forward(
all_hidden_states = all_hidden_states + (hidden_states,)

if self.gradient_checkpointing and self.training:
# >>> START AH Changes <<<
# Without this change, T5 training with gradient checkpointing will fail for reft.
def create_custom_forward(module):
def custom_forward(*inputs):
# Ensure all inputs are on the same device
inputs = tuple(x.to(inputs[0].device) if isinstance(x, torch.Tensor) else x for x in inputs)
return module(*inputs)

return custom_forward

# >>> END AH Changes <<<

layer_outputs = self._gradient_checkpointing_func(
layer_module.forward,
create_custom_forward(layer_module),
hidden_states,
causal_mask,
position_bias,
Expand Down
6 changes: 5 additions & 1 deletion tests/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,13 +394,17 @@ def _run_gradient_checkpointing_test_helper(self, adapter_setup_fn: Callable[[ad

# Initialize model
model = adapters.AutoAdapterModel.from_config(config)

# if model doesn't support gradient checkpointing, skip the test
if not model.supports_gradient_checkpointing:
self.skipTest("Model does not support gradient checkpointing")

model.to(torch_device)
adapter_setup_fn(model)

# Enable gradient checkpointing
if train_with_checkpointing:
model.gradient_checkpointing_enable()
model.enable_input_require_grads()

# Train & store state dict
self.trainings_run(model, batch_size=1, gradient_accumulation_steps=2)
Expand Down

0 comments on commit b89b941

Please sign in to comment.