Skip to content

Releases: pyro-ppl/numpyro

0.7.1

11 Jul 19:42
ed4e956
Compare
Choose a tag to compare

In 0.7.0 release, the wheel file uploaded to PyPI had some files not updated. This release fixes that issue.

0.7.0

11 Jul 02:07
e398f19
Compare
Choose a tag to compare

Since this release, NumPyro can be installed along with the latest jax and jaxlib releases (their version restrictions have been relaxed). In addition, NumPyro will use the default JAX platform so if you installed JAX with GPU/TPU support, their devices will be used by default.

New Features

Enhancements and Bug Fixes

  • Documentation and examples are greatly enhanced to make features more accessible
  • Fix chain detection for various CPU device strings #1077
  • Fix AutoNormal's quantiles method for models with non-scalar latent sites #1066
  • Fix LocScaleReparam with center=1 #1059
  • Enhance auto guides to support models with deterministic sites #1022
  • Support for mutable states in Flax and Haiku modules #1016
  • Fix a bug in auto guides that happens when using the guide in Predictive #1013
  • Support decorator syntax for effect handlers #1009
  • Implement sparse Poisson log probability #1003
  • Support total_count=0 in Multinomial distribution #1000
  • Add a flag to control regularize mass matrix behavior in mass matrix adaptation #998
  • Add experimental Dockerfiles #996
  • Allow setting max tree depth of NUTS sampler during warmup phase #984
  • Fix dimensions mixed up in ExpandedDistribution.sample method #972
  • MCMC objects can be pickled now #968

This release is made of great contributions and feedbacks from the Pyro community: @ahoho, @kpj, @gustavehug, @AndrewCSQ, @jatentaki, @tcbegley, @dominikstrb, @justinrporter, @dirmeier, @irustandi, @MarcoGorelli, @lumip, and many others. Thank you!

0.6.0

16 Mar 17:57
ecd6255
Compare
Choose a tag to compare

New Features

Enhancements and Bug Fixes

  • Improve precision for Dirichlet distributions with small concentration #943
  • Make it easy to use softplus transforms in autoguides #941
  • Improving compiling time in MCMC samplers - compiling time is 2x faster than previously #924
  • Reduce memory requirement for AutoLowRankMultivariateNormal.quantiles #921
  • Example of how to use Distribution.mask #917
  • Add goodness of fit helpers for testing distributions #916
  • Enabling sampling with intermediates for ExpandedDistribution #909
  • Fix DiscreteHMCGibbs to work with multiple chains #908
  • Fix missing infer key in handlers.lift #892

Thanks @loopylangur, Dominik Straub @dominikstrb, Jeremie Coullon @jeremiecoullon, Ola Rønning @OlaRonning, Lukas Prediger @lumip, Raúl Peralta Lozada @RaulPL, Vitalii Kleshchevnikov @vitkl, Matt Ludkin @ludkinm, and many others for your contributions and feedback!

0.5.0

24 Jan 20:44
6a1f522
Compare
Choose a tag to compare

New documentation page with galleries of tutorials and examples num.pyro.ai.

New Features

  • New primitive: prng_key to draw a random key under seed handler.
  • New autoguide: AutoDelta
  • New samplers:
    • HMCGibbs: a general HMC/NUTS-within-Gibbs interface.
    • DiscreteHMCGibbs: HMC/NUTS-within-Gibbs for models with discrete latent variables.
    • HMCECS: HMC/NUTS with energy conserving subsampling.
  • New example:
  • New kernels module in numpyro.contrib.einstein, in preparing for (Ein)Stein VI inference in future releases.
  • New user-friendly SVI.run method to simplify the training phase of SVI inference.
  • New feasible_like method in constraints.
  • New methods forward_shape and inverse_shape in Transform to infer output shape given input shape.
  • Transform.inv now returns an inversed transform, hence enables many new (inversed) transforms.
  • Support thinning in MCMC.
  • Add post_warmup_state and last_state to allow sequential sampling strategy in MCMC: allow to keep calling .run method to get more samples.
  • New history argument to support for Markov models with history > 1 in scan.
  • New forward_model_differentiation argument in HMC/NUTS kernels to allow to use forward mode differentiation.

Enhancements and Bug Fixes

  • #886 Make TransformReparam compatible with .to_event()
  • #883 Improve gradient computation of Euclidean kinetic energy.
  • #872 Enhance masked distribution to allow gradient propagate properly when using mask handler for invalid data.
  • #865 Make subsample faster in CPU.
  • #860 Fix for memory leak in MCMC.
  • #849 Expose logits attribute to some discrete distributions
  • #848 Add has_rsample and rsample attribute to distributions
  • #832 Allow a callable to return an init value in param primitive
  • #824 Fix for cannot using sample method of TFP distributions in sample primitive.
  • #823 Demo on how to use various init strategies in Gaussian Process example.
  • #822 Allow haiku/flax modules to take general args/kwargs in init.
  • #821 Better error messages when rng_key is missing.
  • #818 Better error messages when an error happens in the middle of inference.
  • #805 Display correct progress bar message after running MCMC.warmup.
  • #801 Raise an error early if missing plates for models with discrete latent variables.
  • #797 MCMC vectorized chain method works for models with deterministic sites.
  • #796 Bernoulli distribution returns an int instead of a boolean.
  • #795 Reveal signature for help(Distribution).

Thanks Ola Ronning @OlaRonning, Armin Stepanjan @ab-10, @cerbelaut, Xi Wang @xidulu, Wouter van Amsterdam @vanAmsterdam, @loopylangur, and many others for your contributions and helpful feedback!

0.4.1

19 Oct 21:53
42ed07f
Compare
Choose a tag to compare

New Features

Enhancements and Bug Fixes

  • #764 Make exception chaining more user-friendly. Thanks, @akihironitta!
  • #766 Relax interval constraint.
  • #776 Fix bugs in methods log_prob and sample of VonMises distribution.
  • #775 Make validation mechanism compatible with omnistaging since JAX 0.2.
  • #780 Fix name dimensions of sample sites under contrib.funsor's plate handler.

0.4.0

30 Sep 17:30
8ce9fd3
Compare
Choose a tag to compare

Experimental integrations with JAX-based TensorFlow Probability and neural network libraries Flax and Haiku. New high-quality tutorials written by NumPyro contributors. JAX 0.2 enables "omnistaging" by default (see this guide for what omnistaging means and how to update your code if it is broken after the upgrade - you can also disable this new behavior with jax.config.disable_omnistaging()).

New Features

New Examples

Deprecation

Changes to match Pyro api.

  • ELBO objective is renamed to Trace_ELBO.
  • value argument in Delta distribution is replaced by v.
  • init_strategy argument in autoguides is replaced by init_loc_fn.

Enhancements and Bug Fixes

  • Relax simplex constraint. #725 #737
  • Fix init_strategy argument not respected in HMC and SA kernels. #728
  • Validate the model when cannot find valid initial params. #733
  • Avoid nan acceptance probability in SA kernel. #740

Thanks @xidulu, @vanAmsterdam, @TuanNguyen27, @ucals, @elchorro, @RaulPL, and many others for your contributions and helpful feedback!

0.3.0

27 Jul 17:39
e1433ff
Compare
Choose a tag to compare

Breaking Changes

  • HMC's find_heuristic_step_size (this functionality is different from step size adaptation scheme) is disabled by default to improve compiling time. Previous behavior can be enabled by setting find_heuristic_step_size=True.
  • The automatic reparameterization mechanism introduced in NumPyro 0.2 is removed, in favor of reparam handler. See the eight schools example for the new usage pattern.
  • Automatic Guide Generation module is moved from numpyro.contrib.autoguide to the main inference module numpyro.infer.autoguide.
  • Various API changes to match Pyro API:

New Features

New Examples

Enhancements and Bug Fixes

  • HMC/NUTS compiling time is greatly improved, especially for large models.
  • More efficient BTRS algorithm for sampling from Binomial distribution. #537
  • Allow arbitrary order of plate statements. #555
  • Fix KeyError with scale handler and deterministic primitive. #577
  • Fix Poisson sampler entering into infinite loop under vmap. #582
  • Fix the double compilation issue in numpyro.optim classes. #603
  • Use ExpandedDistribution in numpyro.plate. #616
  • Timeseries forecasting tutorial is updated with scan primitive and the usage of Predictive for forecasting. #608 #657
  • Tweak sparse regression example to bring the model into exact alignment with the reference. #669
  • Add MetropolisHastings algorithm as an example of MCMCKernel. #680

Thanks Nikolaos @daydreamt, Daniel Sheldon @dsheldon, Lukas Prediger @lumip, Freddy Boulton @freddyaboulton, Wouter van Amsterdam @vanAmsterdam, and many others for their contributions and helpful feedback!

0.2.4

23 Jan 19:40
Compare
Choose a tag to compare

New Features

New Examples

Deprecation / Breaking Changes

  • Predictive's get_samples method is deprecated in favor of __call__ method.
  • MCMC constrain_fn is renamed to postprocess_fn.

Enhancements and Bug Fixes

  • Change the init scale of Auto*Normal guides from 1. to 0.1 - this is helpful for stability during the early training phase.
  • Resolve overflow issue with the Poisson sampler.

0.2.3

05 Dec 06:42
Compare
Choose a tag to compare

Patches 0.2.2 with the following changes:

  • restore compatibility with python 3.7 for mcmc.
  • impose cache size limit in MCMC utilities.

0.2.2

04 Dec 20:49
Compare
Choose a tag to compare

Breaking changes

  • Minor interface changes to MCMC utility functions. All experimental interfaces are marked as such in the documentation.

New Features

  • A numpyro.factor primitive that adds an arbitrary log probability factor to a probabilistic model.

Enhancements and Bug Fixes

  • Addressed a bug where multiple invocations of MCMC.run would wrongly use the previously cached arguments.
  • MCMC reuses compiled model code whenever possible. e.g. when re-running with different but same sized model arguments.
  • Ability to reuse the same warmup state for subsequent MCMC runs using MCMC.warmup.