Skip to content

Commit

Permalink
loading generation config if it is part of model
Browse files Browse the repository at this point in the history
  • Loading branch information
eaidova committed Jun 6, 2024
1 parent dad18d1 commit ce8d1bf
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 2 deletions.
19 changes: 18 additions & 1 deletion optimum/intel/openvino/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ def __init__(

self.model = model
self.request = None
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
if self.can_generate():
self.generation_config = kwargs.get("generation_config", GenerationConfig.from_model_config(config))
else:
self.generation_config = None

self._openvino_config = None
if quantization_config:
Expand Down Expand Up @@ -240,6 +243,20 @@ def _from_pretrained(
quantization_config = cls._prepare_weight_quantization_config(quantization_config, load_in_8bit)

model = cls.load_model(model_cache_path, quantization_config=quantization_config)

try:
generation_config = GenerationConfig.from_pretrained(
model_id,
token=token,
revision=revision,
subfolder=subfolder,
force_download=force_download,
cache_dir=cache_dir
)
kwargs["generation_config"] = generation_config
except Exception:
pass

return cls(
model,
config=config,
Expand Down
18 changes: 17 additions & 1 deletion optimum/intel/openvino/modeling_base_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,10 @@ def __init__(
self.encoder_model = encoder
self.decoder_model = decoder
self.decoder_with_past_model = decoder_with_past
self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
if self.can_generate():
self.generation_config = kwargs.get("generation_config", GenerationConfig.from_model_config(config))
else:
self.generation_config = None
self._openvino_config = None
if quantization_config:
self._openvino_config = OVConfig(quantization_config=quantization_config)
Expand Down Expand Up @@ -218,6 +221,19 @@ def _from_pretrained(
if use_cache:
decoder_with_past = cls.load_model(file_names["decoder_with_past"], quantization_config)

try:
generation_config = GenerationConfig.from_pretrained(
model_id,
token=token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only
)
kwargs["generation_config"] = generation_config
except Exception:
pass

return cls(
encoder=encoder,
decoder=decoder,
Expand Down
12 changes: 12 additions & 0 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -763,6 +763,18 @@ def _from_pretrained(
init_cls = cls

enable_compilation = kwargs.pop("compile", True) and not load_in_4bit
try:
generation_config = GenerationConfig.from_pretrained(
model_id,
token=token,
revision=revision,
cache_dir=cache_dir,
force_download=force_download,
local_files_only=local_files_only
)
kwargs["generation_config"] = generation_config
except Exception:
pass
causal_model = init_cls(
model=model,
config=config,
Expand Down

0 comments on commit ce8d1bf

Please sign in to comment.