Skip to content

Commit

Permalink
Add OpenVINO pipelines (#740)
Browse files Browse the repository at this point in the history
* add openvino accelerator for pipelines

* add warning

* add test

* format

* load image preprocessor if needed

* rename

* add max_new_tokens

* fix ov model detection
  • Loading branch information
echarlaix authored Jun 13, 2024
1 parent 0a6075b commit 0c2217d
Show file tree
Hide file tree
Showing 3 changed files with 398 additions and 55 deletions.
229 changes: 203 additions & 26 deletions optimum/intel/pipelines/pipeline_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,40 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

import torch
from transformers import AutoConfig, AutoFeatureExtractor, AutoTokenizer
from transformers import pipeline as transformers_pipeline
from transformers.feature_extraction_utils import PreTrainedFeatureExtractor
from transformers.pipelines import (
from transformers import (
AudioClassificationPipeline,
AutoConfig,
AutoFeatureExtractor,
AutoImageProcessor,
AutomaticSpeechRecognitionPipeline,
AutoTokenizer,
FeatureExtractionPipeline,
FillMaskPipeline,
ImageClassificationPipeline,
ImageToTextPipeline,
Pipeline,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
QuestionAnsweringPipeline,
SummarizationPipeline,
Text2TextGenerationPipeline,
TextClassificationPipeline,
TextGenerationPipeline,
TokenClassificationPipeline,
TranslationPipeline,
ZeroShotClassificationPipeline,
)
from transformers.pipelines.base import Pipeline
from transformers.tokenization_utils import PreTrainedTokenizer
from transformers import pipeline as transformers_pipeline
from transformers.feature_extraction_utils import PreTrainedFeatureExtractor
from transformers.utils import logging

from optimum.intel.utils import is_ipex_available
from optimum.intel.utils.import_utils import (
IPEX_IMPORT_ERROR,
OPENVINO_IMPORT_ERROR,
is_ipex_available,
is_openvino_available,
)
from optimum.intel.utils.modeling_utils import _find_files_matching_pattern


if is_ipex_available():
Expand Down Expand Up @@ -95,21 +112,161 @@
IPEX_SUPPORTED_TASKS = {}


def load_ipex_model(
if is_openvino_available():
from ..openvino import (
OVModelForAudioClassification,
OVModelForCausalLM,
OVModelForFeatureExtraction,
OVModelForImageClassification,
OVModelForMaskedLM,
OVModelForQuestionAnswering,
OVModelForSeq2SeqLM,
OVModelForSequenceClassification,
OVModelForSpeechSeq2Seq,
OVModelForTokenClassification,
OVModelForVision2Seq,
)
from ..openvino.modeling_base import OVBaseModel

OPENVINO_SUPPORTED_TASKS = {
"feature-extraction": {
"impl": FeatureExtractionPipeline,
"class": (OVModelForFeatureExtraction,),
"default": "distilbert-base-cased",
"type": "text", # feature extraction is only supported for text at the moment
},
"fill-mask": {
"impl": FillMaskPipeline,
"class": (OVModelForMaskedLM,),
"default": "bert-base-cased",
"type": "text",
},
"image-classification": {
"impl": ImageClassificationPipeline,
"class": (OVModelForImageClassification,),
"default": "google/vit-base-patch16-224",
"type": "image",
},
"question-answering": {
"impl": QuestionAnsweringPipeline,
"class": (OVModelForQuestionAnswering,),
"default": "distilbert-base-cased-distilled-squad",
"type": "text",
},
"text-classification": {
"impl": TextClassificationPipeline,
"class": (OVModelForSequenceClassification,),
"default": "distilbert-base-uncased-finetuned-sst-2-english",
"type": "text",
},
"text-generation": {
"impl": TextGenerationPipeline,
"class": (OVModelForCausalLM,),
"default": "distilgpt2",
"type": "text",
},
"token-classification": {
"impl": TokenClassificationPipeline,
"class": (OVModelForTokenClassification,),
"default": "dbmdz/bert-large-cased-finetuned-conll03-english",
"type": "text",
},
"zero-shot-classification": {
"impl": ZeroShotClassificationPipeline,
"class": (OVModelForSequenceClassification,),
"default": "facebook/bart-large-mnli",
"type": "text",
},
"summarization": {
"impl": SummarizationPipeline,
"class": (OVModelForSeq2SeqLM,),
"default": "t5-base",
"type": "text",
},
"translation": {
"impl": TranslationPipeline,
"class": (OVModelForSeq2SeqLM,),
"default": "t5-small",
"type": "text",
},
"text2text-generation": {
"impl": Text2TextGenerationPipeline,
"class": (OVModelForSeq2SeqLM,),
"default": "t5-small",
"type": "text",
},
"automatic-speech-recognition": {
"impl": AutomaticSpeechRecognitionPipeline,
"class": (OVModelForSpeechSeq2Seq,),
"default": "openai/whisper-tiny.en",
"type": "multimodal",
},
"image-to-text": {
"impl": ImageToTextPipeline,
"class": (OVModelForVision2Seq,),
"default": "nlpconnect/vit-gpt2-image-captioning",
"type": "multimodal",
},
"audio-classification": {
"impl": AudioClassificationPipeline,
"class": (OVModelForAudioClassification,),
"default": "superb/hubert-base-superb-ks",
"type": "audio",
},
}
else:
OPENVINO_SUPPORTED_TASKS = {}


def load_openvino_model(
model,
targeted_task,
SUPPORTED_TASKS,
model_kwargs: Optional[Dict[str, Any]] = None,
hub_kwargs: Optional[Dict[str, Any]] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
):
if model_kwargs is None:
model_kwargs = {}
hub_kwargs = hub_kwargs or {}
model_kwargs = model_kwargs or {}
ov_model_class = SUPPORTED_TASKS[targeted_task]["class"][0]

if model is None:
model_id = SUPPORTED_TASKS[targeted_task]["default"]
model = ov_model_class.from_pretrained(model_id, export=True, **hub_kwargs, **model_kwargs)
elif isinstance(model, str):
model_id = model
pattern = r"(.*)?openvino(.*)?\_model.xml"
ov_files = _find_files_matching_pattern(
model,
pattern,
use_auth_token=hub_kwargs.get("token", None),
revision=hub_kwargs.get("revision", None),
)
export = len(ov_files) == 0
model = ov_model_class.from_pretrained(model, export=export, **hub_kwargs, **model_kwargs)
elif isinstance(model, OVBaseModel):
model_id = model.model_save_dir
else:
raise ValueError(
f"""Model {model} is not supported. Please provide a valid model either as string or ORTModel.
You can also provide non model then a default one will be used"""
)
return model, model_id


def load_ipex_model(
model,
targeted_task,
SUPPORTED_TASKS,
hub_kwargs: Optional[Dict[str, Any]] = None,
model_kwargs: Optional[Dict[str, Any]] = None,
):
hub_kwargs = hub_kwargs or {}
model_kwargs = model_kwargs or {}
ipex_model_class = SUPPORTED_TASKS[targeted_task]["class"][0]

if model is None:
model_id = SUPPORTED_TASKS[targeted_task]["default"]
model = ipex_model_class.from_pretrained(model_id, export=True, **model_kwargs, **hub_kwargs)
model = ipex_model_class.from_pretrained(model_id, export=True, **hub_kwargs, **model_kwargs)
elif isinstance(model, str):
model_id = model
try:
Expand All @@ -118,7 +275,7 @@ def load_ipex_model(
except RuntimeError:
logger.warning("We will use IPEXModel with export=True to export the model")
export = True
model = ipex_model_class.from_pretrained(model, export=export, **model_kwargs, **hub_kwargs)
model = ipex_model_class.from_pretrained(model, export=export, **hub_kwargs, **model_kwargs)
elif isinstance(model, IPEXModel):
model_id = getattr(model.config, "name_or_path", None)
else:
Expand All @@ -132,6 +289,7 @@ def load_ipex_model(

MAPPING_LOADING_FUNC = {
"ipex": load_ipex_model,
"openvino": load_openvino_model,
}


Expand All @@ -154,8 +312,8 @@ def pipeline(
revision: Optional[str] = None,
trust_remote_code: Optional[bool] = None,
torch_dtype: Optional[Union[str, torch.dtype]] = None,
commit_hash: Optional[str] = None,
**model_kwargs,
model_kwargs: Dict[str, Any] = None,
**kwargs,
) -> Pipeline:
"""
Utility factory method to build a [`Pipeline`].
Expand Down Expand Up @@ -214,9 +372,12 @@ def pipeline(
>>> pipe = pipeline('text-generation', 'gpt2', torch_dtype=torch.bfloat16)
>>> pipe("Describe a real-world application of AI in sustainable energy.")
```"""

if model_kwargs is None:
model_kwargs = {}

commit_hash = kwargs.pop("_commit_hash", None)

if task is None and model is None:
raise RuntimeError(
"Impossible to instantiate a pipeline without either a task or a model "
Expand All @@ -240,12 +401,19 @@ def pipeline(
raise ValueError(msg + f" Supported list of `accelerator` is : {', '.join(MAPPING_LOADING_FUNC)}.")

if accelerator == "ipex":
if task not in list(IPEX_SUPPORTED_TASKS.keys()):
raise ValueError(
f"Task {task} is not supported for the IPEX pipeline. Supported tasks are { list(IPEX_SUPPORTED_TASKS.keys())}"
)
if not is_ipex_available():
raise RuntimeError(IPEX_IMPORT_ERROR.format("`accelerator=ipex`"))
supported_tasks = IPEX_SUPPORTED_TASKS

supported_tasks = IPEX_SUPPORTED_TASKS if accelerator == "ipex" else None
if accelerator == "openvino":
if not is_openvino_available():
raise RuntimeError(OPENVINO_IMPORT_ERROR.format("`accelerator=openvino`"))
supported_tasks = OPENVINO_SUPPORTED_TASKS

if task not in supported_tasks:
raise ValueError(
f"Task {task} is not supported for the {accelerator} pipelines. Supported tasks are {', '.join(supported_tasks)}"
)

no_feature_extractor_tasks = set()
no_tokenizer_tasks = set()
Expand All @@ -272,6 +440,7 @@ def pipeline(
if isinstance(model, Path):
model = str(model)

tokenizer_kwargs = model_kwargs.copy()
if torch_dtype is not None:
if "torch_dtype" in model_kwargs:
raise ValueError(
Expand All @@ -280,20 +449,28 @@ def pipeline(
)
model_kwargs["torch_dtype"] = torch_dtype

# Load the correct model if possible
# Infer the framework from the model if not already defined
model, model_id = MAPPING_LOADING_FUNC[accelerator](model, task, supported_tasks, model_kwargs, hub_kwargs)
# Load the correct model and convert it to the expected format if needed
model, model_id = MAPPING_LOADING_FUNC[accelerator](
model,
task,
SUPPORTED_TASKS=supported_tasks,
hub_kwargs=hub_kwargs,
model_kwargs=model_kwargs,
**kwargs,
)

if load_tokenizer and tokenizer is None:
tokenizer = AutoTokenizer.from_pretrained(model_id, **hub_kwargs, **model_kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, **hub_kwargs, **tokenizer_kwargs)
if load_feature_extractor and feature_extractor is None:
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id, **hub_kwargs, **model_kwargs)
try:
feature_extractor = AutoFeatureExtractor.from_pretrained(model_id, **hub_kwargs, **tokenizer_kwargs)
except Exception:
feature_extractor = AutoImageProcessor.from_pretrained(model_id, **hub_kwargs, **tokenizer_kwargs)

return transformers_pipeline(
task,
model=model,
tokenizer=tokenizer,
feature_extractor=feature_extractor,
use_fast=use_fast,
torch_dtype=torch_dtype,
)
51 changes: 50 additions & 1 deletion optimum/intel/utils/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Tuple
import re
from pathlib import Path
from typing import List, Optional, Tuple, Union

import torch
from huggingface_hub import HfApi, HfFolder
from transformers.modeling_utils import PreTrainedModel


Expand Down Expand Up @@ -191,3 +194,49 @@ def _setattr_from_module(new_module, module):
if k.startswith("__") or k.startswith("forward"):
continue
setattr(new_module.__class__, k, getattr(module.__class__, k))


def _find_files_matching_pattern(
model_name_or_path: Union[str, Path],
pattern: str,
subfolder: str = "",
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
) -> List[Path]:
"""
Scans either a model repo or a local directory to find filenames matching the pattern.
Args:
model_name_or_path (`Union[str, Path]`):
The name of the model repo on the Hugging Face Hub or the path to a local directory.
pattern (`str`):
The pattern to use to look for files.
subfolder (`str`, defaults to `""`):
In case the model files are located inside a subfolder of the model directory / repo on the Hugging
Face Hub, you can specify the subfolder name here.
use_auth_token (`Optional[bool, str]`, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
revision (`Optional[str]`, defaults to `None`):
Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id.
Returns:
`List[Path]`
"""
model_path = Path(model_name_or_path) if isinstance(model_name_or_path, str) else model_name_or_path
pattern = re.compile(f"{subfolder}/{pattern}" if subfolder != "" else pattern)
subfolder = subfolder or "."

if model_path.is_dir():
glob_pattern = subfolder + "/*"
files = model_path.glob(glob_pattern)
files = [p for p in files if re.search(pattern, str(p))]
else:
if isinstance(use_auth_token, bool):
token = HfFolder().get_token()
else:
token = use_auth_token
repo_files = map(Path, HfApi().list_repo_files(model_name_or_path, revision=revision, token=token))
files = [Path(p) for p in repo_files if re.match(pattern, str(p)) and str(p.parent) == subfolder]

return files
Loading

0 comments on commit 0c2217d

Please sign in to comment.