-
Notifications
You must be signed in to change notification settings - Fork 259
refactor(infer.elbo): add type hints to elbo module #2028
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This looks great! (I approved, but we need a true core dev to approve it). I think the error FAILED test/test_example_utils.py::test_mnist_data_load - urllib.error.HTTPError: HTTP Error 429: Too Many Requests Is not related with these changes. |
This is awesome! I'll need to look into details and learn from your designs. :) Please ignore the failing tests. I guess it's a github issue and will be resolved by github devs soon. |
Thanks @fehiepsi and @juanitorduz for the reviews here. Still getting the hang of these typing concepts myself as you can see! Hopefully it will get easier with time and practice. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks awesome, thanks @brendancooley!
from numpyro.distributions.distribution import DistributionLike | ||
from numpyro.util import find_stack_level, identity | ||
|
||
# Type aliases | ||
Message = dict[str, Any] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you keep Message = MessageT in case users already use Message in their libraries?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops yes, forgot this was pre-existing. reverted.
pyproject.toml
Outdated
@@ -104,6 +112,7 @@ doctest_optionflags = [ | |||
] | |||
|
|||
[tool.mypy] | |||
python_version = 3.12 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this mean that we only allow mypy to work with python 3.12?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for python 3.9, mypy complains about modern (3.10+) type annotations on our type aliases
LossT: TypeAlias = jax.Array | dict[str, jax.Array]
numpyro/infer/elbo.py:44: error: Invalid type alias: expression is not a valid type [valid-type]
numpyro/infer/elbo.py:44: error: Unsupported left operand type for | ("type[Array]") [operator]
Found 2 errors in 1 file (checked 88 source files)
Since python 3.9 is approaching end of life, I went ahead removed this target python_version
but excluded the lint workflow from CI for the 3.9 case, which will let us use X | Y
in lieu of typing.Union[X, Y]
across the codebase without having to add inline ignores for the typechecker.
Alternatively, I could add an inline ignore or adopt use the legacy Union[...]
type hint. Whatever you prefer!
Picking up on the work of @juanitorduz, this MR adds type hints to the
numpyro.infer.elbo
module. Along the way, I've made a few housekeeping changes to try and make it easier to maintain and extend type hints going forward.The most important change warranting discussion is that I've introduced a
ParamSpec
for the parameters of a model function.ELBO
s are now generics over the signature of the model/guide functions they operate on. I think that this is a nice way to (softly) enforce consistency in the*args
and**kwargs
passed toELBO.loss
, but there may be some subtleties that I'm missing here.Summary of Changes
_compute_downstream_costs
code and associated tests (replaced in 28e38d8)numpyro.primitives
withnoqa
numpyro._typing
module for cross-module shared type aliasesjax.Array
to address typing errors and following guidance in JEP 9263numpyro.handlers
, with requisite importfrom __future__ import annotations
to support python 3.9.Previous Typing MRs (for Reference)