diff --git a/optimum/intel/pipelines/pipeline_base.py b/optimum/intel/pipelines/pipeline_base.py index 6e0afa6c21..39f48df27f 100644 --- a/optimum/intel/pipelines/pipeline_base.py +++ b/optimum/intel/pipelines/pipeline_base.py @@ -16,23 +16,40 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Union import torch -from transformers import AutoConfig, AutoFeatureExtractor, AutoTokenizer -from transformers import pipeline as transformers_pipeline -from transformers.feature_extraction_utils import PreTrainedFeatureExtractor -from transformers.pipelines import ( +from transformers import ( AudioClassificationPipeline, + AutoConfig, + AutoFeatureExtractor, + AutoImageProcessor, + AutomaticSpeechRecognitionPipeline, + AutoTokenizer, + FeatureExtractionPipeline, FillMaskPipeline, ImageClassificationPipeline, + ImageToTextPipeline, + Pipeline, + PreTrainedTokenizer, + PreTrainedTokenizerFast, QuestionAnsweringPipeline, + SummarizationPipeline, + Text2TextGenerationPipeline, TextClassificationPipeline, TextGenerationPipeline, TokenClassificationPipeline, + TranslationPipeline, + ZeroShotClassificationPipeline, ) -from transformers.pipelines.base import Pipeline -from transformers.tokenization_utils import PreTrainedTokenizer +from transformers import pipeline as transformers_pipeline +from transformers.feature_extraction_utils import PreTrainedFeatureExtractor from transformers.utils import logging -from optimum.intel.utils import is_ipex_available +from optimum.intel.utils.import_utils import ( + IPEX_IMPORT_ERROR, + OPENVINO_IMPORT_ERROR, + is_ipex_available, + is_openvino_available, +) +from optimum.intel.utils.modeling_utils import _find_files_matching_pattern if is_ipex_available(): @@ -95,21 +112,161 @@ IPEX_SUPPORTED_TASKS = {} -def load_ipex_model( +if is_openvino_available(): + from ..openvino import ( + OVModelForAudioClassification, + OVModelForCausalLM, + OVModelForFeatureExtraction, + OVModelForImageClassification, + OVModelForMaskedLM, + OVModelForQuestionAnswering, + OVModelForSeq2SeqLM, + OVModelForSequenceClassification, + OVModelForSpeechSeq2Seq, + OVModelForTokenClassification, + OVModelForVision2Seq, + ) + from ..openvino.modeling_base import OVBaseModel + + OPENVINO_SUPPORTED_TASKS = { + "feature-extraction": { + "impl": FeatureExtractionPipeline, + "class": (OVModelForFeatureExtraction,), + "default": "distilbert-base-cased", + "type": "text", # feature extraction is only supported for text at the moment + }, + "fill-mask": { + "impl": FillMaskPipeline, + "class": (OVModelForMaskedLM,), + "default": "bert-base-cased", + "type": "text", + }, + "image-classification": { + "impl": ImageClassificationPipeline, + "class": (OVModelForImageClassification,), + "default": "google/vit-base-patch16-224", + "type": "image", + }, + "question-answering": { + "impl": QuestionAnsweringPipeline, + "class": (OVModelForQuestionAnswering,), + "default": "distilbert-base-cased-distilled-squad", + "type": "text", + }, + "text-classification": { + "impl": TextClassificationPipeline, + "class": (OVModelForSequenceClassification,), + "default": "distilbert-base-uncased-finetuned-sst-2-english", + "type": "text", + }, + "text-generation": { + "impl": TextGenerationPipeline, + "class": (OVModelForCausalLM,), + "default": "distilgpt2", + "type": "text", + }, + "token-classification": { + "impl": TokenClassificationPipeline, + "class": (OVModelForTokenClassification,), + "default": "dbmdz/bert-large-cased-finetuned-conll03-english", + "type": "text", + }, + "zero-shot-classification": { + "impl": ZeroShotClassificationPipeline, + "class": (OVModelForSequenceClassification,), + "default": "facebook/bart-large-mnli", + "type": "text", + }, + "summarization": { + "impl": SummarizationPipeline, + "class": (OVModelForSeq2SeqLM,), + "default": "t5-base", + "type": "text", + }, + "translation": { + "impl": TranslationPipeline, + "class": (OVModelForSeq2SeqLM,), + "default": "t5-small", + "type": "text", + }, + "text2text-generation": { + "impl": Text2TextGenerationPipeline, + "class": (OVModelForSeq2SeqLM,), + "default": "t5-small", + "type": "text", + }, + "automatic-speech-recognition": { + "impl": AutomaticSpeechRecognitionPipeline, + "class": (OVModelForSpeechSeq2Seq,), + "default": "openai/whisper-tiny.en", + "type": "multimodal", + }, + "image-to-text": { + "impl": ImageToTextPipeline, + "class": (OVModelForVision2Seq,), + "default": "nlpconnect/vit-gpt2-image-captioning", + "type": "multimodal", + }, + "audio-classification": { + "impl": AudioClassificationPipeline, + "class": (OVModelForAudioClassification,), + "default": "superb/hubert-base-superb-ks", + "type": "audio", + }, + } +else: + OPENVINO_SUPPORTED_TASKS = {} + + +def load_openvino_model( model, targeted_task, SUPPORTED_TASKS, - model_kwargs: Optional[Dict[str, Any]] = None, hub_kwargs: Optional[Dict[str, Any]] = None, + model_kwargs: Optional[Dict[str, Any]] = None, ): - if model_kwargs is None: - model_kwargs = {} + hub_kwargs = hub_kwargs or {} + model_kwargs = model_kwargs or {} + ov_model_class = SUPPORTED_TASKS[targeted_task]["class"][0] + + if model is None: + model_id = SUPPORTED_TASKS[targeted_task]["default"] + model = ov_model_class.from_pretrained(model_id, export=True, **hub_kwargs, **model_kwargs) + elif isinstance(model, str): + model_id = model + pattern = r"(.*)?openvino(.*)?\_model.xml" + ov_files = _find_files_matching_pattern( + model, + pattern, + use_auth_token=hub_kwargs.get("token", None), + revision=hub_kwargs.get("revision", None), + ) + export = len(ov_files) == 0 + model = ov_model_class.from_pretrained(model, export=export, **hub_kwargs, **model_kwargs) + elif isinstance(model, OVBaseModel): + model_id = model.model_save_dir + else: + raise ValueError( + f"""Model {model} is not supported. Please provide a valid model either as string or ORTModel. + You can also provide non model then a default one will be used""" + ) + return model, model_id + +def load_ipex_model( + model, + targeted_task, + SUPPORTED_TASKS, + hub_kwargs: Optional[Dict[str, Any]] = None, + model_kwargs: Optional[Dict[str, Any]] = None, +): + hub_kwargs = hub_kwargs or {} + model_kwargs = model_kwargs or {} ipex_model_class = SUPPORTED_TASKS[targeted_task]["class"][0] if model is None: model_id = SUPPORTED_TASKS[targeted_task]["default"] - model = ipex_model_class.from_pretrained(model_id, export=True, **model_kwargs, **hub_kwargs) + model = ipex_model_class.from_pretrained(model_id, export=True, **hub_kwargs, **model_kwargs) elif isinstance(model, str): model_id = model try: @@ -118,7 +275,7 @@ def load_ipex_model( except RuntimeError: logger.warning("We will use IPEXModel with export=True to export the model") export = True - model = ipex_model_class.from_pretrained(model, export=export, **model_kwargs, **hub_kwargs) + model = ipex_model_class.from_pretrained(model, export=export, **hub_kwargs, **model_kwargs) elif isinstance(model, IPEXModel): model_id = getattr(model.config, "name_or_path", None) else: @@ -132,6 +289,7 @@ def load_ipex_model( MAPPING_LOADING_FUNC = { "ipex": load_ipex_model, + "openvino": load_openvino_model, } @@ -154,8 +312,8 @@ def pipeline( revision: Optional[str] = None, trust_remote_code: Optional[bool] = None, torch_dtype: Optional[Union[str, torch.dtype]] = None, - commit_hash: Optional[str] = None, - **model_kwargs, + model_kwargs: Dict[str, Any] = None, + **kwargs, ) -> Pipeline: """ Utility factory method to build a [`Pipeline`]. @@ -214,9 +372,12 @@ def pipeline( >>> pipe = pipeline('text-generation', 'gpt2', torch_dtype=torch.bfloat16) >>> pipe("Describe a real-world application of AI in sustainable energy.") ```""" + if model_kwargs is None: model_kwargs = {} + commit_hash = kwargs.pop("_commit_hash", None) + if task is None and model is None: raise RuntimeError( "Impossible to instantiate a pipeline without either a task or a model " @@ -240,12 +401,19 @@ def pipeline( raise ValueError(msg + f" Supported list of `accelerator` is : {', '.join(MAPPING_LOADING_FUNC)}.") if accelerator == "ipex": - if task not in list(IPEX_SUPPORTED_TASKS.keys()): - raise ValueError( - f"Task {task} is not supported for the IPEX pipeline. Supported tasks are { list(IPEX_SUPPORTED_TASKS.keys())}" - ) + if not is_ipex_available(): + raise RuntimeError(IPEX_IMPORT_ERROR.format("`accelerator=ipex`")) + supported_tasks = IPEX_SUPPORTED_TASKS - supported_tasks = IPEX_SUPPORTED_TASKS if accelerator == "ipex" else None + if accelerator == "openvino": + if not is_openvino_available(): + raise RuntimeError(OPENVINO_IMPORT_ERROR.format("`accelerator=openvino`")) + supported_tasks = OPENVINO_SUPPORTED_TASKS + + if task not in supported_tasks: + raise ValueError( + f"Task {task} is not supported for the {accelerator} pipelines. Supported tasks are {', '.join(supported_tasks)}" + ) no_feature_extractor_tasks = set() no_tokenizer_tasks = set() @@ -272,6 +440,7 @@ def pipeline( if isinstance(model, Path): model = str(model) + tokenizer_kwargs = model_kwargs.copy() if torch_dtype is not None: if "torch_dtype" in model_kwargs: raise ValueError( @@ -280,20 +449,28 @@ def pipeline( ) model_kwargs["torch_dtype"] = torch_dtype - # Load the correct model if possible - # Infer the framework from the model if not already defined - model, model_id = MAPPING_LOADING_FUNC[accelerator](model, task, supported_tasks, model_kwargs, hub_kwargs) + # Load the correct model and convert it to the expected format if needed + model, model_id = MAPPING_LOADING_FUNC[accelerator]( + model, + task, + SUPPORTED_TASKS=supported_tasks, + hub_kwargs=hub_kwargs, + model_kwargs=model_kwargs, + **kwargs, + ) if load_tokenizer and tokenizer is None: - tokenizer = AutoTokenizer.from_pretrained(model_id, **hub_kwargs, **model_kwargs) + tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=use_fast, **hub_kwargs, **tokenizer_kwargs) if load_feature_extractor and feature_extractor is None: - feature_extractor = AutoFeatureExtractor.from_pretrained(model_id, **hub_kwargs, **model_kwargs) + try: + feature_extractor = AutoFeatureExtractor.from_pretrained(model_id, **hub_kwargs, **tokenizer_kwargs) + except Exception: + feature_extractor = AutoImageProcessor.from_pretrained(model_id, **hub_kwargs, **tokenizer_kwargs) return transformers_pipeline( task, model=model, tokenizer=tokenizer, feature_extractor=feature_extractor, - use_fast=use_fast, torch_dtype=torch_dtype, ) diff --git a/optimum/intel/utils/modeling_utils.py b/optimum/intel/utils/modeling_utils.py index 3541f4f933..49069778e5 100644 --- a/optimum/intel/utils/modeling_utils.py +++ b/optimum/intel/utils/modeling_utils.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Tuple +import re +from pathlib import Path +from typing import List, Optional, Tuple, Union import torch +from huggingface_hub import HfApi, HfFolder from transformers.modeling_utils import PreTrainedModel @@ -191,3 +194,49 @@ def _setattr_from_module(new_module, module): if k.startswith("__") or k.startswith("forward"): continue setattr(new_module.__class__, k, getattr(module.__class__, k)) + + +def _find_files_matching_pattern( + model_name_or_path: Union[str, Path], + pattern: str, + subfolder: str = "", + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, +) -> List[Path]: + """ + Scans either a model repo or a local directory to find filenames matching the pattern. + + Args: + model_name_or_path (`Union[str, Path]`): + The name of the model repo on the Hugging Face Hub or the path to a local directory. + pattern (`str`): + The pattern to use to look for files. + subfolder (`str`, defaults to `""`): + In case the model files are located inside a subfolder of the model directory / repo on the Hugging + Face Hub, you can specify the subfolder name here. + use_auth_token (`Optional[bool, str]`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`Optional[str]`, defaults to `None`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. + + Returns: + `List[Path]` + """ + model_path = Path(model_name_or_path) if isinstance(model_name_or_path, str) else model_name_or_path + pattern = re.compile(f"{subfolder}/{pattern}" if subfolder != "" else pattern) + subfolder = subfolder or "." + + if model_path.is_dir(): + glob_pattern = subfolder + "/*" + files = model_path.glob(glob_pattern) + files = [p for p in files if re.search(pattern, str(p))] + else: + if isinstance(use_auth_token, bool): + token = HfFolder().get_token() + else: + token = use_auth_token + repo_files = map(Path, HfApi().list_repo_files(model_name_or_path, revision=revision, token=token)) + files = [Path(p) for p in repo_files if re.match(pattern, str(p)) and str(p.parent) == subfolder] + + return files diff --git a/tests/openvino/test_modeling.py b/tests/openvino/test_modeling.py index a36aae3c51..143a9f97b4 100644 --- a/tests/openvino/test_modeling.py +++ b/tests/openvino/test_modeling.py @@ -77,9 +77,11 @@ OVStableDiffusionPipeline, ) from optimum.intel.openvino import OV_DECODER_NAME, OV_DECODER_WITH_PAST_NAME, OV_ENCODER_NAME, OV_XML_FILE_NAME +from optimum.intel.openvino.modeling_base import OVBaseModel from optimum.intel.openvino.modeling_seq2seq import OVDecoder, OVEncoder from optimum.intel.openvino.modeling_timm import TimmImageProcessor from optimum.intel.openvino.utils import _print_compiled_model_properties +from optimum.intel.pipelines import pipeline as optimum_pipeline from optimum.intel.utils.import_utils import is_openvino_version, is_transformers_version from optimum.utils import ( DIFFUSION_MODEL_TEXT_ENCODER_SUBFOLDER, @@ -258,6 +260,49 @@ def test_load_model_from_hub_private_with_token(self): self.assertIsInstance(model.config, PretrainedConfig) +class PipelineTest(unittest.TestCase): + def test_load_model_from_hub(self): + model_id = "echarlaix/tiny-random-PhiForCausalLM" + + # verify could load both pytorch and openvino model (export argument should automatically infered) + ov_exported_pipe = optimum_pipeline("text-generation", model_id, revision="pt", accelerator="openvino") + ov_pipe = optimum_pipeline("text-generation", model_id, revision="ov", accelerator="openvino") + self.assertIsInstance(ov_exported_pipe.model, OVBaseModel) + self.assertIsInstance(ov_pipe.model, OVBaseModel) + + with tempfile.TemporaryDirectory() as tmpdirname: + ov_exported_pipe.save_pretrained(tmpdirname) + folder_contents = os.listdir(tmpdirname) + self.assertTrue(OV_XML_FILE_NAME in folder_contents) + self.assertTrue(OV_XML_FILE_NAME.replace(".xml", ".bin") in folder_contents) + ov_exported_pipe = optimum_pipeline("text-generation", tmpdirname, accelerator="openvino") + self.assertIsInstance(ov_exported_pipe.model, OVBaseModel) + + del ov_exported_pipe + del ov_pipe + gc.collect() + + def test_seq2seq_load_from_hub(self): + model_id = "echarlaix/tiny-random-t5" + # verify could load both pytorch and openvino model (export argument should automatically infered) + ov_exported_pipe = optimum_pipeline("text2text-generation", model_id, accelerator="openvino") + ov_pipe = optimum_pipeline("text2text-generation", model_id, revision="ov", accelerator="openvino") + self.assertIsInstance(ov_exported_pipe.model, OVBaseModel) + self.assertIsInstance(ov_pipe.model, OVBaseModel) + + with tempfile.TemporaryDirectory() as tmpdirname: + ov_exported_pipe.save_pretrained(tmpdirname) + folder_contents = os.listdir(tmpdirname) + self.assertTrue(OV_DECODER_WITH_PAST_NAME in folder_contents) + self.assertTrue(OV_DECODER_WITH_PAST_NAME.replace(".xml", ".bin") in folder_contents) + ov_exported_pipe = optimum_pipeline("text2text-generation", tmpdirname, accelerator="openvino") + self.assertIsInstance(ov_exported_pipe.model, OVBaseModel) + + del ov_exported_pipe + del ov_pipe + gc.collect() + + class OVModelForSequenceClassificationIntegrationTest(unittest.TestCase): SUPPORTED_ARCHITECTURES = ( "albert", @@ -304,29 +349,36 @@ def test_compare_to_transformers(self, model_arch): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): + set_seed(SEED) model_id = MODEL_NAMES[model_arch] model = OVModelForSequenceClassification.from_pretrained(model_id, export=True, compile=False) model.eval() tokenizer = AutoTokenizer.from_pretrained(model_id) pipe = pipeline("text-classification", model=model, tokenizer=tokenizer) - text = "This restaurant is awesome" - outputs = pipe(text) + inputs = "This restaurant is awesome" + outputs = pipe(inputs) self.assertTrue(model.is_dynamic) self.assertEqual(pipe.device, model.device) self.assertGreaterEqual(outputs[0]["score"], 0.0) self.assertIsInstance(outputs[0]["label"], str) + + ov_pipe = optimum_pipeline("text-classification", model_id, accelerator="openvino") + ov_outputs = ov_pipe(inputs) + self.assertEqual(outputs[-1]["score"], ov_outputs[-1]["score"]) + del ov_pipe + if model_arch == "bert": # Test FP16 conversion model.half() model.to("cpu") model.compile() - outputs = pipe(text) + outputs = pipe(inputs) self.assertGreaterEqual(outputs[0]["score"], 0.0) self.assertIsInstance(outputs[0]["label"], str) # Test static shapes model.reshape(1, 25) model.compile() - outputs = pipe(text) + outputs = pipe(inputs) self.assertTrue(not model.is_dynamic) self.assertGreaterEqual(outputs[0]["score"], 0.0) self.assertIsInstance(outputs[0]["label"], str) @@ -380,6 +432,7 @@ def test_compare_to_transformers(self, model_arch): @pytest.mark.run_slow @slow def test_pipeline(self, model_arch): + set_seed(SEED) model_id = MODEL_NAMES[model_arch] model = OVModelForQuestionAnswering.from_pretrained(model_id, export=True) model.eval() @@ -391,7 +444,11 @@ def test_pipeline(self, model_arch): self.assertEqual(pipe.device, model.device) self.assertGreaterEqual(outputs["score"], 0.0) self.assertIsInstance(outputs["answer"], str) + ov_pipe = optimum_pipeline("question-answering", model_id, accelerator="openvino") + ov_outputs = ov_pipe(question, context) + self.assertEqual(outputs["score"], ov_outputs["score"]) del model + del ov_pipe gc.collect() @pytest.mark.run_slow @@ -451,14 +508,20 @@ def test_compare_to_transformers(self, model_arch): @pytest.mark.run_slow @slow def test_pipeline(self, model_arch): + set_seed(SEED) model_id = MODEL_NAMES[model_arch] model = OVModelForTokenClassification.from_pretrained(model_id, export=True) model.eval() tokenizer = AutoTokenizer.from_pretrained(model_id) pipe = pipeline("token-classification", model=model, tokenizer=tokenizer) - outputs = pipe("My Name is Arthur and I live in Lyon.") + inputs = "My Name is Arthur and I live in" + outputs = pipe(inputs) self.assertEqual(pipe.device, model.device) self.assertTrue(all(item["score"] > 0.0 for item in outputs)) + ov_pipe = optimum_pipeline("token-classification", model_id, accelerator="openvino") + ov_outputs = ov_pipe(inputs) + self.assertEqual(outputs[-1]["score"], ov_outputs[-1]["score"]) + del ov_pipe del model del pipe gc.collect() @@ -503,14 +566,20 @@ def test_compare_to_transformers(self, model_arch): @pytest.mark.run_slow @slow def test_pipeline(self, model_arch): + set_seed(SEED) model_id = MODEL_NAMES[model_arch] model = OVModelForFeatureExtraction.from_pretrained(model_id, export=True) model.eval() tokenizer = AutoTokenizer.from_pretrained(model_id) pipe = pipeline("feature-extraction", model=model, tokenizer=tokenizer) - outputs = pipe("My Name is Arthur and I live in Lyon.") + inputs = "My Name is Arthur and I live in" + outputs = pipe(inputs) self.assertEqual(pipe.device, model.device) self.assertTrue(all(all(isinstance(item, float) for item in row) for row in outputs[0])) + ov_pipe = optimum_pipeline("feature-extraction", model_id, accelerator="openvino") + ov_outputs = ov_pipe(inputs) + self.assertEqual(outputs[-1][-1][-1], ov_outputs[-1][-1][-1]) + del ov_pipe del pipe del model gc.collect() @@ -674,6 +743,7 @@ def test_compare_to_transformers(self, model_arch): @pytest.mark.run_slow @slow def test_pipeline(self, model_arch): + set_seed(SEED) model_kwargs = {} model_id = MODEL_NAMES[model_arch] if model_arch in self.REMOTE_CODE_MODELS: @@ -695,9 +765,22 @@ def test_pipeline(self, model_arch): model.half() model.compile() pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) - outputs = pipe("This is a sample", max_length=20) + inputs = "My name is Arthur and I live in" + set_seed(SEED) + outputs = pipe(inputs, max_new_tokens=5) self.assertEqual(pipe.device, model.device) - self.assertTrue(all("This is a sample" in item["generated_text"] for item in outputs)) + self.assertTrue(all(inputs in item["generated_text"] for item in outputs)) + ov_pipe = optimum_pipeline( + "text-generation", + model_id, + accelerator="openvino", + trust_remote_code=model_arch in self.REMOTE_CODE_MODELS, + tokenizer=tokenizer if model_arch == "qwen" else None, + ) + set_seed(SEED) + ov_outputs = ov_pipe(inputs, max_new_tokens=5) + self.assertEqual(outputs[-1]["generated_text"], ov_outputs[-1]["generated_text"]) + del ov_pipe del pipe del model gc.collect() @@ -944,14 +1027,21 @@ def test_compare_to_transformers(self, model_arch): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): + set_seed(SEED) model_id = MODEL_NAMES[model_arch] model = OVModelForMaskedLM.from_pretrained(model_id, export=True) model.eval() tokenizer = AutoTokenizer.from_pretrained(model_id) pipe = pipeline("fill-mask", model=model, tokenizer=tokenizer) - outputs = pipe(f"This is a {tokenizer.mask_token}.") + inputs = f"This is a {tokenizer.mask_token}." + outputs = pipe(inputs) self.assertEqual(pipe.device, model.device) self.assertTrue(all(item["score"] > 0.0 for item in outputs)) + + ov_pipe = optimum_pipeline("fill-mask", model_id, accelerator="openvino") + ov_outputs = ov_pipe(inputs) + self.assertEqual(outputs[-1]["score"], ov_outputs[-1]["score"]) + del ov_pipe del pipe del model gc.collect() @@ -1004,15 +1094,23 @@ def test_compare_to_transformers(self, model_arch): @pytest.mark.run_slow @slow def test_pipeline(self, model_arch): + set_seed(SEED) model_id = MODEL_NAMES[model_arch] model = OVModelForImageClassification.from_pretrained(model_id, export=True) model.eval() preprocessor = AutoFeatureExtractor.from_pretrained(model_id) pipe = pipeline("image-classification", model=model, feature_extractor=preprocessor) - outputs = pipe("http://images.cocodataset.org/val2017/000000039769.jpg") + inputs = "http://images.cocodataset.org/val2017/000000039769.jpg" + set_seed(SEED) + outputs = pipe(inputs) self.assertEqual(pipe.device, model.device) self.assertGreaterEqual(outputs[0]["score"], 0.0) self.assertTrue(isinstance(outputs[0]["label"], str)) + ov_pipe = optimum_pipeline("image-classification", model_id, accelerator="openvino") + set_seed(SEED) + ov_outputs = ov_pipe(inputs) + self.assertEqual(outputs[-1]["score"], ov_outputs[-1]["score"]) + del ov_pipe del model del pipe gc.collect() @@ -1100,34 +1198,38 @@ def test_compare_to_transformers(self, model_arch): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): + set_seed(SEED) model_id = MODEL_NAMES[model_arch] tokenizer = AutoTokenizer.from_pretrained(model_id) + inputs = "This is a test" model = OVModelForSeq2SeqLM.from_pretrained(model_id, export=True, compile=False) model.eval() model.half() model.to("cpu") model.compile() - # Text2Text generation - pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer) - text = "This is a test" - outputs = pipe(text) - self.assertEqual(pipe.device, model.device) - self.assertIsInstance(outputs[0]["generated_text"], str) - # Summarization pipe = pipeline("summarization", model=model, tokenizer=tokenizer) - text = "This is a test" - outputs = pipe(text) + outputs = pipe(inputs) self.assertEqual(pipe.device, model.device) self.assertIsInstance(outputs[0]["summary_text"], str) # Translation pipe = pipeline("translation_en_to_fr", model=model, tokenizer=tokenizer) - text = "This is a test" - outputs = pipe(text) + outputs = pipe(inputs) self.assertEqual(pipe.device, model.device) self.assertIsInstance(outputs[0]["translation_text"], str) + + # Text2Text generation + pipe = pipeline("text2text-generation", model=model, tokenizer=tokenizer) + outputs = pipe(inputs) + self.assertEqual(pipe.device, model.device) + self.assertIsInstance(outputs[0]["generated_text"], str) + + ov_pipe = optimum_pipeline("text2text-generation", model_id, accelerator="openvino") + ov_outputs = ov_pipe(inputs) + self.assertEqual(outputs[-1]["generated_text"], ov_outputs[-1]["generated_text"]) + del ov_pipe del pipe del model gc.collect() @@ -1237,14 +1339,21 @@ def test_compare_to_transformers(self, model_arch): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): + set_seed(SEED) model_id = MODEL_NAMES[model_arch] model = OVModelForAudioClassification.from_pretrained(model_id, export=True) model.eval() preprocessor = AutoFeatureExtractor.from_pretrained(model_id) pipe = pipeline("audio-classification", model=model, feature_extractor=preprocessor) - outputs = pipe([np.random.random(16000)]) + inputs = [np.random.random(16000)] + outputs = pipe(inputs) self.assertEqual(pipe.device, model.device) self.assertTrue(all(item["score"] > 0.0 for item in outputs[0])) + + ov_pipe = optimum_pipeline("audio-classification", model_id, accelerator="openvino") + ov_outputs = ov_pipe(inputs) + self.assertEqual(outputs[-1][-1]["score"], ov_outputs[-1][-1]["score"]) + del ov_pipe del pipe del model gc.collect() @@ -1548,21 +1657,25 @@ def test_compare_to_transformers(self, model_arch): @parameterized.expand(SUPPORTED_ARCHITECTURES) def test_pipeline(self, model_arch): + set_seed(SEED) model_id = MODEL_NAMES[model_arch] model = OVModelForSpeechSeq2Seq.from_pretrained(model_id, export=True) model.eval() processor = get_preprocessor(model_id) - GenerationConfig.from_pretrained(model_id) pipe = pipeline( "automatic-speech-recognition", model=model, tokenizer=processor.tokenizer, feature_extractor=processor.feature_extractor, ) - data = self._generate_random_audio_data() - outputs = pipe(data) + inputs = self._generate_random_audio_data() + outputs = pipe(inputs) self.assertIsInstance(outputs["text"], str) + ov_pipe = optimum_pipeline("automatic-speech-recognition", model_id, accelerator="openvino") + ov_outputs = ov_pipe(inputs) + self.assertEqual(outputs["text"], ov_outputs["text"]) + del ov_pipe del pipe del model gc.collect() @@ -1649,7 +1762,8 @@ def test_compare_to_transformers(self, model_arch: str): gc.collect() @parameterized.expand(SUPPORTED_ARCHITECTURES) - def test_pipeline_image_to_text(self, model_arch: str): + def test_pipeline(self, model_arch: str): + set_seed(SEED) model_id = MODEL_NAMES[model_arch] ov_model = OVModelForVision2Seq.from_pretrained(model_id, export=True, compile=False) feature_extractor, tokenizer = self._get_preprocessors(model_id) @@ -1663,10 +1777,13 @@ def test_pipeline_image_to_text(self, model_arch: str): tokenizer=tokenizer, feature_extractor=feature_extractor, ) - data = self._get_sample_image() - outputs = pipe(data, max_new_tokens=3) + inputs = self._get_sample_image() + outputs = pipe(inputs, max_new_tokens=3) self.assertEqual(pipe.device, ov_model.device) self.assertIsInstance(outputs[0]["generated_text"], str) + ov_pipe = optimum_pipeline("image-to-text", model_id, accelerator="openvino") + ov_outputs = ov_pipe(inputs, max_new_tokens=3) + self.assertEqual(outputs[-1]["generated_text"], ov_outputs[-1]["generated_text"]) gc.collect()