Skip to content

Commit

Permalink
Allow decomposition methods in MvNormal
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Feb 13, 2025
1 parent 2823dfc commit 2aecb95
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 18 deletions.
15 changes: 14 additions & 1 deletion pytensor/link/jax/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ def jax_sample_fn(op, node):
@jax_sample_fn.register(ptr.BetaRV)
@jax_sample_fn.register(ptr.DirichletRV)
@jax_sample_fn.register(ptr.PoissonRV)
@jax_sample_fn.register(ptr.MvNormalRV)
def jax_sample_fn_generic(op, node):
"""Generic JAX implementation of random variables."""
name = op.name
Expand Down Expand Up @@ -173,6 +172,20 @@ def sample_fn(rng, size, dtype, *parameters):
return sample_fn


@jax_sample_fn.register(ptr.MvNormalRV)
def jax_sample_mvnormal(op, node):
def sample_fn(rng, size, dtype, mean, cov):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
sample = jax.random.multivariate_normal(
sampling_key, mean, cov, shape=size, dtype=dtype, method=op.method
)
rng["jax_state"] = rng_key
return (rng, sample)

return sample_fn


@jax_sample_fn.register(ptr.BernoulliRV)
def jax_sample_fn_bernoulli(op, node):
"""JAX implementation of `BernoulliRV`."""
Expand Down
19 changes: 16 additions & 3 deletions pytensor/link/numba/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,11 +144,24 @@ def random_fn(rng, p):

@numba_core_rv_funcify.register(ptr.MvNormalRV)
def core_MvNormalRV(op, node):
method = op.method

@numba_basic.numba_njit
def random_fn(rng, mean, cov):
chol = np.linalg.cholesky(cov)
stdnorm = rng.normal(size=cov.shape[-1])
return np.dot(chol, stdnorm) + mean
if method == "cholesky":
A = np.linalg.cholesky(cov)
elif method == "svd":
A, s, _ = np.linalg.svd(cov)
A *= np.sqrt(s)[None, :]
else:
w, A = np.linalg.eigh(cov)
A *= np.sqrt(w)[None, :]

out = rng.normal(size=cov.shape[-1])
# out argument not working correctly: https://github.com/numba/numba/issues/9924
out[:] = np.dot(A, out)
out += mean
return out

random_fn.handles_out = True
return random_fn
Expand Down
41 changes: 27 additions & 14 deletions pytensor/tensor/random/basic.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import abc
import warnings
from typing import Literal

import numpy as np
import scipy.stats as stats
from numpy import broadcast_shapes as np_broadcast_shapes
from numpy import einsum as np_einsum
from numpy import sqrt as np_sqrt
from numpy.linalg import cholesky as np_cholesky
from numpy.linalg import eigh as np_eigh
from numpy.linalg import svd as np_svd

import pytensor
from pytensor.tensor import get_vector_length, specify_shape
from pytensor.tensor.basic import as_tensor_variable
from pytensor.tensor.math import sqrt
Expand Down Expand Up @@ -852,8 +855,17 @@ class MvNormalRV(RandomVariable):
signature = "(n),(n,n)->(n)"
dtype = "floatX"
_print_name = ("MultivariateNormal", "\\operatorname{MultivariateNormal}")
__props__ = ("name", "signature", "dtype", "inplace", "method")

def __call__(self, mean=None, cov=None, size=None, **kwargs):
def __init__(self, *args, method: Literal["cholesky", "svd", "eigh"], **kwargs):
super().__init__(*args, **kwargs)
if method not in ("cholesky", "svd", "eigh"):
raise ValueError(
f"Unknown method {method}. The method must be one of 'cholesky', 'svd', or 'eigh'."
)
self.method = method

def __call__(self, mean, cov, size=None, **kwargs):
r""" "Draw samples from a multivariate normal distribution.
Signature
Expand All @@ -876,33 +888,34 @@ def __call__(self, mean=None, cov=None, size=None, **kwargs):
is specified, a single `N`-dimensional sample is returned.
"""
dtype = pytensor.config.floatX if self.dtype == "floatX" else self.dtype

if mean is None:
mean = np.array([0.0], dtype=dtype)
if cov is None:
cov = np.array([[1.0]], dtype=dtype)
return super().__call__(mean, cov, size=size, **kwargs)

@classmethod
def rng_fn(cls, rng, mean, cov, size):
def rng_fn(self, rng, mean, cov, size):
if size is None:
size = np_broadcast_shapes(mean.shape[:-1], cov.shape[:-2])

chol = np_cholesky(cov)
if self.method == "cholesky":
A = np_cholesky(cov)
elif self.method == "svd":
A, s, _ = np_svd(cov)
A *= np_sqrt(s, out=s)[..., None, :]
else:
w, A = np_eigh(cov)
A *= np_sqrt(w, out=w)[..., None, :]

out = rng.normal(size=(*size, mean.shape[-1]))
np_einsum(
"...ij,...j->...i", # numpy doesn't have a batch matrix-vector product
chol,
A,
out,
out=out,
optimize=False, # Nothing to optimize with two operands, skip costly setup
out=out,
)
out += mean
return out


multivariate_normal = MvNormalRV()
multivariate_normal = MvNormalRV(method="cholesky")


class DirichletRV(RandomVariable):
Expand Down
6 changes: 6 additions & 0 deletions tests/link/jax/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
batched_permutation_tester,
batched_unweighted_choice_without_replacement_tester,
batched_weighted_choice_without_replacement_tester,
create_mvnormal_cov_decomposition_method_test,
)


Expand Down Expand Up @@ -547,6 +548,11 @@ def test_random_mvnormal():
np.testing.assert_allclose(samples.mean(axis=0), mu, atol=0.1)


test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_method_test(
"JAX"
)


@pytest.mark.parametrize(
"parameter, size",
[
Expand Down
6 changes: 6 additions & 0 deletions tests/link/numba/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
batched_permutation_tester,
batched_unweighted_choice_without_replacement_tester,
batched_weighted_choice_without_replacement_tester,
create_mvnormal_cov_decomposition_method_test,
)


Expand Down Expand Up @@ -147,6 +148,11 @@ def test_multivariate_normal():
)


test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_method_test(
"NUMBA"
)


@pytest.mark.parametrize(
"rv_op, dist_args, size",
[
Expand Down
44 changes: 44 additions & 0 deletions tests/tensor/random/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pytensor.tensor import ones, stack
from pytensor.tensor.random.basic import (
ChoiceWithoutReplacement,
MvNormalRV,
PermutationRV,
_gamma,
bernoulli,
Expand Down Expand Up @@ -686,6 +687,49 @@ def test_mvnormal_ShapeFeature():
assert s4.get_test_value() == 3


def create_mvnormal_cov_decomposition_method_test(mode):
@pytest.mark.parametrize("psd", (True, False))
@pytest.mark.parametrize("method", ("cholesky", "svd", "eigh"))
def test_mvnormal_cov_decomposition_method(method, psd):
mean = 2 ** np.arange(3)
if psd:
cov = [
[1, 0.5, -1],
[0.5, 2, 0],
[-1, 0, 3],
]
else:
cov = [
[1, 0.5, 0],
[0.5, 2, 0],
[0, 0, 0],
]
rng = shared(np.random.default_rng(675))
draws = MvNormalRV(method=method)(mean, cov, rng=rng, size=(10_000,))
assert draws.owner.op.method == method

# JAX doesn't raise errors at runtime
if not psd and method == "cholesky":
if mode == "JAX":
# JAX doesn't raise errors at runtime, instead it returns nan
np.isnan(draws.eval(mode=mode)).all()
else:
with pytest.raises(np.linalg.LinAlgError):
draws.eval(mode=mode)

else:
draws_eval = draws.eval(mode=mode)
np.testing.assert_allclose(np.mean(draws_eval, axis=0), mean, rtol=0.02)
np.testing.assert_allclose(np.cov(draws_eval, rowvar=False), cov, atol=0.1)

return test_mvnormal_cov_decomposition_method


test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_method_test(
None
)


@pytest.mark.parametrize(
"alphas, size",
[
Expand Down

0 comments on commit 2aecb95

Please sign in to comment.