Skip to content

Commit

Permalink
Merge pull request #88 from pymc-devs/add-pre-commit
Browse files Browse the repository at this point in the history
Add .pre-commit-config.yaml
  • Loading branch information
maresb authored Feb 19, 2024
2 parents 802278c + b8d098d commit 4dfd066
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 49 deletions.
23 changes: 23 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
ci:
autofix_prs: false

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: debug-statements
- id: check-merge-conflict
- id: check-toml
- id: check-yaml
- id: debug-statements
- id: end-of-file-fixer
- id: no-commit-to-branch
args: [--branch, main]
- id: trailing-whitespace

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.2
hooks:
- id: ruff
args: ["--fix", "--output-format=full"]
- id: ruff-format
30 changes: 15 additions & 15 deletions notebooks/pytensor_logp.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,24 +120,24 @@ plt.axvline(0, color="grey", alpha=0.5, zorder=-100)
%%file radon_model.stan
data {
int<lower=0> n_counties;
int<lower=0> n_observed;
int<lower=0> n_observed;
array[n_observed] int<lower=1,upper=n_counties> county_idx;
vector[n_observed] is_floor;
vector[n_observed] log_radon;
}
}
parameters {
real intercept;

vector[n_counties] county_raw;
real<lower=0> county_sd;

real floor_effect;

vector[n_counties] county_floor_raw;
real<lower=0> county_floor_sd;

real<lower=0> sigma;
}
}
transformed parameters {
vector[n_counties] county_effect;
vector[n_counties] county_floor_effect;
Expand All @@ -155,17 +155,17 @@ transformed parameters {
}
model {
intercept ~ normal(0, 10);

county_raw ~ normal(0, 1);
county_sd ~ normal(0, 1);

floor_effect ~ normal(0, 2);

county_floor_raw ~ normal(0, 1);
county_floor_sd ~ normal(0, 1);

sigma ~ normal(0, 1.5);

log_radon ~ normal(mu, sigma);
}
```
Expand Down Expand Up @@ -305,13 +305,13 @@ for name, val in data_stan.items():
if isinstance(val, int):
data_json[name] = int(val)
continue

if val.dtype == np.int64:
data_json[name] = list(int(x) for x in val)
continue

data_json[name] = list(val)

with open("radon.json", "w") as file:
json.dump(data_json, file)
```
Expand Down
46 changes: 46 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,49 @@ all = [
"numba >= 0.57.1",
]

[tool.ruff]
line-length = 88
target-version = "py39"
show-fixes = true
output-format = "full"

[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # Pyflakes
"I", # isort
"C4", # flake8-comprehensions
"B", # flake8-bugbear
"UP", # pyupgrade
"RUF", # Ruff-specific rules
"TID", # flake8-tidy-imports
"BLE", # flake8-blind-except
"PTH", # flake8-pathlib
"A", # flake8-builtins
]
ignore = [
"C408", # unnecessary-collection-call (allow dict(a=1, b=2); clarity over speed!)
# The following list is recommended to disable these when using ruff's formatter.
# (Not all of the following are actually enabled.)
"W191", # tab-indentation
"E111", # indentation-with-invalid-multiple
"E114", # indentation-with-invalid-multiple-comment
"E117", # over-indented
"D206", # indent-with-spaces
"D300", # triple-single-quotes
"Q000", # bad-quotes-inline-string
"Q001", # bad-quotes-multiline-string
"Q002", # bad-quotes-docstring
"Q003", # avoidable-escaped-quote
"COM812", # missing-trailing-comma
"COM819", # prohibited-trailing-comma
"ISC001", # single-line-implicit-string-concatenation
"ISC002", # multi-line-implicit-string-concatenation
]

[tool.ruff.lint.flake8-tidy-imports]
ban-relative-imports = "all"

[tool.ruff.lint.isort]
known-first-party = ["nutpie"]
5 changes: 2 additions & 3 deletions python/nutpie/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from nutpie import _lib
from nutpie.compile_pymc import compile_pymc_model
from nutpie.compile_stan import compile_stan_model
from nutpie.sample import sample

from .compile_pymc import compile_pymc_model
from .compile_stan import compile_stan_model

__version__: str = _lib.__version__
__all__ = ["__version__", "sample", "compile_pymc_model", "compile_stan_model"]
14 changes: 7 additions & 7 deletions python/nutpie/compile_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from importlib.util import find_spec
from math import prod
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from typing import TYPE_CHECKING, Any, Optional

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -42,15 +42,15 @@ def codegen(cgctx, builder, sig, args):
class CompiledPyMCModel(CompiledModel):
compiled_logp_func: "numba.core.ccallback.CFunc"
compiled_expand_func: "numba.core.ccallback.CFunc"
shared_data: Dict[str, NDArray]
shared_data: dict[str, NDArray]
user_data: NDArray
n_expanded: int
shape_info: Any
logp_func: Any
expand_func: Any
_n_dim: int
_shapes: Dict[str, Tuple[int, ...]]
_coords: Optional[Dict[str, Any]]
_shapes: dict[str, tuple[int, ...]]
_coords: Optional[dict[str, Any]]

@property
def n_dim(self):
Expand Down Expand Up @@ -269,7 +269,7 @@ def _compute_shapes(model):
mode=pytensor.compile.mode.FAST_COMPILE,
on_unused_input="ignore",
)
return {name: shape for name, shape in zip(trace_vars.keys(), shape_func())}
return dict(zip(trace_vars.keys(), shape_func()))


def _make_functions(model):
Expand Down Expand Up @@ -517,7 +517,7 @@ def logp_numba(dim, x_, out_, logp_, user_data_):
return 4
# if np.any(out == 0):
# return 4
except Exception:
except Exception: # noqa: BLE001
return 1
return 0

Expand Down Expand Up @@ -552,7 +552,7 @@ def expand_numba(dim, expanded, x_, out_, user_data_):
(values,) = extract(x, user_data_)
out[...] = values

except Exception:
except Exception: # noqa: BLE001
return -2
return 0

Expand Down
22 changes: 11 additions & 11 deletions python/nutpie/compile_stan.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import json
import pathlib
import tempfile
from dataclasses import dataclass, replace
from importlib.util import find_spec
from typing import Any, Dict, List, Optional
from pathlib import Path
from typing import Any, Optional

import numpy as np
import pandas as pd
Expand All @@ -22,9 +22,9 @@ def default(self, obj):

@dataclass(frozen=True)
class CompiledStanModel(CompiledModel):
_coords: Optional[Dict[str, Any]]
_coords: Optional[dict[str, Any]]
code: str
data: Optional[Dict[str, NDArray]]
data: Optional[dict[str, NDArray]]
library: Any
model: Any
model_name: Optional[str] = None
Expand Down Expand Up @@ -107,10 +107,10 @@ def compile_stan_model(
*,
code: Optional[str] = None,
filename: Optional[str] = None,
extra_compile_args: Optional[List[str]] = None,
extra_stanc_args: Optional[List[str]] = None,
dims: Optional[Dict[str, int]] = None,
coords: Optional[Dict[str, Any]] = None,
extra_compile_args: Optional[list[str]] = None,
extra_stanc_args: Optional[list[str]] = None,
dims: Optional[dict[str, int]] = None,
coords: Optional[dict[str, Any]] = None,
model_name: Optional[str] = None,
cleanup: bool = True,
) -> CompiledStanModel:
Expand All @@ -133,7 +133,7 @@ def compile_stan_model(
if code is None:
if filename is None:
raise ValueError("Either code or filename have to be specified")
with open(filename, "r") as file:
with Path(filename).open() as file:
code = file.read()

if model_name is None:
Expand All @@ -142,7 +142,7 @@ def compile_stan_model(
basedir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True)
try:
model_path = (
pathlib.Path(basedir.name)
Path(basedir.name)
.joinpath("name")
.with_name(model_name) # This verifies that it is a valid filename
.with_suffix(".stan")
Expand All @@ -164,7 +164,7 @@ def compile_stan_model(
try:
if cleanup:
basedir.cleanup()
except Exception:
except Exception: # noqa: BLE001
pass

return CompiledStanModel(
Expand Down
16 changes: 8 additions & 8 deletions python/nutpie/sample.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
from typing import Optional

import arviz
import fastprogress
Expand All @@ -12,14 +12,14 @@

@dataclass(frozen=True)
class CompiledModel:
dims: Optional[Dict[str, Tuple[str, ...]]]
dims: Optional[dict[str, tuple[str, ...]]]

@property
def n_dim(self) -> int:
raise NotImplementedError()

@property
def shapes(self) -> Optional[Dict[str, Tuple[int, ...]]]:
def shapes(self) -> Optional[dict[str, tuple[int, ...]]]:
raise NotImplementedError()

@property
Expand Down Expand Up @@ -76,10 +76,10 @@ def _trace_to_arviz(traces, n_tune, shapes, **kwargs):
dtype = col.chunks[0].values.to_numpy().dtype
if dtype in [np.float64, np.float32]:
data = np.full(
(n_chains, length) + tuple(shapes[name]), np.nan, dtype=dtype
(n_chains, length, *tuple(shapes[name])), np.nan, dtype=dtype
)
else:
data = np.zeros((n_chains, length) + tuple(shapes[name]), dtype=dtype)
data = np.zeros((n_chains, length, *tuple(shapes[name])), dtype=dtype)
for i, chunk in enumerate(col.chunks):
data[i, : len(chunk)] = chunk.values.to_numpy().reshape(
(len(chunk),) + shapes[name]
Expand All @@ -103,16 +103,16 @@ def _trace_to_arviz(traces, n_tune, shapes, **kwargs):
length = max(lengths)

if dtype in [np.float64, np.float32]:
data = np.full((n_chains, length) + last_shape, np.nan, dtype=dtype)
data = np.full((n_chains, length, *last_shape), np.nan, dtype=dtype)
else:
data = np.zeros((n_chains, length) + last_shape, dtype=dtype)
data = np.zeros((n_chains, length, *last_shape), dtype=dtype)

for i, chunk in enumerate(col.chunks):
if hasattr(chunk, "values"):
values = chunk.values.to_numpy(False)
else:
values = chunk.to_numpy(False)
data[i, : len(chunk)] = values.reshape((len(chunk),) + last_shape)
data[i, : len(chunk)] = values.reshape((len(chunk), *last_shape))
stats_dict[name] = data[:, n_tune:]
stats_dict_tune[name] = data[:, :n_tune]

Expand Down
6 changes: 3 additions & 3 deletions tests/test_pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_pymc_model():

compiled = nutpie.compile_pymc_model(model)
trace = nutpie.sample(compiled, chains=1)
trace.posterior.a
trace.posterior.a # noqa: B018


def test_pymc_model_with_coordinate():
Expand All @@ -22,7 +22,7 @@ def test_pymc_model_with_coordinate():

compiled = nutpie.compile_pymc_model(model)
trace = nutpie.sample(compiled, chains=1)
trace.posterior.a
trace.posterior.a # noqa: B018


def test_trafo():
Expand All @@ -31,7 +31,7 @@ def test_trafo():

compiled = nutpie.compile_pymc_model(model)
trace = nutpie.sample(compiled, chains=1)
trace.posterior.a
trace.posterior.a # noqa: B018


def test_det():
Expand Down
4 changes: 2 additions & 2 deletions tests/test_stan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_stan_model():

compiled_model = nutpie.compile_stan_model(code=model)
trace = nutpie.sample(compiled_model)
trace.posterior.a
trace.posterior.a # noqa: B018


def test_stan_model_data():
Expand All @@ -37,4 +37,4 @@ def test_stan_model_data():
with pytest.raises(RuntimeError):
trace = nutpie.sample(compiled_model)
trace = nutpie.sample(compiled_model.with_data(x=np.array(3.0)))
trace.posterior.a
trace.posterior.a # noqa: B018

0 comments on commit 4dfd066

Please sign in to comment.