From 084c9e244537b10497f14d4e34e3e10bfca23f57 Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Mon, 19 Feb 2024 22:01:07 +0100 Subject: [PATCH 1/6] Add .pre-commit-config.yaml --- .pre-commit-config.yaml | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..7a141fc --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,23 @@ +ci: + autofix_prs: false + +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: debug-statements + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer + - id: no-commit-to-branch + args: [--branch, main] + - id: trailing-whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.2.2 + hooks: + - id: ruff + args: ["--fix", "--output-format=full"] + - id: ruff-format From 2d303cc79e06608abbeef0fe907d77dc650e68f5 Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Mon, 19 Feb 2024 22:03:30 +0100 Subject: [PATCH 2/6] Lint files --- notebooks/pytensor_logp.md | 30 +++++++++++++++--------------- pyproject.toml | 1 - 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/notebooks/pytensor_logp.md b/notebooks/pytensor_logp.md index e115fea..5cab8dc 100644 --- a/notebooks/pytensor_logp.md +++ b/notebooks/pytensor_logp.md @@ -120,24 +120,24 @@ plt.axvline(0, color="grey", alpha=0.5, zorder=-100) %%file radon_model.stan data { int n_counties; - int n_observed; + int n_observed; array[n_observed] int county_idx; vector[n_observed] is_floor; vector[n_observed] log_radon; -} +} parameters { real intercept; - + vector[n_counties] county_raw; real county_sd; - + real floor_effect; - + vector[n_counties] county_floor_raw; real county_floor_sd; - + real sigma; -} +} transformed parameters { vector[n_counties] county_effect; vector[n_counties] county_floor_effect; @@ -155,17 +155,17 @@ transformed parameters { } model { intercept ~ normal(0, 10); - + county_raw ~ normal(0, 1); county_sd ~ normal(0, 1); - + floor_effect ~ normal(0, 2); - + county_floor_raw ~ normal(0, 1); county_floor_sd ~ normal(0, 1); - + sigma ~ normal(0, 1.5); - + log_radon ~ normal(mu, sigma); } ``` @@ -305,13 +305,13 @@ for name, val in data_stan.items(): if isinstance(val, int): data_json[name] = int(val) continue - + if val.dtype == np.int64: data_json[name] = list(int(x) for x in val) continue - + data_json[name] = list(val) - + with open("radon.json", "w") as file: json.dump(data_json, file) ``` diff --git a/pyproject.toml b/pyproject.toml index b1a6402..b2f3bce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,4 +36,3 @@ all = [ "pymc >= 5.5.0", "numba >= 0.57.1", ] - From 541100294679e28e02b4d768af3e916f9c139854 Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Mon, 19 Feb 2024 22:14:26 +0100 Subject: [PATCH 3/6] Simple linting pass --- python/nutpie/compile_pymc.py | 12 ++++++------ python/nutpie/compile_stan.py | 16 ++++++++-------- python/nutpie/sample.py | 6 +++--- tests/test_pymc.py | 6 +++--- tests/test_stan.py | 4 ++-- 5 files changed, 22 insertions(+), 22 deletions(-) diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index e78ed61..30e6f1a 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from importlib.util import find_spec from math import prod -from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple +from typing import TYPE_CHECKING, Any, Optional import numpy as np import pandas as pd @@ -42,15 +42,15 @@ def codegen(cgctx, builder, sig, args): class CompiledPyMCModel(CompiledModel): compiled_logp_func: "numba.core.ccallback.CFunc" compiled_expand_func: "numba.core.ccallback.CFunc" - shared_data: Dict[str, NDArray] + shared_data: dict[str, NDArray] user_data: NDArray n_expanded: int shape_info: Any logp_func: Any expand_func: Any _n_dim: int - _shapes: Dict[str, Tuple[int, ...]] - _coords: Optional[Dict[str, Any]] + _shapes: dict[str, tuple[int, ...]] + _coords: Optional[dict[str, Any]] @property def n_dim(self): @@ -517,7 +517,7 @@ def logp_numba(dim, x_, out_, logp_, user_data_): return 4 # if np.any(out == 0): # return 4 - except Exception: + except Exception: # noqa: BLE001 return 1 return 0 @@ -552,7 +552,7 @@ def expand_numba(dim, expanded, x_, out_, user_data_): (values,) = extract(x, user_data_) out[...] = values - except Exception: + except Exception: # noqa: BLE001 return -2 return 0 diff --git a/python/nutpie/compile_stan.py b/python/nutpie/compile_stan.py index ad413c9..b4d8bf5 100644 --- a/python/nutpie/compile_stan.py +++ b/python/nutpie/compile_stan.py @@ -3,7 +3,7 @@ import tempfile from dataclasses import dataclass, replace from importlib.util import find_spec -from typing import Any, Dict, List, Optional +from typing import Any, Optional import numpy as np import pandas as pd @@ -22,9 +22,9 @@ def default(self, obj): @dataclass(frozen=True) class CompiledStanModel(CompiledModel): - _coords: Optional[Dict[str, Any]] + _coords: Optional[dict[str, Any]] code: str - data: Optional[Dict[str, NDArray]] + data: Optional[dict[str, NDArray]] library: Any model: Any model_name: Optional[str] = None @@ -107,10 +107,10 @@ def compile_stan_model( *, code: Optional[str] = None, filename: Optional[str] = None, - extra_compile_args: Optional[List[str]] = None, - extra_stanc_args: Optional[List[str]] = None, - dims: Optional[Dict[str, int]] = None, - coords: Optional[Dict[str, Any]] = None, + extra_compile_args: Optional[list[str]] = None, + extra_stanc_args: Optional[list[str]] = None, + dims: Optional[dict[str, int]] = None, + coords: Optional[dict[str, Any]] = None, model_name: Optional[str] = None, cleanup: bool = True, ) -> CompiledStanModel: @@ -164,7 +164,7 @@ def compile_stan_model( try: if cleanup: basedir.cleanup() - except Exception: + except Exception: # noqa: BLE001 pass return CompiledStanModel( diff --git a/python/nutpie/sample.py b/python/nutpie/sample.py index b886b7d..4a8d270 100644 --- a/python/nutpie/sample.py +++ b/python/nutpie/sample.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Dict, Optional, Tuple +from typing import Optional import arviz import fastprogress @@ -12,14 +12,14 @@ @dataclass(frozen=True) class CompiledModel: - dims: Optional[Dict[str, Tuple[str, ...]]] + dims: Optional[dict[str, tuple[str, ...]]] @property def n_dim(self) -> int: raise NotImplementedError() @property - def shapes(self) -> Optional[Dict[str, Tuple[int, ...]]]: + def shapes(self) -> Optional[dict[str, tuple[int, ...]]]: raise NotImplementedError() @property diff --git a/tests/test_pymc.py b/tests/test_pymc.py index 80b504a..101f1a0 100644 --- a/tests/test_pymc.py +++ b/tests/test_pymc.py @@ -12,7 +12,7 @@ def test_pymc_model(): compiled = nutpie.compile_pymc_model(model) trace = nutpie.sample(compiled, chains=1) - trace.posterior.a + trace.posterior.a # noqa: B018 def test_pymc_model_with_coordinate(): @@ -22,7 +22,7 @@ def test_pymc_model_with_coordinate(): compiled = nutpie.compile_pymc_model(model) trace = nutpie.sample(compiled, chains=1) - trace.posterior.a + trace.posterior.a # noqa: B018 def test_trafo(): @@ -31,7 +31,7 @@ def test_trafo(): compiled = nutpie.compile_pymc_model(model) trace = nutpie.sample(compiled, chains=1) - trace.posterior.a + trace.posterior.a # noqa: B018 def test_det(): diff --git a/tests/test_stan.py b/tests/test_stan.py index 39c4411..d1d0f42 100644 --- a/tests/test_stan.py +++ b/tests/test_stan.py @@ -17,7 +17,7 @@ def test_stan_model(): compiled_model = nutpie.compile_stan_model(code=model) trace = nutpie.sample(compiled_model) - trace.posterior.a + trace.posterior.a # noqa: B018 def test_stan_model_data(): @@ -37,4 +37,4 @@ def test_stan_model_data(): with pytest.raises(RuntimeError): trace = nutpie.sample(compiled_model) trace = nutpie.sample(compiled_model.with_data(x=np.array(3.0))) - trace.posterior.a + trace.posterior.a # noqa: B018 From d28cc4e3de5e43017cd2425445bd92e3ecf21c6b Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Mon, 19 Feb 2024 22:29:46 +0100 Subject: [PATCH 4/6] Apply ruff unsafe fixes --- python/nutpie/__init__.py | 5 ++--- python/nutpie/compile_pymc.py | 2 +- python/nutpie/compile_stan.py | 2 +- python/nutpie/sample.py | 10 +++++----- 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/python/nutpie/__init__.py b/python/nutpie/__init__.py index 2e02550..980b7e5 100644 --- a/python/nutpie/__init__.py +++ b/python/nutpie/__init__.py @@ -1,8 +1,7 @@ from nutpie import _lib +from nutpie.compile_pymc import compile_pymc_model +from nutpie.compile_stan import compile_stan_model from nutpie.sample import sample -from .compile_pymc import compile_pymc_model -from .compile_stan import compile_stan_model - __version__: str = _lib.__version__ __all__ = ["__version__", "sample", "compile_pymc_model", "compile_stan_model"] diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index 30e6f1a..e8d3016 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -269,7 +269,7 @@ def _compute_shapes(model): mode=pytensor.compile.mode.FAST_COMPILE, on_unused_input="ignore", ) - return {name: shape for name, shape in zip(trace_vars.keys(), shape_func())} + return dict(zip(trace_vars.keys(), shape_func())) def _make_functions(model): diff --git a/python/nutpie/compile_stan.py b/python/nutpie/compile_stan.py index b4d8bf5..25f7e92 100644 --- a/python/nutpie/compile_stan.py +++ b/python/nutpie/compile_stan.py @@ -133,7 +133,7 @@ def compile_stan_model( if code is None: if filename is None: raise ValueError("Either code or filename have to be specified") - with open(filename, "r") as file: + with open(filename) as file: code = file.read() if model_name is None: diff --git a/python/nutpie/sample.py b/python/nutpie/sample.py index 4a8d270..3cbcf96 100644 --- a/python/nutpie/sample.py +++ b/python/nutpie/sample.py @@ -76,10 +76,10 @@ def _trace_to_arviz(traces, n_tune, shapes, **kwargs): dtype = col.chunks[0].values.to_numpy().dtype if dtype in [np.float64, np.float32]: data = np.full( - (n_chains, length) + tuple(shapes[name]), np.nan, dtype=dtype + (n_chains, length, *tuple(shapes[name])), np.nan, dtype=dtype ) else: - data = np.zeros((n_chains, length) + tuple(shapes[name]), dtype=dtype) + data = np.zeros((n_chains, length, *tuple(shapes[name])), dtype=dtype) for i, chunk in enumerate(col.chunks): data[i, : len(chunk)] = chunk.values.to_numpy().reshape( (len(chunk),) + shapes[name] @@ -103,16 +103,16 @@ def _trace_to_arviz(traces, n_tune, shapes, **kwargs): length = max(lengths) if dtype in [np.float64, np.float32]: - data = np.full((n_chains, length) + last_shape, np.nan, dtype=dtype) + data = np.full((n_chains, length, *last_shape), np.nan, dtype=dtype) else: - data = np.zeros((n_chains, length) + last_shape, dtype=dtype) + data = np.zeros((n_chains, length, *last_shape), dtype=dtype) for i, chunk in enumerate(col.chunks): if hasattr(chunk, "values"): values = chunk.values.to_numpy(False) else: values = chunk.to_numpy(False) - data[i, : len(chunk)] = values.reshape((len(chunk),) + last_shape) + data[i, : len(chunk)] = values.reshape((len(chunk), *last_shape)) stats_dict[name] = data[:, n_tune:] stats_dict_tune[name] = data[:, :n_tune] From cbcb817e9e3abe8ee1eece0fc4f2a9b0486b9383 Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Mon, 19 Feb 2024 22:34:59 +0100 Subject: [PATCH 5/6] Use pathlib.Path --- python/nutpie/compile_stan.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/nutpie/compile_stan.py b/python/nutpie/compile_stan.py index 25f7e92..2d0aef4 100644 --- a/python/nutpie/compile_stan.py +++ b/python/nutpie/compile_stan.py @@ -1,8 +1,8 @@ import json -import pathlib import tempfile from dataclasses import dataclass, replace from importlib.util import find_spec +from pathlib import Path from typing import Any, Optional import numpy as np @@ -133,7 +133,7 @@ def compile_stan_model( if code is None: if filename is None: raise ValueError("Either code or filename have to be specified") - with open(filename) as file: + with Path(filename).open() as file: code = file.read() if model_name is None: @@ -142,7 +142,7 @@ def compile_stan_model( basedir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) try: model_path = ( - pathlib.Path(basedir.name) + Path(basedir.name) .joinpath("name") .with_name(model_name) # This verifies that it is a valid filename .with_suffix(".stan") From b8d098d0073268bfa2b83c73790d69b7b64eb6a3 Mon Sep 17 00:00:00 2001 From: Ben Mares Date: Mon, 19 Feb 2024 22:35:16 +0100 Subject: [PATCH 6/6] Add ruff config --- pyproject.toml | 47 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index b2f3bce..1394880 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,3 +36,50 @@ all = [ "pymc >= 5.5.0", "numba >= 0.57.1", ] + +[tool.ruff] +line-length = 88 +target-version = "py39" +show-fixes = true +output-format = "full" + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # Pyflakes + "I", # isort + "C4", # flake8-comprehensions + "B", # flake8-bugbear + "UP", # pyupgrade + "RUF", # Ruff-specific rules + "TID", # flake8-tidy-imports + "BLE", # flake8-blind-except + "PTH", # flake8-pathlib + "A", # flake8-builtins +] +ignore = [ + "C408", # unnecessary-collection-call (allow dict(a=1, b=2); clarity over speed!) + # The following list is recommended to disable these when using ruff's formatter. + # (Not all of the following are actually enabled.) + "W191", # tab-indentation + "E111", # indentation-with-invalid-multiple + "E114", # indentation-with-invalid-multiple-comment + "E117", # over-indented + "D206", # indent-with-spaces + "D300", # triple-single-quotes + "Q000", # bad-quotes-inline-string + "Q001", # bad-quotes-multiline-string + "Q002", # bad-quotes-docstring + "Q003", # avoidable-escaped-quote + "COM812", # missing-trailing-comma + "COM819", # prohibited-trailing-comma + "ISC001", # single-line-implicit-string-concatenation + "ISC002", # multi-line-implicit-string-concatenation +] + +[tool.ruff.lint.flake8-tidy-imports] +ban-relative-imports = "all" + +[tool.ruff.lint.isort] +known-first-party = ["nutpie"]