Skip to content

Commit

Permalink
Formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Jan 18, 2025
1 parent e668d89 commit 121ed80
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
10 changes: 5 additions & 5 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1753,12 +1753,13 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
# def inputs(self) -> Dict[str, Dict[int, str]]:
# return {"input_features": {0: "batch_size", 1: "sequence_classification"}}


class MoonshineOnnxConfig(AudioToTextOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedSeq2SeqConfig

# torch.onnx.errors.UnsupportedOperatorError: Exporting the operator 'aten::triu' to ONNX opset version 11 is not supported.
# Support for this operator was added in version 14, try exporting with this version.
DEFAULT_ONNX_OPSET = 14
DEFAULT_ONNX_OPSET = 14

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
Expand All @@ -1780,7 +1781,6 @@ def inputs(self) -> Dict[str, Dict[int, str]]:
return common_inputs



class WhisperOnnxConfig(AudioToTextOnnxConfig):
DEFAULT_ONNX_OPSET = 14 # Whisper now uses F.scaled_dot_product_attention by default for torch>=2.1.1.

Expand Down Expand Up @@ -2329,9 +2329,9 @@ def outputs(self) -> Dict[str, Dict[int, str]]:
# for Speech2text, we need to name the second axis as
# encoder_sequence_length / 2 * self._config.num_conv_layers as the axis name is
# used for dummy input generation
common_outputs["last_hidden_state"][
1
] = f"{common_outputs['last_hidden_state'][1]} / {(2 * self._config.num_conv_layers)}"
common_outputs["last_hidden_state"][1] = (
f"{common_outputs['last_hidden_state'][1]} / {(2 * self._config.num_conv_layers)}"
)
return common_outputs


Expand Down
14 changes: 6 additions & 8 deletions optimum/exporters/onnx/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def onnx_compatible_repeat_interleave(input_tensor, repeats, dim=None):
input_tensor (torch.Tensor): The input tensor.
repeats (int or torch.Tensor): The number of repetitions for each element.
dim (int, optional): The dimension along which to repeat. Defaults to None.
Returns:
torch.Tensor: The repeated tensor.
"""
Expand Down Expand Up @@ -199,32 +199,30 @@ def onnx_compatible_repeat_interleave(input_tensor, repeats, dim=None):
UNSUPPORTED_OPS_PATCHING_SPEC = [
PatchingSpec(torch.Tensor, "unfold", onnx_compatible_unfold, torch.Tensor.unfold),
PatchingSpec(torch.Tensor, "repeat_interleave", onnx_compatible_repeat_interleave, torch.Tensor.repeat_interleave),

# TracerWarning: Using len to get tensor shape might cause the trace to be incorrect. Recommended usage would be tensor.shape[0]. Passing a tensor of different shape might lead to errors or silently give incorrect results.
PatchingSpec(torch.Tensor, "__len__", lambda x: x.shape[0], torch.Tensor.__len__),
]


def patched_module_call(self, *args, **kwargs):
if kwargs.get('past_key_values') is not None:
num_items = len(kwargs['past_key_values'][0])
if kwargs.get("past_key_values") is not None:
num_items = len(kwargs["past_key_values"][0])
if num_items == 2:
cls = transformers.DynamicCache
elif num_items == 4:
cls = transformers.EncoderDecoderCache
else:
raise ValueError(f"Unexpected number of items in past_key_values: {num_items}")
kwargs['past_key_values'] = cls.from_legacy_cache(kwargs['past_key_values'])
kwargs["past_key_values"] = cls.from_legacy_cache(kwargs["past_key_values"])

# NOTE: We cannot use .forward directly as this will
# lose optimization opportunities in the ONNX export.
output = self._wrapped_call_impl(*args, **kwargs)

# RuntimeError: Only tuples, lists and Variables are supported as JIT inputs/outputs.
# Dictionaries and strings are also accepted, but their usage is not recommended.
# Here, received an input of unsupported type: XXXCache
if getattr(output, 'past_key_values', None) is not None and \
hasattr(output.past_key_values, 'to_legacy_cache'):
if getattr(output, "past_key_values", None) is not None and hasattr(output.past_key_values, "to_legacy_cache"):
output.past_key_values = output.past_key_values.to_legacy_cache()
return output

Expand Down

0 comments on commit 121ed80

Please sign in to comment.