diff --git a/optimum/onnxruntime/__init__.py b/optimum/onnxruntime/__init__.py index 09a48ec955..1cb5b7c47b 100644 --- a/optimum/onnxruntime/__init__.py +++ b/optimum/onnxruntime/__init__.py @@ -44,6 +44,7 @@ "ORTModelForSemanticSegmentation", "ORTModelForSequenceClassification", "ORTModelForTokenClassification", + "ORTModelForImageToImage", ], "modeling_seq2seq": [ "ORTModelForSeq2SeqLM", @@ -112,6 +113,7 @@ ORTModelForCustomTasks, ORTModelForFeatureExtraction, ORTModelForImageClassification, + ORTModelForImageToImage, ORTModelForMaskedLM, ORTModelForMultipleChoice, ORTModelForQuestionAnswering, diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index 254b771e33..9166f7c2cb 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -34,6 +34,7 @@ AutoModelForAudioXVector, AutoModelForCTC, AutoModelForImageClassification, + AutoModelForImageToImage, AutoModelForMaskedLM, AutoModelForMultipleChoice, AutoModelForQuestionAnswering, @@ -47,6 +48,7 @@ BaseModelOutput, CausalLMOutput, ImageClassifierOutput, + ImageSuperResolutionOutput, MaskedLMOutput, ModelOutput, MultipleChoiceModelOutput, @@ -2183,6 +2185,77 @@ def forward( return TokenClassifierOutput(logits=logits) +IMAGE_TO_IMAGE_EXAMPLE = r""" + Example of image-to-image (Super Resolution): + + ```python + >>> from transformers import {processor_class} + >>> from optimum.onnxruntime import {model_class} + >>> from PIL import Image + + >>> image = Image.open("path/to/image.jpg") + + >>> image_processor = {processor_class}.from_pretrained("{checkpoint}") + >>> model = {model_class}.from_pretrained("{checkpoint}") + + >>> inputs = image_processor(images=image, return_tensors="pt") + + >>> with torch.no_grad(): + ... logits = model(**inputs).logits + ``` +""" + + +@add_end_docstrings(ONNX_MODEL_END_DOCSTRING) +class ORTModelForImageToImage(ORTModel): + """ + ONNX Model for image-to-image tasks. This class officially supports pix2pix, cyclegan, wav2vec2, wav2vec2-conformer. + """ + + auto_model_class = AutoModelForImageToImage + + @add_start_docstrings_to_model_forward( + ONNX_IMAGE_INPUTS_DOCSTRING.format("batch_size, num_channels, height, width") + + IMAGE_TO_IMAGE_EXAMPLE.format( + processor_class=_PROCESSOR_FOR_DOC, + model_class="ORTModelForImgageToImage", + checkpoint="caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr", + ) + ) + def forward( + self, + pixel_values: Union[torch.Tensor, np.ndarray], + **kwargs, + ): + use_torch = isinstance(pixel_values, torch.Tensor) + self.raise_on_numpy_input_io_binding(use_torch) + if self.device.type == "cuda" and self.use_io_binding: + input_shapes = pixel_values.shape + io_binding, output_shapes, output_buffers = self.prepare_io_binding( + pixel_values, + ordered_input_names=self._ordered_input_names, + known_output_shapes={ + "reconstruction": [ + input_shapes[0], + input_shapes[1], + input_shapes[2] * self.config.upscale, + input_shapes[3] * self.config.upscale, + ] + }, + ) + io_binding.synchronize_inputs() + self.model.run_with_iobinding(io_binding) + io_binding.synchronize_outputs() + reconstruction = output_buffers["reconstruction"].view(output_shapes["reconstruction"]) + else: + model_inputs = {"pixel_values": pixel_values} + onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs) + onnx_outputs = self.model.run(None, onnx_inputs) + model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs) + reconstruction = model_outputs["reconstruction"] + return ImageSuperResolutionOutput(reconstruction=reconstruction) + + CUSTOM_TASKS_EXAMPLE = r""" Example of custom tasks(e.g. a sentence transformers taking `pooler_output` as output): diff --git a/optimum/pipelines/pipelines_base.py b/optimum/pipelines/pipelines_base.py index a08ab8782a..7690143f13 100644 --- a/optimum/pipelines/pipelines_base.py +++ b/optimum/pipelines/pipelines_base.py @@ -24,6 +24,7 @@ FillMaskPipeline, ImageClassificationPipeline, ImageSegmentationPipeline, + ImageToImagePipeline, ImageToTextPipeline, Pipeline, PreTrainedTokenizer, @@ -55,6 +56,7 @@ ORTModelForCausalLM, ORTModelForFeatureExtraction, ORTModelForImageClassification, + ORTModelForImageToImage, ORTModelForMaskedLM, ORTModelForQuestionAnswering, ORTModelForSemanticSegmentation, @@ -157,6 +159,12 @@ "default": "superb/hubert-base-superb-ks", "type": "audio", }, + "image-to-image": { + "impl": ImageToImagePipeline, + "class": (ORTModelForImageToImage,), + "default": "caidas/swin2SR-classical-sr-x2-64", + "type": "image", + }, } else: ORT_SUPPORTED_TASKS = {} diff --git a/tests/onnxruntime/test_modeling.py b/tests/onnxruntime/test_modeling.py index 199b96342e..f6771ce761 100644 --- a/tests/onnxruntime/test_modeling.py +++ b/tests/onnxruntime/test_modeling.py @@ -42,6 +42,7 @@ AutoModelForCausalLM, AutoModelForCTC, AutoModelForImageClassification, + AutoModelForImageToImage, AutoModelForMaskedLM, AutoModelForMultipleChoice, AutoModelForQuestionAnswering, @@ -57,7 +58,9 @@ PretrainedConfig, set_seed, ) +from transformers.modeling_outputs import ImageSuperResolutionOutput from transformers.modeling_utils import no_init_weights +from transformers.models.swin2sr.configuration_swin2sr import Swin2SRConfig from transformers.onnx.utils import get_preprocessor from transformers.testing_utils import get_gpu_count, require_torch_gpu, slow from utils_onnxruntime_tests import MODEL_NAMES, SEED, ORTModelTestMixin @@ -79,6 +82,7 @@ ORTModelForCustomTasks, ORTModelForFeatureExtraction, ORTModelForImageClassification, + ORTModelForImageToImage, ORTModelForMaskedLM, ORTModelForMultipleChoice, ORTModelForPix2Struct, @@ -4704,6 +4708,136 @@ def test_compare_generation_to_io_binding( gc.collect() +class ORTModelForImageToImageIntegrationTest(ORTModelTestMixin): + SUPPORTED_ARCHITECTURES = ["swin2sr"] + + ORTMODEL_CLASS = ORTModelForImageToImage + + TASK = "image-to-image" + + def _get_sample_image(self): + url = "http://images.cocodataset.org/val2017/000000039769.jpg" + image = Image.open(requests.get(url, stream=True).raw) + return image + + def _get_preprocessors(self, model_id): + image_processor = AutoImageProcessor.from_pretrained(model_id) + + return image_processor + + def test_load_vanilla_transformers_which_is_not_supported(self): + with self.assertRaises(Exception) as context: + _ = ORTModelForImageToImage.from_pretrained(MODEL_NAMES["bert"], export=True) + + self.assertIn("only supports the tasks", str(context.exception)) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_compare_to_transformers(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + model_id = MODEL_NAMES[model_arch] + onnx_model = ORTModelForImageToImage.from_pretrained(self.onnx_model_dirs[model_arch]) + self.assertIsInstance(onnx_model.config, Swin2SRConfig) + set_seed(SEED) + + transformers_model = AutoModelForImageToImage.from_pretrained(model_id) + image_processor = self._get_preprocessors(model_id) + + data = self._get_sample_image() + features = image_processor(data, return_tensors="pt") + + with torch.no_grad(): + transformers_outputs = transformers_model(**features) + + onnx_outputs = onnx_model(**features) + self.assertIsInstance(onnx_outputs, ImageSuperResolutionOutput) + self.assertTrue("reconstruction" in onnx_outputs) + self.assertIsInstance(onnx_outputs.reconstruction, torch.Tensor) + self.assertTrue(torch.allclose(onnx_outputs.reconstruction, transformers_outputs.reconstruction, atol=1e-4)) + + gc.collect() + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_generate_utils(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + model_id = MODEL_NAMES[model_arch] + onnx_model = ORTModelForImageToImage.from_pretrained(self.onnx_model_dirs[model_arch]) + image_processor = self._get_preprocessors(model_id) + + data = self._get_sample_image() + features = image_processor(data, return_tensors="pt") + + outputs = onnx_model(**features) + self.assertIsInstance(outputs, ImageSuperResolutionOutput) + + gc.collect() + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + def test_pipeline_image_to_image(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + model_id = MODEL_NAMES[model_arch] + onnx_model = ORTModelForImageToImage.from_pretrained(self.onnx_model_dirs[model_arch]) + image_processor = self._get_preprocessors(model_id) + pipe = pipeline( + "image-to-image", + model=onnx_model, + feature_extractor=image_processor, + ) + data = self._get_sample_image() + outputs = pipe(data) + self.assertEqual(pipe.device, onnx_model.device) + self.assertIsInstance(outputs, Image.Image) + + gc.collect() + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_torch_gpu + @pytest.mark.cuda_ep_test + def test_pipeline_on_gpu(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + model_id = MODEL_NAMES[model_arch] + onnx_model = ORTModelForImageToImage.from_pretrained(self.onnx_model_dirs[model_arch]) + image_processor = self._get_preprocessors(model_id) + pipe = pipeline( + "image-to-image", + model=onnx_model, + feature_extractor=image_processor, + device=0, + ) + + data = self._get_sample_image() + outputs = pipe(data) + + self.assertEqual(pipe.model.device.type.lower(), "cuda") + self.assertIsInstance(outputs, Image.Image) + + @parameterized.expand(SUPPORTED_ARCHITECTURES) + @require_torch_gpu + @require_ort_rocm + @pytest.mark.rocm_ep_test + def test_pipeline_on_rocm(self, model_arch: str): + model_args = {"test_name": model_arch, "model_arch": model_arch} + self._setup(model_args) + model_id = MODEL_NAMES[model_arch] + onnx_model = ORTModelForImageToImage.from_pretrained(self.onnx_model_dirs[model_arch]) + image_processor = self._get_preprocessors(model_id) + pipe = pipeline( + "image-to-image", + model=onnx_model, + feature_extractor=image_processor, + device=0, + ) + + data = self._get_sample_image() + outputs = pipe(data) + + self.assertEqual(pipe.model.device.type.lower(), "cuda") + self.assertIsInstance(outputs, Image.Image) + + class ORTModelForVision2SeqIntegrationTest(ORTModelTestMixin): SUPPORTED_ARCHITECTURES = ["vision-encoder-decoder", "trocr", "donut"] @@ -4831,7 +4965,6 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach len(onnx_outputs["past_key_values"][0]), len(transformers_outputs["past_key_values"][0]) ) for i in range(len(onnx_outputs["past_key_values"])): - print(onnx_outputs["past_key_values"][i]) for ort_pkv, trfs_pkv in zip( onnx_outputs["past_key_values"][i], transformers_outputs["past_key_values"][i] ): @@ -5517,6 +5650,7 @@ class TestBothExportersORTModel(unittest.TestCase): ["automatic-speech-recognition", ORTModelForCTCIntegrationTest], ["audio-xvector", ORTModelForAudioXVectorIntegrationTest], ["audio-frame-classification", ORTModelForAudioFrameClassificationIntegrationTest], + ["image-to-image", ORTModelForImageToImageIntegrationTest], ] ) def test_find_untested_architectures(self, task: str, test_class): diff --git a/tests/onnxruntime/utils_onnxruntime_tests.py b/tests/onnxruntime/utils_onnxruntime_tests.py index bb6935461d..0790f6329d 100644 --- a/tests/onnxruntime/utils_onnxruntime_tests.py +++ b/tests/onnxruntime/utils_onnxruntime_tests.py @@ -144,6 +144,7 @@ "stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl", "swin": "hf-internal-testing/tiny-random-SwinModel", "swin-window": "yujiepan/tiny-random-swin-patch4-window7-224", + "swin2sr": "hf-internal-testing/tiny-random-Swin2SRForImageSuperResolution", "t5": "hf-internal-testing/tiny-random-t5", "table-transformer": "hf-internal-testing/tiny-random-TableTransformerModel", "trocr": "microsoft/trocr-small-handwritten",