Skip to content

Commit

Permalink
add torch_dtype option for loading model
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Jun 25, 2024
1 parent 1dbd9d6 commit 6a2e507
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 9 deletions.
31 changes: 22 additions & 9 deletions optimum/exporters/openvino/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from optimum.exporters.openvino.convert import export_from_model
from optimum.intel.utils.import_utils import (
is_openvino_tokenizers_available,
is_transformers_version,
is_openvino_version,
is_transformers_version,
)
from optimum.utils.save_utils import maybe_load_preprocessors

Expand Down Expand Up @@ -88,6 +88,7 @@ def main_export(
stateful: bool = True,
convert_tokenizer: bool = False,
library_name: Optional[str] = None,
model_loading_kwargs: Optional[Dict[str, Any]] = None,
**kwargs_shapes,
):
"""
Expand Down Expand Up @@ -196,6 +197,9 @@ def main_export(
original_task = task
task = infer_task(task, model_name_or_path)
framework = TasksManager.determine_framework(model_name_or_path, subfolder=subfolder, framework=framework)
if framework == "pt":
import torch

library_name_is_not_provided = library_name is None
library_name = TasksManager.infer_library_from_model(
model_name_or_path, subfolder=subfolder, library_name=library_name
Expand All @@ -211,7 +215,7 @@ def main_export(
do_gptq_patching = False
custom_architecture = False
patch_16bit = False
loading_kwargs = {}
loading_kwargs = model_loading_kwargs or {}
if library_name == "transformers":
config = AutoConfig.from_pretrained(
model_name_or_path,
Expand Down Expand Up @@ -262,19 +266,28 @@ def main_export(
"Please provide custom export config if you want load model with remote code."
)
trust_remote_code = False
dtype = loading_kwargs.get("torch_dtype")
if isinstance(dtype, str):
dtype = config.torch_dtype if dtype == "auto" else getattr(torch, dtype)

if (
not do_gptq_patching
dtype is None
and framework == "pt"
and not do_gptq_patching
and task.startswith("text-generation")
and getattr(config, "torch_dtype", "float32") in ["float16", "bfloat16"]
and getattr(config, "torch_dtype", torch.float32) in [torch.float16, torch.bfloat16]
):
if is_openvino_version(">=", "2024.2") and config.torch_dtype == "float16":
loading_kwargs["torch_dtype"] = torch.float16
patch_16bit = True
if is_openvino_version(">=", "2024.3") and config.torch_dtype == "bfloat16":
loading_kwargs["torch_dtype"] = torch.bfloat16
if is_openvino_version(">=", "2024.2") and config.torch_dtype == torch.float16:
dtype = torch.float16
if is_openvino_version(">=", "2024.3") and config.torch_dtype == torch.bfloat16:
dtype = torch.bfloat16

if dtype is not None:
if dtype in [torch.float16, torch.bfloat16]:
patch_16bit = True
loading_kwargs["torch_dtype"] = dtype

logger.warning(loading_kwargs)
# Patch the modules to export of GPTQ models w/o GPU
if do_gptq_patching:
import torch
Expand Down
8 changes: 8 additions & 0 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,13 @@ def _from_transformers(

stateful = kwargs.pop("stateful", ensure_stateful_is_available(warn=False) and use_cache)

torch_dtype = kwargs.pop("torch_dtype", None)

model_loading_kwargs = {}

if torch_dtype is not None:
model_loading_kwargs["torch_dtype"] = torch_dtype

main_export(
model_name_or_path=model_id,
output=save_dir_path,
Expand All @@ -293,6 +300,7 @@ def _from_transformers(
trust_remote_code=trust_remote_code,
ov_config=ov_export_config,
stateful=stateful,
model_loading_kwargs=model_loading_kwargs,
)

config.is_decoder = True
Expand Down
27 changes: 27 additions & 0 deletions tests/openvino/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def _openvino_export(
model_type: str,
compression_option: Optional[str] = None,
stateful: bool = True,
patch_16bit_model: bool = False,
):
auto_model = self.SUPPORTED_ARCHITECTURES[model_type]
task = auto_model.export_feature
Expand Down Expand Up @@ -171,6 +172,32 @@ def test_export_with_custom_gen_config(self, model_type):
self.assertIsInstance(ov_model.generation_config, GenerationConfig)
self.assertTrue(ov_model.generation_config.top_k == 42)

def test_export_fp16_model(self):
auto_model = self.SUPPORTED_ARCHITECTURES["gpt2"]
task = auto_model.export_feature
model_name = MODEL_NAMES["gpt2"]
model = auto_model.auto_model_class.from_pretrained(model_name, torch_dtype=torch.float16)
stateful = True

for supported_task in [task, task + "with-past"]:
with TemporaryDirectory() as tmpdirname:
export_from_model(
model=model,
output=Path(tmpdirname),
task=task,
preprocessors=None,
patch_16bit_model=True,
stateful=stateful,
)
use_cache = supported_task.endswith("-with-past")
ov_model = auto_model.from_pretrained(tmpdirname, use_cache=use_cache)
self.assertIsInstance(ov_model, OVBaseModel)
self.assertEqual(ov_model.use_cache, use_cache)
self.assertEqual(ov_model.stateful, stateful and use_cache)
self.assertEqual(
ov_model.model.get_rt_info()["optimum"]["transformers_version"], _transformers_version
)


class CustomExportModelTest(unittest.TestCase):
def test_custom_export_config_model(self):
Expand Down
24 changes: 24 additions & 0 deletions tests/openvino/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,30 @@ def test_beam_search(self, model_arch):
f"generation config : {gen_config}, transformers output {transformers_outputs}, ov_model_stateless output {ov_stateless_outputs}",
)

def test_load_with_different_dtype(self):
set_seed(SEED)
model_id = MODEL_NAMES["llama"]
pt_model = AutoModelForCausalLM.from_pretrained(
model_id,
)
tokenizer = AutoTokenizer.from_pretrained(model_id)

texts = ["this is a simple input"]
test_input = tokenizer(texts, return_tensors="pt")

ref_logits = pt_model(**test_input).logits
torch_dtypes = [None, "auto", "float32", torch.float16]
if is_openvino_version(">", "2024.2.0"):
torch_dtypes.append("bfloat16")

for dtype in torch_dtypes:
ov_model = OVModelForCausalLM.from_pretrained(model_id=model_id, export=True, torch_dtype=dtype)
ov_logits = ov_model(**test_input).logits
self.assertTrue(
torch.allclose(torch.Tensor(ov_logits), ref_logits, atol=5e-3),
f"values are not close for {dtype if dtype is not None else 'None'}, max diff = {torch.abs(ov_logits - ref_logits).max()}",
)


class OVModelForMaskedLMIntegrationTest(unittest.TestCase):
SUPPORTED_ARCHITECTURES = (
Expand Down

0 comments on commit 6a2e507

Please sign in to comment.