diff --git a/bentoml/_internal/frameworks/keras.py b/bentoml/_internal/frameworks/keras.py index 33a56688bbe..9f6e8b4873c 100644 --- a/bentoml/_internal/frameworks/keras.py +++ b/bentoml/_internal/frameworks/keras.py @@ -54,14 +54,6 @@ class KerasOptions(ModelOptions): include_optimizer: bool partial_kwargs: t.Dict[str, t.Any] = attr.field(factory=dict) - @classmethod - def with_options(cls, **kwargs: t.Any) -> ModelOptions: - return cls(**kwargs) - - @staticmethod - def to_dict(options: ModelOptions) -> dict[str, t.Any]: - return attr.asdict(options) - def get(tag_like: str | Tag) -> bentoml.Model: model = bentoml.models.get(tag_like) @@ -256,9 +248,7 @@ def get_runnable( Private API: use :obj:`~bentoml.Model.to_runnable` instead. """ - partial_kwargs: t.Dict[str, t.Any] = bento_model.info.options.get( - "partial_kwargs", dict() - ) + partial_kwargs: t.Dict[str, t.Any] = bento_model.info.options.partial_kwargs # type: ignore class KerasRunnable(Runnable): SUPPORT_NVIDIA_GPU = True diff --git a/bentoml/_internal/frameworks/tensorflow_v2.py b/bentoml/_internal/frameworks/tensorflow_v2.py index 938507dab94..3127c96e045 100644 --- a/bentoml/_internal/frameworks/tensorflow_v2.py +++ b/bentoml/_internal/frameworks/tensorflow_v2.py @@ -1,6 +1,5 @@ from __future__ import annotations -import os import pickle import typing as t import logging @@ -9,10 +8,13 @@ import contextlib from typing import TYPE_CHECKING +import attr + import bentoml from bentoml import Tag from bentoml import Runnable from bentoml.models import ModelContext +from bentoml.models import ModelOptions from bentoml.exceptions import NotFound from bentoml.exceptions import MissingDependencyException @@ -49,6 +51,13 @@ API_VERSION = "v1" +@attr.define +class TensorflowOptions(ModelOptions): + """Options for the Keras model.""" + + partial_kwargs: t.Dict[str, t.Any] = attr.field(factory=dict) + + def get(tag_like: str | Tag) -> bentoml.Model: model = bentoml.models.get(tag_like) if model.info.module not in (MODULE_NAME, __name__): @@ -219,9 +228,7 @@ def get_runnable( Private API: use :obj:`~bentoml.Model.to_runnable` instead. """ - partial_kwargs: t.Dict[str, t.Any] = bento_model.info.options.get( - "partial_kwargs", dict() - ) + partial_kwargs: t.Dict[str, t.Any] = bento_model.info.options.partial_kwargs class TensorflowRunnable(Runnable): SUPPORT_NVIDIA_GPU = True diff --git a/bentoml/_internal/frameworks/transformers.py b/bentoml/_internal/frameworks/transformers.py index cf4edbea3cb..24eb9232b48 100644 --- a/bentoml/_internal/frameworks/transformers.py +++ b/bentoml/_internal/frameworks/transformers.py @@ -91,20 +91,8 @@ class TransformersOptions(ModelOptions): ] ) - pipeline: bool = attr.field( - default=True, validator=attr.validators.instance_of(bool) - ) - kwargs: t.Dict[str, t.Any] = attr.field(factory=dict) - @classmethod - def with_options(cls, **kwargs: t.Any) -> ModelOptions: - return cls(**kwargs) - - @staticmethod - def to_dict(options: ModelOptions) -> dict[str, t.Any]: - return attr.asdict(options) - def get(tag_like: str | Tag) -> Model: model = bentoml.models.get(tag_like) @@ -146,8 +134,6 @@ def load_model( f"Model {bento_model.tag} was saved with module {bento_model.info.module}, failed loading with {MODULE_NAME}." ) - bento_model.info.parse_options(TransformersOptions) - pipeline_task: str = bento_model.info.options.task # type: ignore pipeline_kwargs: t.Dict[str, t.Any] = bento_model.info.options.kwargs # type: ignore pipeline_kwargs.update(kwargs) diff --git a/bentoml/_internal/models/model.py b/bentoml/_internal/models/model.py index 3a7908ca3da..ecd70c646fe 100644 --- a/bentoml/_internal/models/model.py +++ b/bentoml/_internal/models/model.py @@ -9,7 +9,6 @@ from typing import TYPE_CHECKING from datetime import datetime from datetime import timezone -from collections import UserDict import fs import attr @@ -57,25 +56,17 @@ class ModelSignatureDict(t.TypedDict, total=False): CUSTOM_OBJECTS_FILENAME = "custom_objects.pkl" -if TYPE_CHECKING: - ModelOptionsSuper = UserDict[str, t.Any] -else: - ModelOptionsSuper = UserDict - - -class ModelOptions(ModelOptionsSuper): - @classmethod - def with_options(cls, **kwargs: t.Any) -> ModelOptions: - return cls(**kwargs) +@attr.define +class ModelOptions: + def with_options(self, **kwargs: t.Any) -> "ModelOptions": + return attr.evolve(self, **kwargs) @staticmethod def to_dict(options: ModelOptions) -> dict[str, t.Any]: - return dict(options) + return attr.asdict(options) -bentoml_cattr.register_structure_hook_func( - lambda cls: issubclass(cls, ModelOptions), lambda d, cls: cls.with_options(**d) # type: ignore -) +bentoml_cattr.register_structure_hook(ModelOptions, lambda d, cls: cls(**d)) bentoml_cattr.register_unstructure_hook(ModelOptions, lambda v: v.to_dict(v)) # type: ignore # pylint: disable=unnecessary-lambda # lambda required @@ -545,9 +536,6 @@ def with_options(self, **kwargs: t.Any) -> ModelInfo: def to_dict(self) -> t.Dict[str, t.Any]: return bentoml_cattr.unstructure(self) # type: ignore (incomplete cattr types) - def parse_options(self, options_class: type[ModelOptions]) -> None: - object.__setattr__(self, "options", options_class.with_options(**self.options)) - @overload def dump(self, stream: io.StringIO) -> io.BytesIO: ... @@ -588,6 +576,20 @@ def from_yaml_file(stream: t.IO[t.Any]): del yaml_content["context"]["pip_dependencies"] yaml_content["context"]["framework_versions"] = {} + # register hook for model options + module_name: str = yaml_content["module"] + try: + module = importlib.import_module(module_name) + except (ValueError, ModuleNotFoundError) as e: + raise BentoMLException( + f"Module '{module_name}' defined in {MODEL_YAML_FILENAME} is not found." + ) from e + if hasattr(module, "ModelOptions"): + bentoml_cattr.register_structure_hook( + ModelOptions, + lambda d, _: module.ModelOptions(**d), + ) + try: model_info = bentoml_cattr.structure(yaml_content, ModelInfo) except TypeError as e: # pragma: no cover - simple error handling diff --git a/bentoml/keras.py b/bentoml/keras.py index 11c90e0dd7c..5ae10f78137 100644 --- a/bentoml/keras.py +++ b/bentoml/keras.py @@ -4,6 +4,7 @@ from ._internal.frameworks.keras import load_model from ._internal.frameworks.keras import save_model from ._internal.frameworks.keras import get_runnable +from ._internal.frameworks.keras import KerasOptions as ModelOptions # type: ignore # noqa logger = logging.getLogger(__name__) diff --git a/bentoml/tensorflow.py b/bentoml/tensorflow.py index 8a2c6b24875..77906949744 100644 --- a/bentoml/tensorflow.py +++ b/bentoml/tensorflow.py @@ -4,6 +4,7 @@ from ._internal.frameworks.tensorflow_v2 import load_model from ._internal.frameworks.tensorflow_v2 import save_model from ._internal.frameworks.tensorflow_v2 import get_runnable +from ._internal.frameworks.tensorflow_v2 import TensorflowOptions as ModelOptions # type: ignore # noqa logger = logging.getLogger(__name__) diff --git a/bentoml/transformers.py b/bentoml/transformers.py index abe0e9f9275..c2971a56bbb 100644 --- a/bentoml/transformers.py +++ b/bentoml/transformers.py @@ -4,6 +4,7 @@ from ._internal.frameworks.transformers import load_model from ._internal.frameworks.transformers import save_model from ._internal.frameworks.transformers import get_runnable +from ._internal.frameworks.transformers import TransformersOptions as ModelOptions # type: ignore # noqa logger = logging.getLogger(__name__) diff --git a/tests/conftest.py b/tests/conftest.py index 43e0242a773..be94e4052ec 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -22,15 +22,27 @@ def fixture_change_test_dir(request: pytest.FixtureRequest): def fixture_dummy_model_store(tmpdir_factory: "pytest.TempPathFactory") -> ModelStore: store = ModelStore(tmpdir_factory.mktemp("models")) with bentoml.models.create( - "testmodel", signatures={}, context=TEST_MODEL_CONTEXT, _model_store=store + "testmodel", + module=__name__, + signatures={}, + context=TEST_MODEL_CONTEXT, + _model_store=store, ): pass with bentoml.models.create( - "testmodel", signatures={}, context=TEST_MODEL_CONTEXT, _model_store=store + "testmodel", + module=__name__, + signatures={}, + context=TEST_MODEL_CONTEXT, + _model_store=store, ): pass with bentoml.models.create( - "anothermodel", signatures={}, context=TEST_MODEL_CONTEXT, _model_store=store + "anothermodel", + module=__name__, + signatures={}, + context=TEST_MODEL_CONTEXT, + _model_store=store, ): pass diff --git a/tests/integration/frameworks/test_transformers_impl.py b/tests/integration/frameworks/test_transformers_impl.py index 53c9b923811..1fa6912775c 100644 --- a/tests/integration/frameworks/test_transformers_impl.py +++ b/tests/integration/frameworks/test_transformers_impl.py @@ -39,36 +39,36 @@ def pt_gpt2_pipeline(): ( "text-generation", transformers.pipeline(task="text-generation"), # type: ignore - {"pipeline": True, "task": "text-generation"}, - {"pipeline": True, "task": "text-generation"}, + {}, + {"task": "text-generation", "kwargs": {}}, "A Bento box is a ", ), ( "text-generation", transformers.pipeline(task="text-generation"), # type: ignore - {"pipeline": True, "task": "text-generation", "kwargs": {"a": 1}}, - {"pipeline": True, "task": "text-generation", "kwargs": {"a": 1}}, + {"kwargs": {"a": 1}}, + {"task": "text-generation", "kwargs": {"a": 1}}, "A Bento box is a ", ), ( "text-generation", tf_gpt2_pipeline(), - {"pipeline": True, "task": "text-generation"}, - {"pipeline": True, "task": "text-generation"}, + {}, + {"task": "text-generation", "kwargs": {}}, "A Bento box is a ", ), ( "text-generation", pt_gpt2_pipeline(), - {"pipeline": True, "task": "text-generation"}, - {"pipeline": True, "task": "text-generation"}, + {}, + {"task": "text-generation", "kwargs": {}}, "A Bento box is a ", ), ( "image-classification", transformers.pipeline("image-classification"), # type: ignore - {"pipeline": True, "task": "image-classification"}, - {"pipeline": True, "task": "image-classification"}, + {}, + {"task": "image-classification", "kwargs": {}}, Image.open( requests.get( "http://images.cocodataset.org/val2017/000000039769.jpg", @@ -79,8 +79,8 @@ def pt_gpt2_pipeline(): ( "text-classification", transformers.pipeline("text-classification"), # type: ignore - {"pipeline": True, "task": "text-classification"}, - {"pipeline": True, "task": "text-classification"}, + {}, + {"task": "text-classification", "kwargs": {}}, "BentoML is an awesome library for machine learning.", ), ], @@ -101,7 +101,8 @@ def test_transformers( ) assert bento_model.tag == tag assert bento_model.info.context.framework_name == "transformers" - assert dict(bento_model.info.options) == expected_options + assert bento_model.info.options.task == expected_options["task"] # type: ignore + assert bento_model.info.options.kwargs == expected_options["kwargs"] # type: ignore runnable: bentoml.Runnable = bentoml.transformers.get_runnable(bento_model)() output_data = runnable(input_data) # type: ignore diff --git a/tests/unit/_internal/models/test_model.py b/tests/unit/_internal/models/test_model.py index f61b230b951..2b7672d9b99 100644 --- a/tests/unit/_internal/models/test_model.py +++ b/tests/unit/_internal/models/test_model.py @@ -7,6 +7,7 @@ from datetime import timezone import fs +import attr import numpy as np import pytest import fs.errors @@ -14,7 +15,7 @@ from bentoml import Tag from bentoml.exceptions import BentoMLException from bentoml._internal.models import ModelContext -from bentoml._internal.models import ModelOptions +from bentoml._internal.models import ModelOptions as InternalModelOptions from bentoml._internal.models.model import Model from bentoml._internal.models.model import ModelInfo from bentoml._internal.models.model import ModelStore @@ -32,7 +33,7 @@ expected_yaml = """\ name: test version: v1 -module: testmodule +module: test_model labels: label: stringvalue options: @@ -77,12 +78,16 @@ """ -class TestModelOption(ModelOptions): +@attr.define +class TestModelOptions(InternalModelOptions): option_a: int option_b: str option_c: list[float] +ModelOptions = TestModelOptions + + def test_model_info(tmpdir: "Path"): start = datetime.now(timezone.utc) modelinfo_a = ModelInfo( @@ -90,7 +95,7 @@ def test_model_info(tmpdir: "Path"): module="module", api_version="v1", labels={}, - options=ModelOptions(), + options=TestModelOptions(option_a=42, option_b="foo", option_c=[0.1, 0.2]), metadata={}, context=TEST_MODEL_CONTEXT, signatures={"predict": {"batchable": True}}, @@ -102,9 +107,9 @@ def test_model_info(tmpdir: "Path"): assert start <= modelinfo_a.creation_time <= end tag = Tag("test", "v1") - module = "testmodule" + module = __name__ labels = {"label": "stringvalue"} - options = TestModelOption(option_a=1, option_b="foo", option_c=[0.1, 0.2]) + options = TestModelOptions(option_a=1, option_b="foo", option_c=[0.1, 0.2]) metadata = {"a": 0.1, "b": 1, "c": np.array([2, 3, 4], dtype=np.uint32)} # TODO: add test cases for input_spec and output_spec signatures = { @@ -183,10 +188,11 @@ def __call__(self, y: int) -> int: def fixture_bento_model(): model = Model.create( "testmodel", - module="foo", + module=__name__, api_version="v1", signatures={}, context=TEST_MODEL_CONTEXT, + options=TestModelOptions(option_a=1, option_b="foo", option_c=[0.1, 0.2]), custom_objects={ "add": AdditionClass(add_num_1), }, diff --git a/tests/unit/test_models.py b/tests/unit/test_models.py index ead110cb219..91727647c40 100644 --- a/tests/unit/test_models.py +++ b/tests/unit/test_models.py @@ -41,19 +41,31 @@ def test_models(tmpdir: "Path"): store = ModelStore(os.path.join(tmpdir, "models")) with bentoml.models.create( - "testmodel", signatures={}, context=TEST_MODEL_CONTEXT, _model_store=store + "testmodel", + module=__name__, + signatures={}, + context=TEST_MODEL_CONTEXT, + _model_store=store, ) as testmodel: testmodel1tag = testmodel.tag with bentoml.models.create( - "testmodel", signatures={}, context=TEST_MODEL_CONTEXT, _model_store=store + "testmodel", + module=__name__, + signatures={}, + context=TEST_MODEL_CONTEXT, + _model_store=store, ) as testmodel: testmodel2tag = testmodel.tag testmodel_file_content = createfile(testmodel.path_of("file")) testmodel_infolder_content = createfile(testmodel.path_of("folder/file")) with bentoml.models.create( - "anothermodel", signatures={}, context=TEST_MODEL_CONTEXT, _model_store=store + "anothermodel", + module=__name__, + signatures={}, + context=TEST_MODEL_CONTEXT, + _model_store=store, ) as anothermodel: anothermodeltag = anothermodel.tag anothermodel_file_content = createfile(anothermodel.path_of("file"))