diff --git a/notebooks/Gradient_Checkpointing_Llama.ipynb b/notebooks/Gradient_Checkpointing_Llama.ipynb index dee4d3abd3..b48390d846 100644 --- a/notebooks/Gradient_Checkpointing_Llama.ipynb +++ b/notebooks/Gradient_Checkpointing_Llama.ipynb @@ -185,9 +185,6 @@ "# Activate gradient checkpointing\n", "model.gradient_checkpointing_enable()\n", "\n", - "# For gradient checkpointing with adapters, it is beneficial to set enable_input_require_grads.\n", - "model.enable_input_require_grads()\n", - "\n", "print(model.adapter_summary())" ] }, diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 8e07559713..00714ad873 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -1674,12 +1674,10 @@ def gradient_checkpointing_function(function, *args, **kwargs): "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model." ) - if getattr(self, "_hf_peft_config_loaded", False): - # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True - # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334 - # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate - # the gradients to make sure the gradient flows. - self.enable_input_require_grads() + # >>> START AH Changes <<< + # For adapter training, we always require requires_grad=True for the input embeddings. + self.enable_input_require_grads() + # >>> END AH Changes <<< @inherit_doc diff --git a/src/adapters/models/beit/adapter_model.py b/src/adapters/models/beit/adapter_model.py index 5667fa098d..578142ea11 100644 --- a/src/adapters/models/beit/adapter_model.py +++ b/src/adapters/models/beit/adapter_model.py @@ -36,6 +36,20 @@ def __init__(self, config): # Initialize weights and apply final processing self.post_init() + # Overwrites the function from: transformers.modeling_utils.PreTrainedModel + def enable_input_require_grads(self): + """ + Enables the gradients for the input embeddings specifically for BEiT's tuple output format. + """ + + def make_inputs_require_grads(module, input, output): + # >>> START AH Changes <<< + # Handle BEiT's specific tuple output format. Hugging Face's implementation is buggy and doesn't work for BEiT. + output[0].requires_grad_(True) + # >>> END AH Changes <<< + + self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) + @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING) def forward( self, diff --git a/src/adapters/models/t5/modeling_t5.py b/src/adapters/models/t5/modeling_t5.py index 09b969bb1b..e4af5f04c4 100644 --- a/src/adapters/models/t5/modeling_t5.py +++ b/src/adapters/models/t5/modeling_t5.py @@ -419,8 +419,20 @@ def forward( all_hidden_states = all_hidden_states + (hidden_states,) if self.gradient_checkpointing and self.training: + # >>> START AH Changes <<< + # Without this change, T5 training with gradient checkpointing will fail for reft. + def create_custom_forward(module): + def custom_forward(*inputs): + # Ensure all inputs are on the same device + inputs = tuple(x.to(inputs[0].device) if isinstance(x, torch.Tensor) else x for x in inputs) + return module(*inputs) + + return custom_forward + + # >>> END AH Changes <<< + layer_outputs = self._gradient_checkpointing_func( - layer_module.forward, + create_custom_forward(layer_module), hidden_states, causal_mask, position_bias, diff --git a/tests/methods/base.py b/tests/methods/base.py index 46df3fdc2a..e0e3abce19 100644 --- a/tests/methods/base.py +++ b/tests/methods/base.py @@ -394,13 +394,17 @@ def _run_gradient_checkpointing_test_helper(self, adapter_setup_fn: Callable[[ad # Initialize model model = adapters.AutoAdapterModel.from_config(config) + + # if model doesn't support gradient checkpointing, skip the test + if not model.supports_gradient_checkpointing: + self.skipTest("Model does not support gradient checkpointing") + model.to(torch_device) adapter_setup_fn(model) # Enable gradient checkpointing if train_with_checkpointing: model.gradient_checkpointing_enable() - model.enable_input_require_grads() # Train & store state dict self.trainings_run(model, batch_size=1, gradient_accumulation_steps=2)