Skip to content

Commit

Permalink
couple initial fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Jan 31, 2025
1 parent 3ad88da commit 4bf5ffc
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions optimum/exporters/openvino/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,14 +718,15 @@ def _mistral_update_causal_mask(
class MistralModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
if is_transformers_version(">=", "4.42.0"):
if is_transformers_version(">=", "4.42.0") and is_transformers_version("<", "4.48.0"):
# apply fix https://github.com/huggingface/transformers/commit/57d7594a79a9f5d835abf2d4d384db0e4818e548
self._model.model._orig_update_causal_mask = self._model.model._update_causal_mask
self._model.model._update_causal_mask = types.MethodType(_mistral_update_causal_mask, self._model.model)

else:
for layer in self._model.model.layers:
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)
if hasattr(layer.self_attn, "rotary_emb"):
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)

def __exit__(self, exc_type, exc_value, traceback):
super().__exit__(exc_type, exc_value, traceback)
Expand All @@ -734,7 +735,7 @@ def __exit__(self, exc_type, exc_value, traceback):
self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask

for layer in self._model.model.layers:
if hasattr(layer.self_attn.rotary_emb, "_orig_forward"):
if hasattr(layer.self_attn, "rotary_emb") and hasattr(layer.self_attn.rotary_emb, "_orig_forward"):
layer.self_attn.rotary_emb.forward = layer.self_attn.rotary_emb._orig_forward


Expand Down Expand Up @@ -2493,7 +2494,9 @@ class UpdateCausalMaskModelPatcher(DecoderModelPatcher):
def __enter__(self):
super().__enter__()
patch_update_causal_mask(self._model, "4.42.0")
if hasattr(self._model.model.layers[0].self_attn.rotary_emb, "_set_cos_sin_cache"):
if hasattr(self._model.model.layers[0].self_attn, "rotary_emb") and hasattr(
self._model.model.layers[0].self_attn.rotary_emb, "_set_cos_sin_cache"
):
for layer in self._model.model.layers:
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)

Expand Down Expand Up @@ -3045,15 +3048,16 @@ def patched_forward(self, fn):
def __enter__(self):
if is_torch_version(">=", "2.1.0"):
if self._model.config.model_type == "qwen2" and self._model.config._attn_implementation != "sdpa":
from transformers.models.qwen2.modeling_qwen2 import QWEN2_ATTENTION_CLASSES
if is_transformers_version("<", "4.48"):
from transformers.models.qwen2.modeling_qwen2 import QWEN2_ATTENTION_CLASSES

sdpa_attn = QWEN2_ATTENTION_CLASSES["sdpa"]
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
self._model.config._attn_implementation = "sdpa"
sdpa_attn = QWEN2_ATTENTION_CLASSES["sdpa"]
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
self._model.config._attn_implementation = "sdpa"

for layer in self._model.model.layers:
layer.self_attn._orig_forward = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)
for layer in self._model.model.layers:
layer.self_attn._orig_forward = layer.self_attn.forward
layer.self_attn.forward = types.MethodType(sdpa_attn.forward, layer.self_attn)

if self._model.config.model_type == "llama" and self._model.config._attn_implementation != "sdpa":
self._model.config._orig_attn_implementation = self._model.config._attn_implementation
Expand Down

0 comments on commit 4bf5ffc

Please sign in to comment.