Skip to content

Commit

Permalink
fix handling of CLIP and fix MT5 like we did for T5
Browse files Browse the repository at this point in the history
  • Loading branch information
lenglaender committed Jan 14, 2025
1 parent b89b941 commit e7de20b
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 3 deletions.
13 changes: 11 additions & 2 deletions src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1675,8 +1675,17 @@ def gradient_checkpointing_function(function, *args, **kwargs):
)

# >>> START AH Changes <<<
# For adapter training, we always require requires_grad=True for the input embeddings.
self.enable_input_require_grads()
# For adapter training, we set requires_grad=True for the input embeddings. Just like Hugging Face does for training with PEFT.
try:
self.enable_input_require_grads()
except NotImplementedError:
# Some models (CLIP) don't have input embeddings, so Hugging Face's implementation raises a NotImplementedError.
logger.warning(
"Model does not have input embeddings. Hugging Face didn't implement the model.enable_input_require_grads() method. But Gradient Checkpointing should nevertheless work. If you, however, encounter errors / weird behaviour, this might be the reason. In this case, please implement the method in the model yourself / open an issue on our GitHub."
)
except Exception as e:
# Every other exception is unexpected and should be raised.
raise e
# >>> END AH Changes <<<


Expand Down
16 changes: 15 additions & 1 deletion src/adapters/models/mt5/modeling_mt5.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,22 @@ def forward(
all_hidden_states = all_hidden_states + (hidden_states,)

if self.gradient_checkpointing and self.training:
# >>> START AH Changes <<<
# Without this change, T5 training with gradient checkpointing will fail for reft.
def create_custom_forward(module):
def custom_forward(*inputs):
# Ensure all inputs are on the same device
inputs = tuple(x.to(inputs[0].device) if isinstance(x, torch.Tensor) else x for x in inputs)
return module(*inputs)

return custom_forward

# >>> END AH Changes <<<

layer_outputs = self._gradient_checkpointing_func(
layer_module.forward,
# >>> START AH Changes <<<
create_custom_forward(layer_module),
# >>> END AH Changes <<<
hidden_states,
causal_mask,
position_bias,
Expand Down
2 changes: 2 additions & 0 deletions src/adapters/models/t5/modeling_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,9 @@ def custom_forward(*inputs):
# >>> END AH Changes <<<

layer_outputs = self._gradient_checkpointing_func(
# >>> START AH Changes <<<
create_custom_forward(layer_module),
# >>> END AH Changes <<<
hidden_states,
causal_mask,
position_bias,
Expand Down

0 comments on commit e7de20b

Please sign in to comment.