Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Whisper in new concise nn.Module API #868

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/mlc_chat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

MLC Chat is the app runtime of MLC LLM.
"""
from . import protocol, serve

# from . import protocol, serve
from .chat_module import ChatConfig, ChatModule, ConvConfig, GenerationConfig
from .libinfo import __version__
5 changes: 3 additions & 2 deletions python/mlc_chat/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Load MLC LLM library and _ffi_api functions."""

import ctypes
import os
import sys
Expand All @@ -24,5 +25,5 @@ def _load_mlc_llm_lib():


# only load once here
if SKIP_LOADING_MLCLLM_SO == "0":
_LIB, _LIB_PATH = _load_mlc_llm_lib()
# if SKIP_LOADING_MLCLLM_SO == "0":
# _LIB, _LIB_PATH = _load_mlc_llm_lib()
1 change: 1 addition & 0 deletions python/mlc_chat/compiler_pass/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""The compilation pipeline for LLM applications."""

from pathlib import Path
from typing import Any, Dict, List, Optional

Expand Down
1 change: 1 addition & 0 deletions python/mlc_chat/interface/compile.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Python entrypoint of compilation."""

import dataclasses
import math
from io import StringIO
Expand Down
14 changes: 14 additions & 0 deletions python/mlc_chat/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .phi import phi_loader, phi_model, phi_quantization
from .qwen import qwen_loader, qwen_model, qwen_quantization
from .stable_lm import stablelm_loader, stablelm_model, stablelm_quantization
from .whisper import whisper_loader, whisper_model, whisper_quantization

ModelConfig = Any
"""A ModelConfig is an object that represents a model architecture. It is required to have
Expand Down Expand Up @@ -195,4 +196,17 @@ class Model:
"group-quant": stablelm_quantization.group_quant,
},
),
"whisper": Model(
name="whisper",
model=whisper_model.WhisperForConditionalGeneration,
config=whisper_model.WhisperConfig,
source={
"huggingface-torch": whisper_loader.huggingface,
"huggingface-safetensor": whisper_loader.huggingface,
},
quantize={
"no-quant": whisper_quantization.no_quant,
"group-quant": whisper_quantization.group_quant,
},
),
}
Empty file.
51 changes: 51 additions & 0 deletions python/mlc_chat/model/whisper/whisper_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
This file specifies how MLC's Whisper parameter maps from other formats, for example HuggingFace
PyTorch, HuggingFace safetensors.
"""

import functools

from mlc_chat.loader import ExternMapping
from mlc_chat.quantization import Quantization

from .whisper_model import WhisperConfig, WhisperForConditionalGeneration


def huggingface(model_config: WhisperConfig, quantization: Quantization) -> ExternMapping:
"""Returns a parameter mapping that maps from the names of MLC LLM parameters to
the names of HuggingFace PyTorch parameters.

Parameters
----------
model_config : WhisperConfig
The configuration of the GPTNeoX model.

quantization : Quantization
The quantization configuration.

Returns
-------
param_map : ExternMapping
The parameter mapping from MLC to HuggingFace PyTorch.
"""
model = WhisperForConditionalGeneration(model_config)
if quantization is not None:
model.to(quantization.model_dtype)
_, _named_params, _ = model.export_tvm( # type: ignore[misc]
spec=model.get_default_spec(),
allow_extern=True,
)
named_parameters = dict(_named_params)

mapping = ExternMapping()

for mlc_name, mlc_param in named_parameters.items():
mapping.add_mapping(
mlc_name,
[mlc_name],
functools.partial(
lambda x, dtype: x.astype(dtype),
dtype=mlc_param.dtype,
),
)
return mapping
Loading