Skip to content

refactor(infer.elbo): add type hints to elbo module #2028

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 20, 2025

Conversation

brendancooley
Copy link
Contributor

Picking up on the work of @juanitorduz, this MR adds type hints to the numpyro.infer.elbo module. Along the way, I've made a few housekeeping changes to try and make it easier to maintain and extend type hints going forward.

The most important change warranting discussion is that I've introduced a ParamSpec for the parameters of a model function. ELBOs are now generics over the signature of the model/guide functions they operate on. I think that this is a nice way to (softly) enforce consistency in the *args and **kwargs passed to ELBO.loss, but there may be some subtleties that I'm missing here.

Summary of Changes

  • Added comprehensive type annotations to elbo module
  • Remove dormant _compute_downstream_costs code and associated tests (replaced in 28e38d8)
  • Use Ruff ANN rule to require annotations in typed modules
    • exempt nested functions in numpyro.primitives with noqa
  • Added numpyro._typing module for cross-module shared type aliases
  • Converted some rng_key annotations to jax.Array to address typing errors and following guidance in JEP 9263
  • Use shared type aliases and python 3.10+ union annotations in numpyro.handlers, with requisite import from __future__ import annotations to support python 3.9.

Previous Typing MRs (for Reference)

  1. primitives
  2. handlers
  3. optim

@juanitorduz
Copy link
Contributor

This looks great! (I approved, but we need a true core dev to approve it). I think the error

FAILED test/test_example_utils.py::test_mnist_data_load - urllib.error.HTTPError: HTTP Error 429: Too Many Requests

Is not related with these changes.

@fehiepsi
Copy link
Member

This is awesome! I'll need to look into details and learn from your designs. :)

Please ignore the failing tests. I guess it's a github issue and will be resolved by github devs soon.

@brendancooley
Copy link
Contributor Author

Thanks @fehiepsi and @juanitorduz for the reviews here. Still getting the hang of these typing concepts myself as you can see! Hopefully it will get easier with time and practice.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks awesome, thanks @brendancooley!

from numpyro.distributions.distribution import DistributionLike
from numpyro.util import find_stack_level, identity

# Type aliases
Message = dict[str, Any]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you keep Message = MessageT in case users already use Message in their libraries?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops yes, forgot this was pre-existing. reverted.

pyproject.toml Outdated
@@ -104,6 +112,7 @@ doctest_optionflags = [
]

[tool.mypy]
python_version = 3.12
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this mean that we only allow mypy to work with python 3.12?

Copy link
Contributor Author

@brendancooley brendancooley May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for python 3.9, mypy complains about modern (3.10+) type annotations on our type aliases

LossT: TypeAlias = jax.Array | dict[str, jax.Array]
numpyro/infer/elbo.py:44: error: Invalid type alias: expression is not a valid type  [valid-type]
numpyro/infer/elbo.py:44: error: Unsupported left operand type for | ("type[Array]")  [operator]
Found 2 errors in 1 file (checked 88 source files)

Since python 3.9 is approaching end of life, I went ahead removed this target python_version but excluded the lint workflow from CI for the 3.9 case, which will let us use X | Y in lieu of typing.Union[X, Y] across the codebase without having to add inline ignores for the typechecker.

Alternatively, I could add an inline ignore or adopt use the legacy Union[...] type hint. Whatever you prefer!

@fehiepsi fehiepsi merged commit c7f522b into pyro-ppl:master May 20, 2025
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants