diff --git a/.github/workflows/onpush.yml b/.github/workflows/onpush.yml new file mode 100644 index 0000000..c577489 --- /dev/null +++ b/.github/workflows/onpush.yml @@ -0,0 +1,36 @@ +name: Giza CI + +on: + pull_request: + types: [ opened, synchronize ] + push: + branches: [main] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.11"] + + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install poetry + poetry config virtualenvs.create false + poetry install --all-extras + - name: Lint with ruff + run: | + poetry run ruff auto_zkml + - name: Pre-commit check + run: | + poetry run pre-commit run --all-files + - name: Testing + run: | + poetry run pytest --cov=auto_zkml --cov-report term-missing diff --git a/.github/workflows/onrelease.yml b/.github/workflows/onrelease.yml new file mode 100644 index 0000000..82ade6c --- /dev/null +++ b/.github/workflows/onrelease.yml @@ -0,0 +1,40 @@ +name: release + +on: + push: + tags: + - 'v*' + +jobs: + release: + runs-on: ubuntu-latest + strategy: + max-parallel: 1 + matrix: + python-version: ["3.11"] + + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 0 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install poetry + poetry config virtualenvs.create false + poetry install + - name: Lint with ruff + run: | + poetry run ruff auto_zkml + - name: Build dist + run: poetry build + - name: Publish a Python distribution to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + with: + user: __token__ + password: ${{ secrets.GIZA_AUTOZKML_PYPI_TOKEN }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 065b910..fcf6496 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -35,8 +35,3 @@ repos: entry: ruff language: system files: "py$" - - id: mypy - name: mypy - entry: mypy - language: system - files: "py$" diff --git a/auto_zkml/__init__.py b/auto_zkml/__init__.py index c679b09..4f5a4a9 100644 --- a/auto_zkml/__init__.py +++ b/auto_zkml/__init__.py @@ -1,5 +1,5 @@ -from giza_mlutils.model_reducer import mcr -from giza_mlutils.serializer.serialize import serialize_model +from auto_zkml.model_reducer import mcr +from auto_zkml.serializer.serialize import serialize_model __all__ = ["mcr", "serialize_model"] diff --git a/auto_zkml/model_reducer.py b/auto_zkml/model_reducer.py index 3182412..7d6d898 100644 --- a/auto_zkml/model_reducer.py +++ b/auto_zkml/model_reducer.py @@ -1,12 +1,12 @@ from skopt import gp_minimize from skopt.utils import use_named_args -from giza_mlutils.model_toolkit.data_transformer import DataTransformer -from giza_mlutils.model_toolkit.feature_models_space import FeatureSpaceConstants -from giza_mlutils.model_toolkit.metrics import check_metric_optimization -from giza_mlutils.model_toolkit.model_evaluator import ModelEvaluator -from giza_mlutils.model_toolkit.model_info import ModelParameterExtractor -from giza_mlutils.model_toolkit.model_trainer import ModelTrainer +from auto_zkml.model_toolkit.data_transformer import DataTransformer +from auto_zkml.model_toolkit.feature_models_space import FeatureSpaceConstants +from auto_zkml.model_toolkit.metrics import check_metric_optimization +from auto_zkml.model_toolkit.model_evaluator import ModelEvaluator +from auto_zkml.model_toolkit.model_info import ModelParameterExtractor +from auto_zkml.model_toolkit.model_trainer import ModelTrainer def mcr(model, X_train, y_train, X_eval, y_eval, eval_metric, transform_features=False): diff --git a/auto_zkml/model_toolkit/data_transformer.py b/auto_zkml/model_toolkit/data_transformer.py index 3999d26..60971ba 100644 --- a/auto_zkml/model_toolkit/data_transformer.py +++ b/auto_zkml/model_toolkit/data_transformer.py @@ -2,8 +2,8 @@ from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler -from giza_mlutils.model_toolkit.custom_transformers.customPCA import CustomPCA -from giza_mlutils.model_toolkit.custom_transformers.customRFE import CustomRFE +from auto_zkml.model_toolkit.custom_transformers.customPCA import CustomPCA +from auto_zkml.model_toolkit.custom_transformers.customRFE import CustomRFE class DataTransformer(BaseEstimator, TransformerMixin): diff --git a/auto_zkml/model_toolkit/model_info.py b/auto_zkml/model_toolkit/model_info.py index 5759f3f..fe1f7db 100644 --- a/auto_zkml/model_toolkit/model_info.py +++ b/auto_zkml/model_toolkit/model_info.py @@ -2,7 +2,6 @@ class ModelParameterExtractor: - # TODO: creo que se puede cambiar por un método solo por tipo de algoritmo. def __init__(self): self.model_extractors_by_name = { "XGBRegressor": self.extract_params_from_xgb, diff --git a/auto_zkml/model_toolkit/model_trainer.py b/auto_zkml/model_toolkit/model_trainer.py index e63214b..19372cb 100644 --- a/auto_zkml/model_toolkit/model_trainer.py +++ b/auto_zkml/model_toolkit/model_trainer.py @@ -58,7 +58,6 @@ def train_catboost( params["eval_metric"] = eval_metric model = self.model_class(**params) - # Entrenamiento con parámetros específicos de fit model.fit(X_train, y_train, eval_set=[(X_eval, y_eval)]) return model diff --git a/auto_zkml/serializer/serialize.py b/auto_zkml/serializer/serialize.py index 64dc5f1..4938cf1 100644 --- a/auto_zkml/serializer/serialize.py +++ b/auto_zkml/serializer/serialize.py @@ -1,5 +1,5 @@ -from giza_mlutils.model_toolkit.model_info import ModelParameterExtractor -from giza_mlutils.serializer import lgbm, xg +from auto_zkml.model_toolkit.model_info import ModelParameterExtractor +from auto_zkml.serializer import lgbm, xg def serialize_model(model, output_path): diff --git a/poetry.lock b/poetry.lock index 17bf4e6..f6e2b85 100644 --- a/poetry.lock +++ b/poetry.lock @@ -201,6 +201,70 @@ traitlets = ">=4" [package.extras] test = ["pytest"] +[[package]] +name = "coverage" +version = "7.5.1" +description = "Code coverage measurement for Python" +optional = false +python-versions = ">=3.8" +files = [ + {file = "coverage-7.5.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:c0884920835a033b78d1c73b6d3bbcda8161a900f38a488829a83982925f6c2e"}, + {file = "coverage-7.5.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:39afcd3d4339329c5f58de48a52f6e4e50f6578dd6099961cf22228feb25f38f"}, + {file = "coverage-7.5.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4a7b0ceee8147444347da6a66be737c9d78f3353b0681715b668b72e79203e4a"}, + {file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a9ca3f2fae0088c3c71d743d85404cec8df9be818a005ea065495bedc33da35"}, + {file = "coverage-7.5.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5fd215c0c7d7aab005221608a3c2b46f58c0285a819565887ee0b718c052aa4e"}, + {file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:4bf0655ab60d754491004a5efd7f9cccefcc1081a74c9ef2da4735d6ee4a6223"}, + {file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:61c4bf1ba021817de12b813338c9be9f0ad5b1e781b9b340a6d29fc13e7c1b5e"}, + {file = "coverage-7.5.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:db66fc317a046556a96b453a58eced5024af4582a8dbdc0c23ca4dbc0d5b3146"}, + {file = "coverage-7.5.1-cp310-cp310-win32.whl", hash = "sha256:b016ea6b959d3b9556cb401c55a37547135a587db0115635a443b2ce8f1c7228"}, + {file = "coverage-7.5.1-cp310-cp310-win_amd64.whl", hash = "sha256:df4e745a81c110e7446b1cc8131bf986157770fa405fe90e15e850aaf7619bc8"}, + {file = "coverage-7.5.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:796a79f63eca8814ca3317a1ea443645c9ff0d18b188de470ed7ccd45ae79428"}, + {file = "coverage-7.5.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4fc84a37bfd98db31beae3c2748811a3fa72bf2007ff7902f68746d9757f3746"}, + {file = "coverage-7.5.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6175d1a0559986c6ee3f7fccfc4a90ecd12ba0a383dcc2da30c2b9918d67d8a3"}, + {file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1fc81d5878cd6274ce971e0a3a18a8803c3fe25457165314271cf78e3aae3aa2"}, + {file = "coverage-7.5.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:556cf1a7cbc8028cb60e1ff0be806be2eded2daf8129b8811c63e2b9a6c43bca"}, + {file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:9981706d300c18d8b220995ad22627647be11a4276721c10911e0e9fa44c83e8"}, + {file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d7fed867ee50edf1a0b4a11e8e5d0895150e572af1cd6d315d557758bfa9c057"}, + {file = "coverage-7.5.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:ef48e2707fb320c8f139424a596f5b69955a85b178f15af261bab871873bb987"}, + {file = "coverage-7.5.1-cp311-cp311-win32.whl", hash = "sha256:9314d5678dcc665330df5b69c1e726a0e49b27df0461c08ca12674bcc19ef136"}, + {file = "coverage-7.5.1-cp311-cp311-win_amd64.whl", hash = "sha256:5fa567e99765fe98f4e7d7394ce623e794d7cabb170f2ca2ac5a4174437e90dd"}, + {file = "coverage-7.5.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b6cf3764c030e5338e7f61f95bd21147963cf6aa16e09d2f74f1fa52013c1206"}, + {file = "coverage-7.5.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ec92012fefebee89a6b9c79bc39051a6cb3891d562b9270ab10ecfdadbc0c34"}, + {file = "coverage-7.5.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:16db7f26000a07efcf6aea00316f6ac57e7d9a96501e990a36f40c965ec7a95d"}, + {file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:beccf7b8a10b09c4ae543582c1319c6df47d78fd732f854ac68d518ee1fb97fa"}, + {file = "coverage-7.5.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8748731ad392d736cc9ccac03c9845b13bb07d020a33423fa5b3a36521ac6e4e"}, + {file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7352b9161b33fd0b643ccd1f21f3a3908daaddf414f1c6cb9d3a2fd618bf2572"}, + {file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:7a588d39e0925f6a2bff87154752481273cdb1736270642aeb3635cb9b4cad07"}, + {file = "coverage-7.5.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:68f962d9b72ce69ea8621f57551b2fa9c70509af757ee3b8105d4f51b92b41a7"}, + {file = "coverage-7.5.1-cp312-cp312-win32.whl", hash = "sha256:f152cbf5b88aaeb836127d920dd0f5e7edff5a66f10c079157306c4343d86c19"}, + {file = "coverage-7.5.1-cp312-cp312-win_amd64.whl", hash = "sha256:5a5740d1fb60ddf268a3811bcd353de34eb56dc24e8f52a7f05ee513b2d4f596"}, + {file = "coverage-7.5.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:e2213def81a50519d7cc56ed643c9e93e0247f5bbe0d1247d15fa520814a7cd7"}, + {file = "coverage-7.5.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5037f8fcc2a95b1f0e80585bd9d1ec31068a9bcb157d9750a172836e98bc7a90"}, + {file = "coverage-7.5.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3721c2c9e4c4953a41a26c14f4cef64330392a6d2d675c8b1db3b645e31f0e"}, + {file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca498687ca46a62ae590253fba634a1fe9836bc56f626852fb2720f334c9e4e5"}, + {file = "coverage-7.5.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0cdcbc320b14c3e5877ee79e649677cb7d89ef588852e9583e6b24c2e5072661"}, + {file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:57e0204b5b745594e5bc14b9b50006da722827f0b8c776949f1135677e88d0b8"}, + {file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8fe7502616b67b234482c3ce276ff26f39ffe88adca2acf0261df4b8454668b4"}, + {file = "coverage-7.5.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:9e78295f4144f9dacfed4f92935fbe1780021247c2fabf73a819b17f0ccfff8d"}, + {file = "coverage-7.5.1-cp38-cp38-win32.whl", hash = "sha256:1434e088b41594baa71188a17533083eabf5609e8e72f16ce8c186001e6b8c41"}, + {file = "coverage-7.5.1-cp38-cp38-win_amd64.whl", hash = "sha256:0646599e9b139988b63704d704af8e8df7fa4cbc4a1f33df69d97f36cb0a38de"}, + {file = "coverage-7.5.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4cc37def103a2725bc672f84bd939a6fe4522310503207aae4d56351644682f1"}, + {file = "coverage-7.5.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:fc0b4d8bfeabd25ea75e94632f5b6e047eef8adaed0c2161ada1e922e7f7cece"}, + {file = "coverage-7.5.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0d0a0f5e06881ecedfe6f3dd2f56dcb057b6dbeb3327fd32d4b12854df36bf26"}, + {file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9735317685ba6ec7e3754798c8871c2f49aa5e687cc794a0b1d284b2389d1bd5"}, + {file = "coverage-7.5.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d21918e9ef11edf36764b93101e2ae8cc82aa5efdc7c5a4e9c6c35a48496d601"}, + {file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c3e757949f268364b96ca894b4c342b41dc6f8f8b66c37878aacef5930db61be"}, + {file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:79afb6197e2f7f60c4824dd4b2d4c2ec5801ceb6ba9ce5d2c3080e5660d51a4f"}, + {file = "coverage-7.5.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:d1d0d98d95dd18fe29dc66808e1accf59f037d5716f86a501fc0256455219668"}, + {file = "coverage-7.5.1-cp39-cp39-win32.whl", hash = "sha256:1cc0fe9b0b3a8364093c53b0b4c0c2dd4bb23acbec4c9240b5f284095ccf7981"}, + {file = "coverage-7.5.1-cp39-cp39-win_amd64.whl", hash = "sha256:dde0070c40ea8bb3641e811c1cfbf18e265d024deff6de52c5950677a8fb1e0f"}, + {file = "coverage-7.5.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:6537e7c10cc47c595828b8a8be04c72144725c383c4702703ff4e42e44577312"}, + {file = "coverage-7.5.1.tar.gz", hash = "sha256:54de9ef3a9da981f7af93eafde4ede199e0846cd819eb27c88e2b712aae9708c"}, +] + +[package.extras] +toml = ["tomli"] + [[package]] name = "debugpy" version = "1.8.1" @@ -1006,6 +1070,24 @@ pluggy = ">=0.12,<2.0" [package.extras] testing = ["argcomplete", "attrs (>=19.2.0)", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] +[[package]] +name = "pytest-cov" +version = "5.0.0" +description = "Pytest plugin for measuring coverage." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest-cov-5.0.0.tar.gz", hash = "sha256:5837b58e9f6ebd335b0f8060eecce69b662415b16dc503883a02f45dfeb14857"}, + {file = "pytest_cov-5.0.0-py3-none-any.whl", hash = "sha256:4f0764a1219df53214206bf1feea4633c3b558a2925c8b59f144f682861ce652"}, +] + +[package.dependencies] +coverage = {version = ">=5.2.1", extras = ["toml"]} +pytest = ">=4.6" + +[package.extras] +testing = ["fields", "hunter", "process-tests", "pytest-xdist", "virtualenv"] + [[package]] name = "python-dateutil" version = "2.9.0.post0" @@ -1514,4 +1596,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.11" -content-hash = "722eb26fa0e06734b02f7467afa40b6280d0fbe455fb2b67802029309aeba821" +content-hash = "2b903317846656bb8488e3d0457b92603af7eaa96ca5ddec0fc20136ad46c210" diff --git a/pyproject.toml b/pyproject.toml index dc090ec..794b8ed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,6 +23,7 @@ ruff = "^0.1.11" black = "^23.12.1" isort = "^5.13.2" ipykernel = "^6.29.0" +pytest-cov = "^5.0.0" [build-system] requires = ["poetry-core"] diff --git a/tutorials/serialize_my_model.ipynb b/tutorials/serialize_my_model.ipynb index 5f92f82..284986c 100644 --- a/tutorials/serialize_my_model.ipynb +++ b/tutorials/serialize_my_model.ipynb @@ -6,7 +6,7 @@ "source": [ "## How to serialize my model\n", "\n", - "Giza_mlutils offers various functionalities that help us have a model with the necessary characteristics to be transpilable, and therefore, able to generate proofs of its inferences.\n", + "auto_zkml offers various functionalities that help us have a model with the necessary characteristics to be transpilable, and therefore, able to generate proofs of its inferences.\n", "In this case, we will talk about the serialization process, which involves saving your model in a format that can be interpreted by other Giza tools.\n", "\n", "Currently, the two supported models are XGBoost and LightGBM for both classification and regression. It is preferable that the training is done using the scikit-learn API.\n", @@ -123,7 +123,7 @@ "metadata": {}, "source": [ "That simple! We now have our models saved in the correct format to use the rest of the Giza stack! But not so fast...\n", - "In this example, the models are very simple (few trees and shallow depth), but for other problems, the optimal architecture might be much more complex and not compatible with our current technology. In this case, we will have to use another of the functionalities offered by Giza_mlutils beforehand: our model_complexity_reducer.\n", + "In this example, the models are very simple (few trees and shallow depth), but for other problems, the optimal architecture might be much more complex and not compatible with our current technology. In this case, we will have to use another of the functionalities offered by auto_zkml beforehand: our model_complexity_reducer.\n", "\n", "To understand how the model_complexity_reducer (mcr) works, in this same folder you will find the notebook reduce_model_complexity.ipynb with a detailed explanation of its operation and how to run it before serializing your model." ]