From 80b1e7a7910eaa27c82850ef672206dc5510de7c Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Mon, 27 Jan 2025 16:29:03 -0500 Subject: [PATCH 01/19] implement sizesof for named and positional distributions --- effectful/handlers/pyro.py | 23 +++++++++++++++++++++++ effectful/handlers/torch.py | 20 +++++++++++++++----- tests/test_handlers_pyro.py | 34 ++++++++++++++++++++++++++++++++++ 3 files changed, 72 insertions(+), 5 deletions(-) diff --git a/effectful/handlers/pyro.py b/effectful/handlers/pyro.py index d24a5cfb..beb5abd2 100644 --- a/effectful/handlers/pyro.py +++ b/effectful/handlers/pyro.py @@ -342,6 +342,9 @@ def log_prob(self, value): def enumerate_support(self, expand=True): return self._to_positional(self.base_dist.enumerate_support(expand)) + def _sizesof(self): + return {} + class NamedDistribution(pyro.distributions.torch_distribution.TorchDistribution): """A distribution wrapper that lazily names leftmost dimensions.""" @@ -437,6 +440,26 @@ def log_prob(self, value): def enumerate_support(self, expand=True): return self._to_named(self.base_dist.enumerate_support(expand)) + def _sizesof(self): + base_shape = self.base_dist.shape() + return { + v: base_shape[d] for (v, d) in self.naming.name_to_dim.items() + } | sizesof(self.base_dist) + + +@sizesof.register(PositionalDistribution) +def _sizesof_positional_distribution( + value: PositionalDistribution, +) -> Mapping[Operation[[], int], int]: + return value._sizesof() + + +@sizesof.register(NamedDistribution) +def _sizesof_named_distribution( + value: NamedDistribution, +) -> Mapping[Operation[[], int], int]: + return value._sizesof() + def pyro_module_shim( module: type[pyro.nn.module.PyroModule], diff --git a/effectful/handlers/torch.py b/effectful/handlers/torch.py index 81eabf68..15385a73 100644 --- a/effectful/handlers/torch.py +++ b/effectful/handlers/torch.py @@ -75,7 +75,8 @@ def _getitem_ellipsis_and_none( return torch.reshape(x, new_shape), new_key -def sizesof(value: Expr) -> Mapping[Operation[[], int], int]: +@functools.singledispatch +def sizesof(value) -> Mapping[Operation[[], int], int]: """Return the sizes of named dimensions in a tensor expression. Sizes are inferred from the tensor shape. @@ -89,11 +90,11 @@ def sizesof(value: Expr) -> Mapping[Operation[[], int], int]: >>> sizesof(Indexable(torch.ones(2, 3))[a(), b()]) {a: 2, b: 3} """ - if isinstance(value, torch.distributions.Distribution) and not isinstance( - value, Term - ): - return {v: s for a in value.__dict__.values() for v, s in sizesof(a).items()} + return {} + +@sizesof.register(Term) +def _sizesof_term(value: Term) -> Mapping[Operation[[], int], int]: sizes: dict[Operation[[], int], int] = {} def _torch_getitem_sizeof( @@ -128,6 +129,15 @@ def _torch_getitem_sizeof( return sizes +@sizesof.register(torch.distributions.Distribution) +def _sizesof_distribution( + value: torch.distributions.Distribution, +) -> Mapping[Operation[[], int], int]: + if isinstance(value, Term): + return _sizesof_term(value) + return {v: s for a in value.__dict__.values() for v, s in sizesof(a).items()} + + def _partial_eval(t: T, order: Optional[Sequence[Operation[[], int]]] = None) -> T: """Partially evaluate a term with respect to its sized free variables. diff --git a/tests/test_handlers_pyro.py b/tests/test_handlers_pyro.py index 783d8b68..bfdc5432 100644 --- a/tests/test_handlers_pyro.py +++ b/tests/test_handlers_pyro.py @@ -238,3 +238,37 @@ def test_simple_distribution(): dist.Beta(t, t, validate_args=False) dist.Bernoulli(t, validate_args=False) + + +def test_sizesof_named_distribution(): + # Create base distribution with known batch shape + base_dist = dist.Normal(torch.zeros(3, 4, 5), torch.ones(3, 4, 5)) + + # Create names for the first two dimensions + dim0 = defop(int, name="dim0") + dim1 = defop(int, name="dim1") + names = [dim0, dim1] + + # Create named distribution + named_dist = NamedDistribution(base_dist, names) + + # Get sizes + sizes = sizesof(named_dist) + + # Check that the sizes match expected values + assert sizes[dim0] == 3 + assert sizes[dim1] == 4 + assert len(sizes) == 2 # Only named dimensions should be included + + +def test_sizesof_positional_distribution(): + dim0 = defop(int, name="dim0") + dim1 = defop(int, name="dim1") + + mean = Indexable(torch.zeros(3, 4, 5))[dim0(), dim1()] + var = Indexable(torch.ones(3, 4, 5))[dim0(), dim1()] + base_dist = dist.Normal(mean, var) + + pos_dist = PositionalDistribution(base_dist) + + assert sizesof(pos_dist) == {} From f93d5488cd608710ebbc54e252b8da9359b6a565 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 28 Jan 2025 15:38:06 -0500 Subject: [PATCH 02/19] first pass at replacing dist wrappers with term subclasses --- effectful/handlers/indexed.py | 10 +- effectful/handlers/pyro.py | 203 ++++++++++++++++++++-------------- effectful/handlers/torch.py | 14 --- tests/test_handlers_pyro.py | 14 +-- 4 files changed, 133 insertions(+), 108 deletions(-) diff --git a/effectful/handlers/indexed.py b/effectful/handlers/indexed.py index e97f0726..598b5174 100644 --- a/effectful/handlers/indexed.py +++ b/effectful/handlers/indexed.py @@ -32,13 +32,13 @@ class IndexSet(Dict[str, Set[int]]): for which a value is defined:: >>> IndexSet(x={0, 1}, y={2, 3}) - IndexSet({x: {0, 1}, y: {2, 3}}) + IndexSet({'x': {0, 1}, 'y': {2, 3}}) :class:`IndexSet` 's constructor will automatically drop empty entries and attempt to convert input values to :class:`set` s:: >>> IndexSet(x=[0, 0, 1], y=set(), z=2) - IndexSet({x: {0, 1}, z: {2}}) + IndexSet({'x': {0, 1}, 'z': {2}}) :class:`IndexSet` s are also hashable and can be used as keys in :class:`dict` s:: @@ -263,8 +263,10 @@ def cond(fst: torch.Tensor, snd: torch.Tensor, case_: torch.Tensor) -> torch.Ten Unlike a Python conditional expression, however, the case may be a tensor, and both branches are evaluated, as with :func:`torch.where` :: - >>> from effectful.internals.sugar import gensym - >>> b = gensym(int, name="b") + >>> from effectful.ops.syntax import defop + >>> from effectful.handlers.torch import to_tensor + + >>> b = defop(int, name="b") >>> fst, snd = Indexable(torch.randn(2, 3))[b()], Indexable(torch.randn(2, 3))[b()] >>> case = (fst < snd).all(-1) >>> x = cond(fst, snd, case) diff --git a/effectful/handlers/pyro.py b/effectful/handlers/pyro.py index beb5abd2..b89b4188 100644 --- a/effectful/handlers/pyro.py +++ b/effectful/handlers/pyro.py @@ -1,12 +1,17 @@ import typing import warnings -from typing import Any, Collection, List, Mapping, Optional, Tuple +from typing import Any, Collection, List, Mapping, Optional, Tuple, Annotated, TypeVar try: import pyro except ImportError: raise ImportError("Pyro is required to use effectful.handlers.pyro.") +from pyro.distributions.torch_distribution import ( + TorchDistributionMixin, + TorchDistribution, +) + try: import torch except ImportError: @@ -15,16 +20,18 @@ from typing_extensions import ParamSpec from effectful.handlers.torch import Indexable, sizesof, to_tensor -from effectful.ops.syntax import defop -from effectful.ops.types import Operation +from effectful.ops.syntax import defop, Scoped, defterm, defdata +from effectful.ops.types import Operation, Term P = ParamSpec("P") +A = TypeVar("A") +B = TypeVar("B") @defop def pyro_sample( name: str, - fn: pyro.distributions.torch_distribution.TorchDistributionMixin, + fn: TorchDistributionMixin, *args, obs: Optional[torch.Tensor] = None, obs_mask: Optional[torch.BoolTensor] = None, @@ -124,9 +131,7 @@ def _pyro_sample(self, msg: pyro.poutine.runtime.Message) -> None: assert msg["type"] == "sample" assert msg["name"] is not None assert msg["infer"] is not None - assert isinstance( - msg["fn"], pyro.distributions.torch_distribution.TorchDistributionMixin - ) + assert isinstance(msg["fn"], TorchDistributionMixin) if pyro.poutine.util.site_is_subsample(msg) or pyro.poutine.util.site_is_factor( msg @@ -142,7 +147,7 @@ def _pyro_sample(self, msg: pyro.poutine.runtime.Message) -> None: # pdist shape: | named1 | batch_shape | event_shape | # obs shape: | batch_shape | event_shape |, | named2 | where named2 may overlap named1 - pdist = PositionalDistribution(dist) + pdist = positional_distribution(dist) naming = pdist.naming if msg["mask"] is None: @@ -209,9 +214,7 @@ def _pyro_post_sample(self, msg: pyro.poutine.runtime.Message) -> None: if value is not None: # note: is it safe to assume that msg['fn'] is a distribution? - assert isinstance( - msg["fn"], pyro.distributions.torch_distribution.TorchDistribution - ) + assert isinstance(msg["fn"], TorchDistribution) dist_shape: tuple[int, ...] = msg["fn"].batch_shape + msg["fn"].event_shape if len(value.shape) < len(dist_shape): value = value.broadcast_to( @@ -252,25 +255,103 @@ def __repr__(self): return f"Naming({self.name_to_dim})" -class PositionalDistribution(pyro.distributions.torch_distribution.TorchDistribution): +@defop +def named_distribution( + dist: Annotated[TorchDistribution, Scoped[A]], + *names: Annotated[Operation[[], int], Scoped[B]], +) -> Annotated[TorchDistribution, Scoped[A | B]]: + raise NotImplementedError + + +@defop +def positional_distribution( + dist: Annotated[TorchDistribution, Scoped[A]] +) -> TorchDistribution: + raise NotImplementedError + + +@Term.register +class _DistributionTerm(TorchDistribution): + """A distribution wrapper that satisfies the Term interface.""" + + op: Operation[..., TorchDistribution] = defop(TorchDistribution) + args: Tuple[torch.Tensor, ...] + kwargs: Mapping[str, object] = {} + + __match_args__ = ("op", "args", "kwargs") + + def __init__(self, base_dist: TorchDistribution): + self.args = tuple(base_dist.__dict__.values()) + self.base_dist = base_dist + + def __getitem__(self, key: Collection[Operation[[], int]]): + return named_distribution(self, *key) + + @property + def has_rsample(self): + return self.base_dist.has_rsample + + @property + def batch_shape(self): + return self.base_dist.batch_shape + + @property + def event_shape(self): + return self.base_dist.event_shape + + @property + def has_enumerate_support(self): + return self.base_dist.has_enumerate_support + + @property + def arg_constraints(self): + return self.base_dist.arg_constraints + + @property + def support(self): + return self.base_dist.support + + def sample(self, sample_shape=torch.Size()): + return self.base_dist.sample(sample_shape) + + def rsample(self, sample_shape=torch.Size()): + return self.base_dist.rsample(sample_shape) + + def log_prob(self, value): + return self.base_dist.log_prob(value) + + def enumerate_support(self, expand=True): + return self.base_dist.enumerate_support(expand) + + +@defterm.register(TorchDistribution) +def _embed_dist(dist: TorchDistribution) -> Term[TorchDistribution]: + return _DistributionTerm(dist) + + +class _PositionalDistributionTerm(_DistributionTerm): """A distribution wrapper that lazily converts indexed dimensions to positional. """ + op: Operation[..., TorchDistribution] = positional_distribution + args: Tuple[TorchDistribution] + kwargs: Mapping[str, object] = {} + + __match_args__ = ("op", "args", "kwargs") + indices: Mapping[Operation[[], int], int] - def __init__( - self, base_dist: pyro.distributions.torch_distribution.TorchDistribution - ): - self.base_dist = base_dist + def __init__(self, base_dist: TorchDistribution): + super().__init__(base_dist) + + self.args = (base_dist,) self.indices = sizesof(base_dist) n_base = len(base_dist.batch_shape) + len(base_dist.event_shape) self.naming = Naming.from_shape(self.indices.keys(), n_base) - super().__init__() - def _to_positional(self, value: torch.Tensor) -> torch.Tensor: # self.base_dist has shape: | batch_shape | event_shape | & named # assume value comes from base_dist with shape: @@ -325,9 +406,6 @@ def arg_constraints(self): def support(self): return self.base_dist.support - def __repr__(self): - return f"PositionalDistribution({self.base_dist})" - def sample(self, sample_shape=torch.Size()): return self._to_positional(self.base_dist.sample(sample_shape)) @@ -342,18 +420,17 @@ def log_prob(self, value): def enumerate_support(self, expand=True): return self._to_positional(self.base_dist.enumerate_support(expand)) - def _sizesof(self): - return {} - -class NamedDistribution(pyro.distributions.torch_distribution.TorchDistribution): +class _NamedDistributionTerm(_DistributionTerm): """A distribution wrapper that lazily names leftmost dimensions.""" - def __init__( - self, - base_dist: pyro.distributions.torch_distribution.TorchDistribution, - names: Collection[Operation[[], int]], - ): + op: Operation[..., TorchDistribution] = named_distribution + args: tuple + kwargs: Mapping[str, object] = {} + + __match_args__ = ("op", "args", "kwargs") + + def __init__(self, base_dist: TorchDistribution, *names: Operation[[], int]): """ :param base_dist: A distribution with batch dimensions. @@ -391,74 +468,34 @@ def _from_named(self, value: torch.Tensor) -> torch.Tensor: return pos_tensor_r - @property - def has_rsample(self): - return self.base_dist.has_rsample - @property def batch_shape(self): - return self.base_dist.batch_shape[len(self.names) :] - - @property - def event_shape(self): - return self.base_dist.event_shape - - @property - def has_enumerate_support(self): - return self.base_dist.has_enumerate_support - - @property - def arg_constraints(self): - return self.base_dist.arg_constraints - - @property - def support(self): - return self.base_dist.support - - def __repr__(self): - return f"NamedDistribution({self.base_dist}, {self.names})" + return super().batch_shape[len(self.names) :] def sample(self, sample_shape=torch.Size()): - t = self._to_named( - self.base_dist.sample(sample_shape), offset=len(sample_shape) - ) + t = self._to_named(super().sample(sample_shape), offset=len(sample_shape)) assert set(sizesof(t).keys()) == set(self.names) assert t.shape == self.shape() + sample_shape return t def rsample(self, sample_shape=torch.Size()): - return self._to_named( - self.base_dist.rsample(sample_shape), offset=len(sample_shape) - ) + return self._to_named(super().rsample(sample_shape), offset=len(sample_shape)) def log_prob(self, value): - v1 = self._from_named(value) - v2 = self.base_dist.log_prob(v1) - v3 = self._to_named(v2) - return v3 + return self._to_named(super().log_prob(self._from_named(value))) def enumerate_support(self, expand=True): - return self._to_named(self.base_dist.enumerate_support(expand)) - - def _sizesof(self): - base_shape = self.base_dist.shape() - return { - v: base_shape[d] for (v, d) in self.naming.name_to_dim.items() - } | sizesof(self.base_dist) - - -@sizesof.register(PositionalDistribution) -def _sizesof_positional_distribution( - value: PositionalDistribution, -) -> Mapping[Operation[[], int], int]: - return value._sizesof() + return self._to_named(super().enumerate_support(expand)) -@sizesof.register(NamedDistribution) -def _sizesof_named_distribution( - value: NamedDistribution, -) -> Mapping[Operation[[], int], int]: - return value._sizesof() +@defdata.register(TorchDistribution) +def _(op, *args, **kwargs): + if op is named_distribution: + return _NamedDistributionTerm(*args, **kwargs) + elif op is positional_distribution: + return _PositionalDistributionTerm(*args, **kwargs) + else: + Term(op, *args, **kwargs) def pyro_module_shim( diff --git a/effectful/handlers/torch.py b/effectful/handlers/torch.py index 15385a73..50292ec6 100644 --- a/effectful/handlers/torch.py +++ b/effectful/handlers/torch.py @@ -90,11 +90,6 @@ def sizesof(value) -> Mapping[Operation[[], int], int]: >>> sizesof(Indexable(torch.ones(2, 3))[a(), b()]) {a: 2, b: 3} """ - return {} - - -@sizesof.register(Term) -def _sizesof_term(value: Term) -> Mapping[Operation[[], int], int]: sizes: dict[Operation[[], int], int] = {} def _torch_getitem_sizeof( @@ -129,15 +124,6 @@ def _torch_getitem_sizeof( return sizes -@sizesof.register(torch.distributions.Distribution) -def _sizesof_distribution( - value: torch.distributions.Distribution, -) -> Mapping[Operation[[], int], int]: - if isinstance(value, Term): - return _sizesof_term(value) - return {v: s for a in value.__dict__.values() for v, s in sizesof(a).items()} - - def _partial_eval(t: T, order: Optional[Sequence[Operation[[], int]]] = None) -> T: """Partially evaluate a term with respect to its sized free variables. diff --git a/tests/test_handlers_pyro.py b/tests/test_handlers_pyro.py index bfdc5432..95e10da3 100644 --- a/tests/test_handlers_pyro.py +++ b/tests/test_handlers_pyro.py @@ -8,8 +8,8 @@ import torch from effectful.handlers.pyro import ( - NamedDistribution, - PositionalDistribution, + named_distribution, + positional_distribution, PyroShim, pyro_sample, ) @@ -179,7 +179,7 @@ def _pyro_sample(self, msg): def test_named_dist(): x, y = defop(int, name="x"), defop(int, name="y") - d = NamedDistribution(dist.Normal(0.0, 1.0).expand((2, 3)), [x, y]) + d = named_distribution(dist.Normal(0.0, 1.0).expand((2, 3)), [x, y]) expected_indices = {x: 2, y: 3} @@ -203,7 +203,7 @@ def test_positional_dist(): expected_indices = {x: 2, y: 3} - d = PositionalDistribution(dist.Normal(loc, scale)) + d = positional_distribution(dist.Normal(loc, scale)) assert d.shape() == torch.Size([2, 3]) @@ -224,7 +224,7 @@ def test_positional_dist(): loc = Indexable(torch.tensor(0.0).expand((2, 3, 4, 5)))[x(), y()] scale = Indexable(torch.tensor(1.0).expand((2, 3, 4, 5)))[x(), y()] - d = PositionalDistribution(dist.Normal(loc, scale)) + d = positional_distribution(dist.Normal(loc, scale)) assert sizesof(d._from_positional(d.sample((6, 7)))) == expected_indices assert d.sample().shape == torch.Size([2, 3, 4, 5]) @@ -250,7 +250,7 @@ def test_sizesof_named_distribution(): names = [dim0, dim1] # Create named distribution - named_dist = NamedDistribution(base_dist, names) + named_dist = named_distribution(base_dist, names) # Get sizes sizes = sizesof(named_dist) @@ -269,6 +269,6 @@ def test_sizesof_positional_distribution(): var = Indexable(torch.ones(3, 4, 5))[dim0(), dim1()] base_dist = dist.Normal(mean, var) - pos_dist = PositionalDistribution(base_dist) + pos_dist = positional_distribution(base_dist) assert sizesof(pos_dist) == {} From 68a96e4b3da2446421d0ca415cd2530ed1edea6d Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Tue, 28 Jan 2025 17:47:33 -0500 Subject: [PATCH 03/19] wip --- effectful/handlers/pyro.py | 145 ++++++++++++++++++++---------------- effectful/handlers/torch.py | 14 ++-- effectful/ops/semantics.py | 4 +- effectful/ops/syntax.py | 4 +- tests/test_handlers_pyro.py | 4 +- 5 files changed, 95 insertions(+), 76 deletions(-) diff --git a/effectful/handlers/pyro.py b/effectful/handlers/pyro.py index b89b4188..493aa121 100644 --- a/effectful/handlers/pyro.py +++ b/effectful/handlers/pyro.py @@ -1,6 +1,18 @@ +from abc import ABC, abstractmethod +import functools import typing import warnings -from typing import Any, Collection, List, Mapping, Optional, Tuple, Annotated, TypeVar +from typing import ( + Any, + Collection, + List, + Mapping, + Optional, + Tuple, + Annotated, + TypeVar, + Callable, +) try: import pyro @@ -21,6 +33,7 @@ from effectful.handlers.torch import Indexable, sizesof, to_tensor from effectful.ops.syntax import defop, Scoped, defterm, defdata +from effectful.ops.semantics import call from effectful.ops.types import Operation, Term P = ParamSpec("P") @@ -270,66 +283,89 @@ def positional_distribution( raise NotImplementedError -@Term.register -class _DistributionTerm(TorchDistribution): - """A distribution wrapper that satisfies the Term interface.""" - - op: Operation[..., TorchDistribution] = defop(TorchDistribution) - args: Tuple[torch.Tensor, ...] - kwargs: Mapping[str, object] = {} - - __match_args__ = ("op", "args", "kwargs") - - def __init__(self, base_dist: TorchDistribution): - self.args = tuple(base_dist.__dict__.values()) - self.base_dist = base_dist - - def __getitem__(self, key: Collection[Operation[[], int]]): - return named_distribution(self, *key) +class _TorchDistributionWrapperMixin(ABC): + @property + @abstractmethod + def _base_dist(self): + pass @property def has_rsample(self): - return self.base_dist.has_rsample + return self._base_dist.has_rsample @property def batch_shape(self): - return self.base_dist.batch_shape + return self._base_dist.batch_shape @property def event_shape(self): - return self.base_dist.event_shape + return self._base_dist.event_shape @property def has_enumerate_support(self): - return self.base_dist.has_enumerate_support + return self._base_dist.has_enumerate_support @property def arg_constraints(self): - return self.base_dist.arg_constraints + return self._base_dist.arg_constraints @property def support(self): - return self.base_dist.support + return self._base_dist.support def sample(self, sample_shape=torch.Size()): - return self.base_dist.sample(sample_shape) + return self._base_dist.sample(sample_shape) def rsample(self, sample_shape=torch.Size()): - return self.base_dist.rsample(sample_shape) + return self._base_dist.rsample(sample_shape) def log_prob(self, value): - return self.base_dist.log_prob(value) + return self._base_dist.log_prob(value) def enumerate_support(self, expand=True): - return self.base_dist.enumerate_support(expand) + return self._base_dist.enumerate_support(expand) + + def __getitem__(self, key: Collection[Operation[[], int]]): + return named_distribution(self, *key) + + +@functools.cache +def _register_dist_constr(dist_constr: Callable[P, TorchDistribution]): + return defop(dist_constr) + + +@Term.register +class _DistributionTerm(_TorchDistributionWrapperMixin, TorchDistribution): + """A distribution wrapper that satisfies the Term interface. + + Represented as a term of the form call(D, *args, **kwargs) where D is the + distribution constructor. + + """ + + op: Operation = call + args: tuple + kwargs: Mapping[str, Any] = {} + + __match_args__ = ("op", "args", "kwargs") + + def __init__(self, base_dist: TorchDistribution): + self.args = (_register_dist_constr(type(base_dist)),) + tuple( + base_dist.__dict__.values() + ) + + @property + def _base_dist(self): + return self.args[0](*self.args[1:]) @defterm.register(TorchDistribution) +@defterm.register(TorchDistributionMixin) def _embed_dist(dist: TorchDistribution) -> Term[TorchDistribution]: return _DistributionTerm(dist) -class _PositionalDistributionTerm(_DistributionTerm): +class _PositionalDistributionTerm(_TorchDistributionWrapperMixin, TorchDistribution): """A distribution wrapper that lazily converts indexed dimensions to positional. @@ -344,23 +380,25 @@ class _PositionalDistributionTerm(_DistributionTerm): indices: Mapping[Operation[[], int], int] def __init__(self, base_dist: TorchDistribution): - super().__init__(base_dist) - self.args = (base_dist,) self.indices = sizesof(base_dist) n_base = len(base_dist.batch_shape) + len(base_dist.event_shape) self.naming = Naming.from_shape(self.indices.keys(), n_base) + @property + def _base_dist(self): + return self.args[0] + def _to_positional(self, value: torch.Tensor) -> torch.Tensor: - # self.base_dist has shape: | batch_shape | event_shape | & named + # self._base_dist has shape: | batch_shape | event_shape | & named # assume value comes from base_dist with shape: # | sample_shape | batch_shape | event_shape | & named # return a tensor of shape | sample_shape | named | batch_shape | event_shape | n_named = len(self.indices) dims = list(range(n_named + len(value.shape))) - n_base = len(self.event_shape) + len(self.base_dist.batch_shape) + n_base = len(self.event_shape) + len(self._base_dist.batch_shape) n_sample = len(value.shape) - n_base base_dims = dims[len(dims) - n_base :] @@ -380,48 +418,24 @@ def _from_positional(self, value: torch.Tensor) -> torch.Tensor: # maximal value shape: | sample_shape | named | batch_shape | event_shape | return self.naming.apply(value) - @property - def has_rsample(self): - return self.base_dist.has_rsample - @property def batch_shape(self): - return ( - torch.Size([s for s in self.indices.values()]) + self.base_dist.batch_shape - ) - - @property - def event_shape(self): - return self.base_dist.event_shape - - @property - def has_enumerate_support(self): - return self.base_dist.has_enumerate_support - - @property - def arg_constraints(self): - return self.base_dist.arg_constraints - - @property - def support(self): - return self.base_dist.support + return torch.Size([s for s in self.indices.values()]) + super().batch_shape def sample(self, sample_shape=torch.Size()): - return self._to_positional(self.base_dist.sample(sample_shape)) + return self._to_positional(super().sample(sample_shape)) def rsample(self, sample_shape=torch.Size()): - return self._to_positional(self.base_dist.rsample(sample_shape)) + return self._to_positional(super().rsample(sample_shape)) def log_prob(self, value): - return self._to_positional( - self.base_dist.log_prob(self._from_positional(value)) - ) + return self._to_positional(super().log_prob(self._from_positional(value))) def enumerate_support(self, expand=True): - return self._to_positional(self.base_dist.enumerate_support(expand)) + return self._to_positional(super().enumerate_support(expand)) -class _NamedDistributionTerm(_DistributionTerm): +class _NamedDistributionTerm(_TorchDistributionWrapperMixin, TorchDistribution): """A distribution wrapper that lazily names leftmost dimensions.""" op: Operation[..., TorchDistribution] = named_distribution @@ -437,7 +451,7 @@ def __init__(self, base_dist: TorchDistribution, *names: Operation[[], int]): :param names: A list of names. """ - self.base_dist = base_dist + self.args = (base_dist,) + tuple(names) self.names = names assert 1 <= len(names) <= len(base_dist.batch_shape) @@ -446,7 +460,10 @@ def __init__(self, base_dist: TorchDistribution, *names: Operation[[], int]): n_base = len(base_dist.batch_shape) + len(base_dist.event_shape) self.naming = Naming.from_shape(names, n_base - len(names)) - super().__init__() + + @property + def _base_dist(self): + return self.args[0] def _to_named(self, value: torch.Tensor, offset=0) -> torch.Tensor: return self.naming.apply(value) diff --git a/effectful/handlers/torch.py b/effectful/handlers/torch.py index 0064666f..6f783c08 100644 --- a/effectful/handlers/torch.py +++ b/effectful/handlers/torch.py @@ -23,7 +23,7 @@ import effectful.handlers.numbers # noqa: F401 from effectful.internals.runtime import interpreter from effectful.ops.semantics import apply, evaluate, fvsof, typeof -from effectful.ops.syntax import defdata, defop +from effectful.ops.syntax import defdata, defop, defterm from effectful.ops.types import Expr, Operation, Term P = ParamSpec("P") @@ -122,12 +122,12 @@ def _torch_getitem_sizeof( return defdata(torch_getitem, x, key) - with interpreter( - { - torch_getitem: _torch_getitem_sizeof, - apply: lambda _, op, *a, **k: defdata(op, *a, **k), - } - ): + def _apply(_, op, *args, **kwargs): + args, kwargs = tree.map_structure(defterm, (args, kwargs)) + return defdata(op, *args, **kwargs) + + value = defterm(value) + with interpreter({torch_getitem: _torch_getitem_sizeof, apply: _apply}): evaluate(value) return sizes diff --git a/effectful/ops/semantics.py b/effectful/ops/semantics.py index 1d429daa..5ad1e131 100644 --- a/effectful/ops/semantics.py +++ b/effectful/ops/semantics.py @@ -37,7 +37,9 @@ def apply(intp: Interpretation, op: Operation, *args, **kwargs) -> Any: By installing an :func:`apply` handler, we capture the term instead: - >>> with handler({apply: lambda _, op, *args, **kwargs: op.__free_rule__(*args, **kwargs) }): + >>> def default(*args, **kwargs): + ... raise NotImplementedError + >>> with handler({apply: default }): ... term = mul(add(1, 2), 3) >>> term mul(add(1, 2), 3) diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 3bf6eec4..19196c8f 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -89,7 +89,7 @@ class Scoped(Annotation): >>> @defop ... def LambdaN( - ... body: Annotated[T, Scoped[A | B]] + ... body: Annotated[T, Scoped[A | B]], ... *args: Annotated[Operation[[], S], Scoped[A]], ... **kwargs: Annotated[Operation[[], S], Scoped[A]] ... ) -> Annotated[Callable[..., T], Scoped[B]]: @@ -280,7 +280,7 @@ def infer_annotations(cls, sig: inspect.Signature) -> inspect.Signature: # pre-conditions assert cls._check_has_single_scope(sig) assert cls._check_no_typevar_overlap(sig) - assert cls._check_no_boundvars_in_result(sig) + # assert cls._check_no_boundvars_in_result(sig) root_ordinal = cls._get_root_ordinal(sig) if not root_ordinal: diff --git a/tests/test_handlers_pyro.py b/tests/test_handlers_pyro.py index 95e10da3..d626e6c6 100644 --- a/tests/test_handlers_pyro.py +++ b/tests/test_handlers_pyro.py @@ -179,7 +179,7 @@ def _pyro_sample(self, msg): def test_named_dist(): x, y = defop(int, name="x"), defop(int, name="y") - d = named_distribution(dist.Normal(0.0, 1.0).expand((2, 3)), [x, y]) + d = named_distribution(dist.Normal(0.0, 1.0).expand((2, 3)), x, y) expected_indices = {x: 2, y: 3} @@ -250,7 +250,7 @@ def test_sizesof_named_distribution(): names = [dim0, dim1] # Create named distribution - named_dist = named_distribution(base_dist, names) + named_dist = named_distribution(base_dist, *names) # Get sizes sizes = sizesof(named_dist) From 330ac0df2d55ec5435fbd1b8010c57a34b8fe845 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Wed, 29 Jan 2025 11:03:49 -0500 Subject: [PATCH 04/19] switch to argument rebuilding approach --- docs/source/named_tensor_notation.ipynb | 4 +- effectful/handlers/pyro.py | 339 ++++++++++++++---------- tests/test_handlers_indexed.py | 2 + tests/test_handlers_pyro.py | 12 +- 4 files changed, 204 insertions(+), 153 deletions(-) diff --git a/docs/source/named_tensor_notation.ipynb b/docs/source/named_tensor_notation.ipynb index f7995287..d4bffc38 100644 --- a/docs/source/named_tensor_notation.ipynb +++ b/docs/source/named_tensor_notation.ipynb @@ -1181,7 +1181,9 @@ ], "source": [ "P = Indexable(\n", - " torch.tensor([[6, 2, 4, 2], [8, 2, 1, 3], [5, 5, 7, 0], [1, 3, 8, 2], [5, 9, 2, 3]]),\n", + " torch.tensor(\n", + " [[6, 2, 4, 2], [8, 2, 1, 3], [5, 5, 7, 0], [1, 3, 8, 2], [5, 9, 2, 3]]\n", + " ),\n", ")[vocab(), seq()]\n", "\n", "subst(P, {vocab: I0})" diff --git a/effectful/handlers/pyro.py b/effectful/handlers/pyro.py index 493aa121..f3fbd276 100644 --- a/effectful/handlers/pyro.py +++ b/effectful/handlers/pyro.py @@ -1,17 +1,16 @@ -from abc import ABC, abstractmethod import functools import typing import warnings from typing import ( + Annotated, Any, Collection, List, Mapping, Optional, Tuple, - Annotated, + Type, TypeVar, - Callable, ) try: @@ -19,9 +18,10 @@ except ImportError: raise ImportError("Pyro is required to use effectful.handlers.pyro.") +import pyro.distributions as dist from pyro.distributions.torch_distribution import ( - TorchDistributionMixin, TorchDistribution, + TorchDistributionMixin, ) try: @@ -32,8 +32,8 @@ from typing_extensions import ParamSpec from effectful.handlers.torch import Indexable, sizesof, to_tensor -from effectful.ops.syntax import defop, Scoped, defterm, defdata from effectful.ops.semantics import call +from effectful.ops.syntax import Scoped, defop, defterm from effectful.ops.types import Operation, Term P = ParamSpec("P") @@ -160,8 +160,8 @@ def _pyro_sample(self, msg: pyro.poutine.runtime.Message) -> None: # pdist shape: | named1 | batch_shape | event_shape | # obs shape: | batch_shape | event_shape |, | named2 | where named2 may overlap named1 - pdist = positional_distribution(dist) - naming = pdist.naming + indices = sizesof(dist) + pdist, naming = positional_distribution(dist) if msg["mask"] is None: mask = torch.tensor(True) @@ -170,14 +170,12 @@ def _pyro_sample(self, msg: pyro.poutine.runtime.Message) -> None: else: mask = msg["mask"] - pos_mask, _ = PyroShim._broadcast_to_named( - mask, dist.batch_shape, pdist.indices - ) + pos_mask, _ = PyroShim._broadcast_to_named(mask, dist.batch_shape, indices) pos_obs: Optional[torch.Tensor] = None if obs is not None: pos_obs, naming = PyroShim._broadcast_to_named( - obs, dist.shape(), pdist.indices + obs, dist.shape(), indices ) for var, dim in naming.name_to_dim.items(): @@ -273,21 +271,59 @@ def named_distribution( dist: Annotated[TorchDistribution, Scoped[A]], *names: Annotated[Operation[[], int], Scoped[B]], ) -> Annotated[TorchDistribution, Scoped[A | B]]: - raise NotImplementedError + match defterm(dist): + case Term(op=_call, args=(dist_constr, *args)) if _call is call: + named_args = [] + for a in args: + assert isinstance(a, torch.Tensor) + named_args.append( + Indexable(typing.cast(torch.Tensor, a))[tuple(n() for n in names)] + ) + assert callable(dist_constr) + return defterm(dist_constr(*named_args)) + case _: + raise NotImplementedError @defop def positional_distribution( dist: Annotated[TorchDistribution, Scoped[A]] -) -> TorchDistribution: - raise NotImplementedError +) -> Tuple[TorchDistribution, Naming]: + match defterm(dist): + case Term(op=_call, args=(dist_constr, *args)) if _call is call: + assert callable(dist_constr) + base_dist = dist_constr(*args) + indices = sizesof(base_dist).keys() + n_base = len(base_dist.batch_shape) + len(base_dist.event_shape) + naming = Naming.from_shape(indices, n_base) + pos_args = [to_tensor(a, indices) for a in args] + pos_dist = dist_constr(*pos_args) + return defterm(pos_dist), naming + case _: + raise NotImplementedError + + +@Term.register +class _DistributionTerm(TorchDistribution): + """A distribution wrapper that satisfies the Term interface. + + Represented as a term of the form call(D, *args, **kwargs) where D is the + distribution constructor. + + """ + + op: Operation = call + args: tuple + kwargs: Mapping[str, Any] = {} + + __match_args__ = ("op", "args", "kwargs") + def __init__(self, dist_constr: Type[TorchDistribution], *args): + self.args = (dist_constr,) + tuple(defterm(a) for a in args) -class _TorchDistributionWrapperMixin(ABC): @property - @abstractmethod def _base_dist(self): - pass + return self.args[0](*self.args[1:]) @property def has_rsample(self): @@ -329,190 +365,201 @@ def __getitem__(self, key: Collection[Operation[[], int]]): return named_distribution(self, *key) -@functools.cache -def _register_dist_constr(dist_constr: Callable[P, TorchDistribution]): - return defop(dist_constr) +@defterm.register(TorchDistribution) +@defterm.register(TorchDistributionMixin) +@functools.singledispatch +def _embed_dist(dist: TorchDistribution) -> Term[TorchDistribution]: + raise ValueError( + "No embedding provided for distribution of type {type(dist).__name__}." + ) -@Term.register -class _DistributionTerm(_TorchDistributionWrapperMixin, TorchDistribution): - """A distribution wrapper that satisfies the Term interface. +@_embed_dist.register(dist.Bernoulli) +def _embed_bernoulli(d: dist.Bernoulli) -> Term[TorchDistribution]: + return _DistributionTerm(dist.Bernoulli, d.probs) - Represented as a term of the form call(D, *args, **kwargs) where D is the - distribution constructor. - """ +@_embed_dist.register(dist.Beta) +def _embed_beta(d: dist.Beta) -> Term[TorchDistribution]: + return _DistributionTerm(dist.Beta, d.concentration1, d.concentration0) - op: Operation = call - args: tuple - kwargs: Mapping[str, Any] = {} - __match_args__ = ("op", "args", "kwargs") +@_embed_dist.register(dist.Binomial) +def _embed_binomial(d: dist.Binomial) -> Term[TorchDistribution]: + return _DistributionTerm(dist.Binomial, d.total_count, d.probs) - def __init__(self, base_dist: TorchDistribution): - self.args = (_register_dist_constr(type(base_dist)),) + tuple( - base_dist.__dict__.values() - ) - @property - def _base_dist(self): - return self.args[0](*self.args[1:]) +@_embed_dist.register(dist.Categorical) +def _embed_categorical(d: dist.Categorical) -> Term[TorchDistribution]: + return _DistributionTerm(dist.Categorical, d.probs) -@defterm.register(TorchDistribution) -@defterm.register(TorchDistributionMixin) -def _embed_dist(dist: TorchDistribution) -> Term[TorchDistribution]: - return _DistributionTerm(dist) +@_embed_dist.register(dist.Cauchy) +def _embed_cauchy(d: dist.Cauchy) -> Term[TorchDistribution]: + return _DistributionTerm(dist.Cauchy, d.loc, d.scale) -class _PositionalDistributionTerm(_TorchDistributionWrapperMixin, TorchDistribution): - """A distribution wrapper that lazily converts indexed dimensions to - positional. +@_embed_dist.register(dist.Chi2) +def _embed_chi2(d: dist.Chi2) -> Term[TorchDistribution]: + return _DistributionTerm(dist.Chi2, d.df) - """ - op: Operation[..., TorchDistribution] = positional_distribution - args: Tuple[TorchDistribution] - kwargs: Mapping[str, object] = {} +@_embed_dist.register(dist.ContinuousBernoulli) +def _embed_continuous_bernoulli( + d: dist.ContinuousBernoulli, +) -> Term[TorchDistribution]: + return _DistributionTerm(dist.ContinuousBernoulli, d.probs) - __match_args__ = ("op", "args", "kwargs") - indices: Mapping[Operation[[], int], int] +@_embed_dist.register(dist.Dirichlet) +def _embed_dirichlet(d: dist.Dirichlet) -> Term[TorchDistribution]: + return _DistributionTerm(dist.Dirichlet, d.concentration) - def __init__(self, base_dist: TorchDistribution): - self.args = (base_dist,) - self.indices = sizesof(base_dist) - n_base = len(base_dist.batch_shape) + len(base_dist.event_shape) - self.naming = Naming.from_shape(self.indices.keys(), n_base) +@_embed_dist.register(dist.Exponential) +def _embed_exponential(d: dist.Exponential) -> Term[TorchDistribution]: + return _DistributionTerm(dist.Exponential, d.rate) - @property - def _base_dist(self): - return self.args[0] - def _to_positional(self, value: torch.Tensor) -> torch.Tensor: - # self._base_dist has shape: | batch_shape | event_shape | & named - # assume value comes from base_dist with shape: - # | sample_shape | batch_shape | event_shape | & named - # return a tensor of shape | sample_shape | named | batch_shape | event_shape | - n_named = len(self.indices) - dims = list(range(n_named + len(value.shape))) +@_embed_dist.register(dist.FisherSnedecor) +def _embed_fisher_snedecor(d: dist.FisherSnedecor) -> Term[TorchDistribution]: + return _DistributionTerm(dist.FisherSnedecor, d.df1, d.df2) - n_base = len(self.event_shape) + len(self._base_dist.batch_shape) - n_sample = len(value.shape) - n_base - base_dims = dims[len(dims) - n_base :] - named_dims = dims[:n_named] - sample_dims = dims[n_named : n_named + n_sample] +@_embed_dist.register(dist.Gamma) +def _embed_gamma(d: dist.Gamma) -> Term[TorchDistribution]: + return _DistributionTerm(dist.Gamma, d.concentration, d.rate) - # shape: | named | sample_shape | batch_shape | event_shape | - # TODO: replace with something more efficient - pos_tensor = to_tensor(value, self.indices.keys()) - # shape: | sample_shape | named | batch_shape | event_shape | - pos_tensor_r = torch.permute(pos_tensor, sample_dims + named_dims + base_dims) +@_embed_dist.register(dist.Geometric) +def _embed_geometric(d: dist.Geometric) -> Term[TorchDistribution]: + return _DistributionTerm(dist.Geometric, d.probs) - return pos_tensor_r - def _from_positional(self, value: torch.Tensor) -> torch.Tensor: - # maximal value shape: | sample_shape | named | batch_shape | event_shape | - return self.naming.apply(value) +@_embed_dist.register(dist.Gumbel) +def _embed_gumbel(d: dist.Gumbel) -> Term[TorchDistribution]: + return _DistributionTerm(dist.Gumbel, d.loc, d.scale) - @property - def batch_shape(self): - return torch.Size([s for s in self.indices.values()]) + super().batch_shape - def sample(self, sample_shape=torch.Size()): - return self._to_positional(super().sample(sample_shape)) +@_embed_dist.register(dist.HalfCauchy) +def _embed_half_cauchy(d: dist.HalfCauchy) -> Term[TorchDistribution]: + return _DistributionTerm(dist.HalfCauchy, d.scale) - def rsample(self, sample_shape=torch.Size()): - return self._to_positional(super().rsample(sample_shape)) - def log_prob(self, value): - return self._to_positional(super().log_prob(self._from_positional(value))) +@_embed_dist.register(dist.HalfNormal) +def _embed_half_normal(d: dist.HalfNormal) -> Term[TorchDistribution]: + return _DistributionTerm(dist.HalfNormal, d.scale) - def enumerate_support(self, expand=True): - return self._to_positional(super().enumerate_support(expand)) +@_embed_dist.register(dist.Independent) +def _embed_independent(d: dist.Independent) -> Term[TorchDistribution]: + return _DistributionTerm(dist.Independent, d.base_dist, d.reinterpreted_batch_ndims) -class _NamedDistributionTerm(_TorchDistributionWrapperMixin, TorchDistribution): - """A distribution wrapper that lazily names leftmost dimensions.""" - op: Operation[..., TorchDistribution] = named_distribution - args: tuple - kwargs: Mapping[str, object] = {} +@_embed_dist.register(dist.Kumaraswamy) +def _embed_kumaraswamy(d: dist.Kumaraswamy) -> Term[TorchDistribution]: + return _DistributionTerm(dist.Kumaraswamy, d.concentration1, d.concentration0) - __match_args__ = ("op", "args", "kwargs") - def __init__(self, base_dist: TorchDistribution, *names: Operation[[], int]): - """ - :param base_dist: A distribution with batch dimensions. +@_embed_dist.register(dist.LKJCholesky) +def _embed_lkj_cholesky(d: dist.LKJCholesky) -> Term[TorchDistribution]: + return _DistributionTerm(dist.LKJCholesky, d.dim, d.concentration) - :param names: A list of names. - """ - self.args = (base_dist,) + tuple(names) - self.names = names +@_embed_dist.register(dist.Laplace) +def _embed_laplace(d: dist.Laplace) -> Term[TorchDistribution]: + return _DistributionTerm(dist.Laplace, d.loc, d.scale) - assert 1 <= len(names) <= len(base_dist.batch_shape) - base_indices = sizesof(base_dist) - assert not any(n in base_indices for n in names) - n_base = len(base_dist.batch_shape) + len(base_dist.event_shape) - self.naming = Naming.from_shape(names, n_base - len(names)) +@_embed_dist.register(dist.LogNormal) +def _embed_log_normal(d: dist.LogNormal) -> Term[TorchDistribution]: + return _DistributionTerm(dist.LogNormal, d.loc, d.scale) - @property - def _base_dist(self): - return self.args[0] - def _to_named(self, value: torch.Tensor, offset=0) -> torch.Tensor: - return self.naming.apply(value) +@_embed_dist.register(dist.LogisticNormal) +def _embed_logistic_normal(d: dist.LogisticNormal) -> Term[TorchDistribution]: + return _DistributionTerm(dist.LogisticNormal, d.loc, d.scale) - def _from_named(self, value: torch.Tensor) -> torch.Tensor: - pos_value = to_tensor(value, self.names) - dims = list(range(len(pos_value.shape))) +@_embed_dist.register(dist.Multinomial) +def _embed_multinomial(d: dist.Multinomial) -> Term[TorchDistribution]: + return _DistributionTerm(dist.Multinomial, d.total_count, d.probs) - n_base = len(self.event_shape) + len(self.batch_shape) - n_named = len(self.names) - n_sample = len(pos_value.shape) - n_base - n_named - base_dims = dims[len(dims) - n_base :] - named_dims = dims[:n_named] - sample_dims = dims[n_named : n_named + n_sample] +@_embed_dist.register(dist.MultivariateNormal) +def _embed_multivariate_normal( + d: dist.MultivariateNormal, +) -> Term[TorchDistribution]: + return _DistributionTerm(dist.MultivariateNormal, d.loc, d.scale_tril) - pos_tensor_r = torch.permute(pos_value, sample_dims + named_dims + base_dims) - return pos_tensor_r +@_embed_dist.register(dist.NegativeBinomial) +def _embed_negative_binomial(d: dist.NegativeBinomial) -> Term[TorchDistribution]: + return _DistributionTerm(dist.NegativeBinomial, d.total_count, d.probs) - @property - def batch_shape(self): - return super().batch_shape[len(self.names) :] - def sample(self, sample_shape=torch.Size()): - t = self._to_named(super().sample(sample_shape), offset=len(sample_shape)) - assert set(sizesof(t).keys()) == set(self.names) - assert t.shape == self.shape() + sample_shape - return t +@_embed_dist.register(dist.Normal) +def _embed_normal(d: dist.Normal) -> Term[TorchDistribution]: + return _DistributionTerm(dist.Normal, d.loc, d.scale) - def rsample(self, sample_shape=torch.Size()): - return self._to_named(super().rsample(sample_shape), offset=len(sample_shape)) - def log_prob(self, value): - return self._to_named(super().log_prob(self._from_named(value))) +@_embed_dist.register(dist.OneHotCategorical) +def _embed_one_hot_categorical(d: dist.OneHotCategorical) -> Term[TorchDistribution]: + return _DistributionTerm(dist.OneHotCategorical, d.probs) - def enumerate_support(self, expand=True): - return self._to_named(super().enumerate_support(expand)) + +@_embed_dist.register(dist.OneHotCategoricalStraightThrough) +def _embed_one_hot_categorical_straight_through( + d: dist.OneHotCategoricalStraightThrough, +) -> Term[TorchDistribution]: + return _DistributionTerm(dist.OneHotCategoricalStraightThrough, d.probs) + + +@_embed_dist.register(dist.Pareto) +def _embed_pareto(d: dist.Pareto) -> Term[TorchDistribution]: + return _DistributionTerm(dist.Pareto, d.scale, d.alpha) + + +@_embed_dist.register(dist.Poisson) +def _embed_poisson(d: dist.Poisson) -> Term[TorchDistribution]: + return _DistributionTerm(dist.Poisson, d.rate) + + +@_embed_dist.register(dist.RelaxedBernoulli) +def _embed_relaxed_bernoulli(d: dist.RelaxedBernoulli) -> Term[TorchDistribution]: + return _DistributionTerm(dist.RelaxedBernoulli, d.temperature, d.probs) + + +@_embed_dist.register(dist.RelaxedOneHotCategorical) +def _embed_relaxed_one_hot_categorical( + d: dist.RelaxedOneHotCategorical, +) -> Term[TorchDistribution]: + return _DistributionTerm(dist.RelaxedOneHotCategorical, d.temperature, d.probs) + + +@_embed_dist.register(dist.StudentT) +def _embed_student_t(d: dist.StudentT) -> Term[TorchDistribution]: + return _DistributionTerm(dist.StudentT, d.df, d.loc, d.scale) + + +@_embed_dist.register(dist.Uniform) +def _embed_uniform(d: dist.Uniform) -> Term[TorchDistribution]: + return _DistributionTerm(dist.Uniform, d.low, d.high) + + +@_embed_dist.register(dist.VonMises) +def _embed_von_mises(d: dist.VonMises) -> Term[TorchDistribution]: + return _DistributionTerm(dist.VonMises, d.loc, d.concentration) + + +@_embed_dist.register(dist.Weibull) +def _embed_weibull(d: dist.Weibull) -> Term[TorchDistribution]: + return _DistributionTerm(dist.Weibull, d.scale, d.concentration) -@defdata.register(TorchDistribution) -def _(op, *args, **kwargs): - if op is named_distribution: - return _NamedDistributionTerm(*args, **kwargs) - elif op is positional_distribution: - return _PositionalDistributionTerm(*args, **kwargs) - else: - Term(op, *args, **kwargs) +@_embed_dist.register(dist.Wishart) +def _embed_wishart(d: dist.Wishart) -> Term[TorchDistribution]: + return _DistributionTerm(dist.Wishart, d.df, d.scale_tril) def pyro_module_shim( diff --git a/tests/test_handlers_indexed.py b/tests/test_handlers_indexed.py index 51dde021..4d9635e9 100644 --- a/tests/test_handlers_indexed.py +++ b/tests/test_handlers_indexed.py @@ -5,6 +5,8 @@ import pytest import torch +# required to register embedding for distributions +import effectful.handlers.pyro # noqa: F401 from effectful.handlers.indexed import ( IndexSet, cond, diff --git a/tests/test_handlers_pyro.py b/tests/test_handlers_pyro.py index d626e6c6..d39973ae 100644 --- a/tests/test_handlers_pyro.py +++ b/tests/test_handlers_pyro.py @@ -8,9 +8,9 @@ import torch from effectful.handlers.pyro import ( + PyroShim, named_distribution, positional_distribution, - PyroShim, pyro_sample, ) from effectful.handlers.torch import Indexable, sizesof, torch_getitem @@ -203,14 +203,14 @@ def test_positional_dist(): expected_indices = {x: 2, y: 3} - d = positional_distribution(dist.Normal(loc, scale)) + d, naming = positional_distribution(dist.Normal(loc, scale)) assert d.shape() == torch.Size([2, 3]) s1 = d.sample() assert sizesof(s1) == {} assert s1.shape == torch.Size([2, 3]) - assert all(n in sizesof(d._from_positional(s1)) for n in [x, y]) + assert all(n in sizesof(naming.apply(s1)) for n in [x, y]) d_exp = d.expand((4, 5) + d.batch_shape) s2 = d_exp.sample() @@ -220,13 +220,13 @@ def test_positional_dist(): s3 = d.sample((4, 5)) assert sizesof(s3) == {} assert s3.shape == torch.Size([4, 5, 2, 3]) - assert all(n in sizesof(d._from_positional(s3)) for n in [x, y]) + assert all(n in sizesof(naming.apply(s3)) for n in [x, y]) loc = Indexable(torch.tensor(0.0).expand((2, 3, 4, 5)))[x(), y()] scale = Indexable(torch.tensor(1.0).expand((2, 3, 4, 5)))[x(), y()] - d = positional_distribution(dist.Normal(loc, scale)) + d, naming = positional_distribution(dist.Normal(loc, scale)) - assert sizesof(d._from_positional(d.sample((6, 7)))) == expected_indices + assert sizesof(naming.apply(d.sample((6, 7)))) == expected_indices assert d.sample().shape == torch.Size([2, 3, 4, 5]) assert d.sample((6, 7)).shape == torch.Size([6, 7, 2, 3, 4, 5]) From 5c63a0096152cd21d985275e616138e2a2cc4e2f Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Wed, 29 Jan 2025 11:15:53 -0500 Subject: [PATCH 05/19] format --- docs/source/semi_ring.py | 4 +++- effectful/handlers/pyro.py | 2 +- effectful/handlers/torch.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/source/semi_ring.py b/docs/source/semi_ring.py index b8a91fd2..d8a8bcad 100644 --- a/docs/source/semi_ring.py +++ b/docs/source/semi_ring.py @@ -183,7 +183,9 @@ def vertical_fusion(e1: T, x: Operation[[], T], e2: S) -> S: case ( Term(ops.Sum, (e_sum, k1, v1, Term(ops.Dict, (Term(k1a), e_lhs)))), Term(ops.Sum, (Term(xa), k2, v2, Term(ops.Dict, (Term(k2a), e_rhs)))), - ) if x == xa and k1 == k1a and k2 == k2a: + ) if ( + x == xa and k1 == k1a and k2 == k2a + ): return evaluate( Sum( e_sum, # type: ignore diff --git a/effectful/handlers/pyro.py b/effectful/handlers/pyro.py index f3fbd276..91e0de0f 100644 --- a/effectful/handlers/pyro.py +++ b/effectful/handlers/pyro.py @@ -287,7 +287,7 @@ def named_distribution( @defop def positional_distribution( - dist: Annotated[TorchDistribution, Scoped[A]] + dist: Annotated[TorchDistribution, Scoped[A]], ) -> Tuple[TorchDistribution, Naming]: match defterm(dist): case Term(op=_call, args=(dist_constr, *args)) if _call is call: diff --git a/effectful/handlers/torch.py b/effectful/handlers/torch.py index 6f783c08..d8adfbbd 100644 --- a/effectful/handlers/torch.py +++ b/effectful/handlers/torch.py @@ -469,7 +469,7 @@ def grad_fn(self): def _indexed_func_wrapper( - func: Callable[P, T] + func: Callable[P, T], ) -> Tuple[Callable[P, S], Callable[[S], T]]: # index expressions for the result of the function indexes = None From 1f778423e8ce784d12d37e823b45efb467f92ad6 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Wed, 29 Jan 2025 17:59:08 -0500 Subject: [PATCH 06/19] handle additional distributions and consolidate --- effectful/handlers/pyro.py | 225 ++++++++++++++++--------------------- 1 file changed, 94 insertions(+), 131 deletions(-) diff --git a/effectful/handlers/pyro.py b/effectful/handlers/pyro.py index 91e0de0f..fe22facc 100644 --- a/effectful/handlers/pyro.py +++ b/effectful/handlers/pyro.py @@ -225,7 +225,6 @@ def _pyro_post_sample(self, msg: pyro.poutine.runtime.Message) -> None: if value is not None: # note: is it safe to assume that msg['fn'] is a distribution? - assert isinstance(msg["fn"], TorchDistribution) dist_shape: tuple[int, ...] = msg["fn"].batch_shape + msg["fn"].event_shape if len(value.shape) < len(dist_shape): value = value.broadcast_to( @@ -268,11 +267,12 @@ def __repr__(self): @defop def named_distribution( - dist: Annotated[TorchDistribution, Scoped[A]], + d: Annotated[TorchDistribution, Scoped[A]], *names: Annotated[Operation[[], int], Scoped[B]], ) -> Annotated[TorchDistribution, Scoped[A | B]]: - match defterm(dist): + match defterm(d): case Term(op=_call, args=(dist_constr, *args)) if _call is call: + assert dist_constr is not dist.Independent named_args = [] for a in args: assert isinstance(a, torch.Tensor) @@ -280,46 +280,59 @@ def named_distribution( Indexable(typing.cast(torch.Tensor, a))[tuple(n() for n in names)] ) assert callable(dist_constr) - return defterm(dist_constr(*named_args)) + return dist_constr(*named_args) case _: raise NotImplementedError @defop def positional_distribution( - dist: Annotated[TorchDistribution, Scoped[A]], + d: Annotated[TorchDistribution, Scoped[A]], ) -> Tuple[TorchDistribution, Naming]: - match defterm(dist): + match defterm(d): case Term(op=_call, args=(dist_constr, *args)) if _call is call: assert callable(dist_constr) base_dist = dist_constr(*args) - indices = sizesof(base_dist).keys() + indices = sizesof(d).keys() n_base = len(base_dist.batch_shape) + len(base_dist.event_shape) naming = Naming.from_shape(indices, n_base) pos_args = [to_tensor(a, indices) for a in args] pos_dist = dist_constr(*pos_args) - return defterm(pos_dist), naming + return pos_dist, naming case _: raise NotImplementedError -@Term.register -class _DistributionTerm(TorchDistribution): +class _DistributionTerm(Term[TorchDistribution], TorchDistribution): """A distribution wrapper that satisfies the Term interface. Represented as a term of the form call(D, *args, **kwargs) where D is the distribution constructor. + Note: When we construct instances of this class, we put distribution + parameters that can be expanded in the args list and those that cannot in + the kwargs list. + """ - op: Operation = call - args: tuple - kwargs: Mapping[str, Any] = {} + _args: tuple + _kwargs: dict + + def __init__(self, dist_constr: Type[TorchDistribution], *args, **kwargs): + self._args = (dist_constr,) + tuple(args) + self._kwargs = kwargs + + @property + def op(self): + return call - __match_args__ = ("op", "args", "kwargs") + @property + def args(self): + return self._args - def __init__(self, dist_constr: Type[TorchDistribution], *args): - self.args = (dist_constr,) + tuple(defterm(a) for a in args) + @property + def kwargs(self): + return self._kwargs @property def _base_dist(self): @@ -368,200 +381,150 @@ def __getitem__(self, key: Collection[Operation[[], int]]): @defterm.register(TorchDistribution) @defterm.register(TorchDistributionMixin) @functools.singledispatch -def _embed_dist(dist: TorchDistribution) -> Term[TorchDistribution]: +def _embed_distribution(dist: TorchDistribution) -> Term[TorchDistribution]: raise ValueError( - "No embedding provided for distribution of type {type(dist).__name__}." + f"No embedding provided for distribution of type {type(dist).__name__}." ) -@_embed_dist.register(dist.Bernoulli) -def _embed_bernoulli(d: dist.Bernoulli) -> Term[TorchDistribution]: - return _DistributionTerm(dist.Bernoulli, d.probs) +@_embed_distribution.register(dist.ExpandedDistribution) +def _embed_expanded(d: dist.ExpandedDistribution) -> Term[TorchDistribution]: + if d._batch_shape == d.base_dist.batch_shape: + return d.base_dist + raise ValueError("Nontrivial ExpandedDistribution not implemented.") -@_embed_dist.register(dist.Beta) -def _embed_beta(d: dist.Beta) -> Term[TorchDistribution]: - return _DistributionTerm(dist.Beta, d.concentration1, d.concentration0) +@_embed_distribution.register(dist.Independent) +def _embed_independent(d) -> Term[TorchDistribution]: + return _DistributionTerm(type(d), d.base_dist, d.reinterpreted_batch_ndims) -@_embed_dist.register(dist.Binomial) -def _embed_binomial(d: dist.Binomial) -> Term[TorchDistribution]: - return _DistributionTerm(dist.Binomial, d.total_count, d.probs) +@_embed_distribution.register(dist.Cauchy) +@_embed_distribution.register(dist.Gumbel) +@_embed_distribution.register(dist.Laplace) +@_embed_distribution.register(dist.LogNormal) +@_embed_distribution.register(dist.LogisticNormal) +@_embed_distribution.register(dist.Normal) +@_embed_distribution.register(dist.StudentT) +def _embed_loc_scale(d) -> Term[TorchDistribution]: + return _DistributionTerm(type(d), d.loc, d.scale) -@_embed_dist.register(dist.Categorical) -def _embed_categorical(d: dist.Categorical) -> Term[TorchDistribution]: - return _DistributionTerm(dist.Categorical, d.probs) +@_embed_distribution.register(dist.Bernoulli) +@_embed_distribution.register(dist.Categorical) +@_embed_distribution.register(dist.ContinuousBernoulli) +@_embed_distribution.register(dist.Geometric) +@_embed_distribution.register(dist.OneHotCategorical) +@_embed_distribution.register(dist.OneHotCategoricalStraightThrough) +def _embed_probs(d) -> Term[TorchDistribution]: + return _DistributionTerm(type(d), d.probs) -@_embed_dist.register(dist.Cauchy) -def _embed_cauchy(d: dist.Cauchy) -> Term[TorchDistribution]: - return _DistributionTerm(dist.Cauchy, d.loc, d.scale) +@_embed_distribution.register(dist.Beta) +@_embed_distribution.register(dist.Kumaraswamy) +def _embed_beta(d) -> Term[TorchDistribution]: + return _DistributionTerm(type(d), d.concentration1, d.concentration0) -@_embed_dist.register(dist.Chi2) -def _embed_chi2(d: dist.Chi2) -> Term[TorchDistribution]: - return _DistributionTerm(dist.Chi2, d.df) +@_embed_distribution.register(dist.Binomial) +def _embed_binomial(d: dist.Binomial) -> Term[TorchDistribution]: + return _DistributionTerm(dist.Binomial, d.total_count, d.probs) -@_embed_dist.register(dist.ContinuousBernoulli) -def _embed_continuous_bernoulli( - d: dist.ContinuousBernoulli, -) -> Term[TorchDistribution]: - return _DistributionTerm(dist.ContinuousBernoulli, d.probs) +@_embed_distribution.register(dist.Chi2) +def _embed_chi2(d: dist.Chi2) -> Term[TorchDistribution]: + return _DistributionTerm(dist.Chi2, d.df) -@_embed_dist.register(dist.Dirichlet) +@_embed_distribution.register(dist.Dirichlet) def _embed_dirichlet(d: dist.Dirichlet) -> Term[TorchDistribution]: return _DistributionTerm(dist.Dirichlet, d.concentration) -@_embed_dist.register(dist.Exponential) +@_embed_distribution.register(dist.Exponential) def _embed_exponential(d: dist.Exponential) -> Term[TorchDistribution]: return _DistributionTerm(dist.Exponential, d.rate) -@_embed_dist.register(dist.FisherSnedecor) +@_embed_distribution.register(dist.FisherSnedecor) def _embed_fisher_snedecor(d: dist.FisherSnedecor) -> Term[TorchDistribution]: return _DistributionTerm(dist.FisherSnedecor, d.df1, d.df2) -@_embed_dist.register(dist.Gamma) +@_embed_distribution.register(dist.Gamma) def _embed_gamma(d: dist.Gamma) -> Term[TorchDistribution]: return _DistributionTerm(dist.Gamma, d.concentration, d.rate) -@_embed_dist.register(dist.Geometric) -def _embed_geometric(d: dist.Geometric) -> Term[TorchDistribution]: - return _DistributionTerm(dist.Geometric, d.probs) - - -@_embed_dist.register(dist.Gumbel) -def _embed_gumbel(d: dist.Gumbel) -> Term[TorchDistribution]: - return _DistributionTerm(dist.Gumbel, d.loc, d.scale) - - -@_embed_dist.register(dist.HalfCauchy) -def _embed_half_cauchy(d: dist.HalfCauchy) -> Term[TorchDistribution]: - return _DistributionTerm(dist.HalfCauchy, d.scale) - - -@_embed_dist.register(dist.HalfNormal) -def _embed_half_normal(d: dist.HalfNormal) -> Term[TorchDistribution]: - return _DistributionTerm(dist.HalfNormal, d.scale) - - -@_embed_dist.register(dist.Independent) -def _embed_independent(d: dist.Independent) -> Term[TorchDistribution]: - return _DistributionTerm(dist.Independent, d.base_dist, d.reinterpreted_batch_ndims) +@_embed_distribution.register(dist.HalfCauchy) +@_embed_distribution.register(dist.HalfNormal) +def _embed_half_cauchy(d) -> Term[TorchDistribution]: + return _DistributionTerm(type(d), d.scale) -@_embed_dist.register(dist.Kumaraswamy) -def _embed_kumaraswamy(d: dist.Kumaraswamy) -> Term[TorchDistribution]: - return _DistributionTerm(dist.Kumaraswamy, d.concentration1, d.concentration0) - - -@_embed_dist.register(dist.LKJCholesky) +@_embed_distribution.register(dist.LKJCholesky) def _embed_lkj_cholesky(d: dist.LKJCholesky) -> Term[TorchDistribution]: - return _DistributionTerm(dist.LKJCholesky, d.dim, d.concentration) - - -@_embed_dist.register(dist.Laplace) -def _embed_laplace(d: dist.Laplace) -> Term[TorchDistribution]: - return _DistributionTerm(dist.Laplace, d.loc, d.scale) - - -@_embed_dist.register(dist.LogNormal) -def _embed_log_normal(d: dist.LogNormal) -> Term[TorchDistribution]: - return _DistributionTerm(dist.LogNormal, d.loc, d.scale) + return _DistributionTerm(dist.LKJCholesky, d.concentration, dim=d.dim) -@_embed_dist.register(dist.LogisticNormal) -def _embed_logistic_normal(d: dist.LogisticNormal) -> Term[TorchDistribution]: - return _DistributionTerm(dist.LogisticNormal, d.loc, d.scale) - - -@_embed_dist.register(dist.Multinomial) +@_embed_distribution.register(dist.Multinomial) def _embed_multinomial(d: dist.Multinomial) -> Term[TorchDistribution]: return _DistributionTerm(dist.Multinomial, d.total_count, d.probs) -@_embed_dist.register(dist.MultivariateNormal) +@_embed_distribution.register(dist.MultivariateNormal) def _embed_multivariate_normal( d: dist.MultivariateNormal, ) -> Term[TorchDistribution]: return _DistributionTerm(dist.MultivariateNormal, d.loc, d.scale_tril) -@_embed_dist.register(dist.NegativeBinomial) +@_embed_distribution.register(dist.NegativeBinomial) def _embed_negative_binomial(d: dist.NegativeBinomial) -> Term[TorchDistribution]: return _DistributionTerm(dist.NegativeBinomial, d.total_count, d.probs) -@_embed_dist.register(dist.Normal) -def _embed_normal(d: dist.Normal) -> Term[TorchDistribution]: - return _DistributionTerm(dist.Normal, d.loc, d.scale) - - -@_embed_dist.register(dist.OneHotCategorical) -def _embed_one_hot_categorical(d: dist.OneHotCategorical) -> Term[TorchDistribution]: - return _DistributionTerm(dist.OneHotCategorical, d.probs) - - -@_embed_dist.register(dist.OneHotCategoricalStraightThrough) -def _embed_one_hot_categorical_straight_through( - d: dist.OneHotCategoricalStraightThrough, -) -> Term[TorchDistribution]: - return _DistributionTerm(dist.OneHotCategoricalStraightThrough, d.probs) - - -@_embed_dist.register(dist.Pareto) +@_embed_distribution.register(dist.Pareto) def _embed_pareto(d: dist.Pareto) -> Term[TorchDistribution]: return _DistributionTerm(dist.Pareto, d.scale, d.alpha) -@_embed_dist.register(dist.Poisson) +@_embed_distribution.register(dist.Poisson) def _embed_poisson(d: dist.Poisson) -> Term[TorchDistribution]: return _DistributionTerm(dist.Poisson, d.rate) -@_embed_dist.register(dist.RelaxedBernoulli) -def _embed_relaxed_bernoulli(d: dist.RelaxedBernoulli) -> Term[TorchDistribution]: - return _DistributionTerm(dist.RelaxedBernoulli, d.temperature, d.probs) - +@_embed_distribution.register(dist.RelaxedBernoulli) +@_embed_distribution.register(dist.RelaxedOneHotCategorical) +def _embed_relaxed(d) -> Term[TorchDistribution]: + return _DistributionTerm(type(d), d.temperature, d.probs) -@_embed_dist.register(dist.RelaxedOneHotCategorical) -def _embed_relaxed_one_hot_categorical( - d: dist.RelaxedOneHotCategorical, -) -> Term[TorchDistribution]: - return _DistributionTerm(dist.RelaxedOneHotCategorical, d.temperature, d.probs) - -@_embed_dist.register(dist.StudentT) -def _embed_student_t(d: dist.StudentT) -> Term[TorchDistribution]: - return _DistributionTerm(dist.StudentT, d.df, d.loc, d.scale) - - -@_embed_dist.register(dist.Uniform) +@_embed_distribution.register(dist.Uniform) def _embed_uniform(d: dist.Uniform) -> Term[TorchDistribution]: return _DistributionTerm(dist.Uniform, d.low, d.high) -@_embed_dist.register(dist.VonMises) +@_embed_distribution.register(dist.VonMises) def _embed_von_mises(d: dist.VonMises) -> Term[TorchDistribution]: return _DistributionTerm(dist.VonMises, d.loc, d.concentration) -@_embed_dist.register(dist.Weibull) +@_embed_distribution.register(dist.Weibull) def _embed_weibull(d: dist.Weibull) -> Term[TorchDistribution]: return _DistributionTerm(dist.Weibull, d.scale, d.concentration) -@_embed_dist.register(dist.Wishart) +@_embed_distribution.register(dist.Wishart) def _embed_wishart(d: dist.Wishart) -> Term[TorchDistribution]: return _DistributionTerm(dist.Wishart, d.df, d.scale_tril) +@_embed_distribution.register(dist.Delta) +def _embed_delta(d: dist.Delta) -> Term[TorchDistribution]: + return _DistributionTerm(dist.Delta, d.v, d.log_density, event_dim=d.event_dim) + + def pyro_module_shim( module: type[pyro.nn.module.PyroModule], ) -> type[pyro.nn.module.PyroModule]: From 08f06c14f4a7f0bc1d8bd7ea8c12133bfdff4382 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Wed, 29 Jan 2025 18:00:17 -0500 Subject: [PATCH 07/19] lint --- effectful/handlers/pyro.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/effectful/handlers/pyro.py b/effectful/handlers/pyro.py index fe22facc..20eb63d7 100644 --- a/effectful/handlers/pyro.py +++ b/effectful/handlers/pyro.py @@ -225,7 +225,7 @@ def _pyro_post_sample(self, msg: pyro.poutine.runtime.Message) -> None: if value is not None: # note: is it safe to assume that msg['fn'] is a distribution? - dist_shape: tuple[int, ...] = msg["fn"].batch_shape + msg["fn"].event_shape + dist_shape: tuple[int, ...] = msg["fn"].batch_shape + msg["fn"].event_shape # type: ignore if len(value.shape) < len(dist_shape): value = value.broadcast_to( torch.broadcast_shapes(value.shape, dist_shape) From 728d63f388342493cad89047d473fc6c6e879974 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 30 Jan 2025 10:43:23 -0500 Subject: [PATCH 08/19] add missing defterm --- effectful/handlers/pyro.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/effectful/handlers/pyro.py b/effectful/handlers/pyro.py index 20eb63d7..0d2317c0 100644 --- a/effectful/handlers/pyro.py +++ b/effectful/handlers/pyro.py @@ -319,7 +319,7 @@ class _DistributionTerm(Term[TorchDistribution], TorchDistribution): _kwargs: dict def __init__(self, dist_constr: Type[TorchDistribution], *args, **kwargs): - self._args = (dist_constr,) + tuple(args) + self._args = (dist_constr,) + tuple(defterm(a) for a in args) self._kwargs = kwargs @property From ac758b7a8d18f8739fa12a9f8124be3b374b0833 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 30 Jan 2025 15:48:19 -0500 Subject: [PATCH 09/19] reenable check --- effectful/handlers/pyro.py | 2 +- effectful/ops/syntax.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/effectful/handlers/pyro.py b/effectful/handlers/pyro.py index 0d2317c0..4ce20fb4 100644 --- a/effectful/handlers/pyro.py +++ b/effectful/handlers/pyro.py @@ -267,7 +267,7 @@ def __repr__(self): @defop def named_distribution( - d: Annotated[TorchDistribution, Scoped[A]], + d: Annotated[TorchDistribution, Scoped[A | B]], *names: Annotated[Operation[[], int], Scoped[B]], ) -> Annotated[TorchDistribution, Scoped[A | B]]: match defterm(d): diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 69d8a746..98ab015a 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -280,7 +280,7 @@ def infer_annotations(cls, sig: inspect.Signature) -> inspect.Signature: # pre-conditions assert cls._check_has_single_scope(sig) assert cls._check_no_typevar_overlap(sig) - # assert cls._check_no_boundvars_in_result(sig) + assert cls._check_no_boundvars_in_result(sig) root_ordinal = cls._get_root_ordinal(sig) if not root_ordinal: From a266adc430bc2f5e14b7caaa95f2379a16ee0207 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 30 Jan 2025 15:49:06 -0500 Subject: [PATCH 10/19] remove getitem --- effectful/handlers/pyro.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/effectful/handlers/pyro.py b/effectful/handlers/pyro.py index 4ce20fb4..7b86ba74 100644 --- a/effectful/handlers/pyro.py +++ b/effectful/handlers/pyro.py @@ -374,9 +374,6 @@ def log_prob(self, value): def enumerate_support(self, expand=True): return self._base_dist.enumerate_support(expand) - def __getitem__(self, key: Collection[Operation[[], int]]): - return named_distribution(self, *key) - @defterm.register(TorchDistribution) @defterm.register(TorchDistributionMixin) From 803ba6115773875fb0f21aed7d53c57bf6a9a28d Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Thu, 30 Jan 2025 15:52:36 -0500 Subject: [PATCH 11/19] flatten singledispatch --- effectful/handlers/pyro.py | 77 +++++++++++++++++++------------------- 1 file changed, 38 insertions(+), 39 deletions(-) diff --git a/effectful/handlers/pyro.py b/effectful/handlers/pyro.py index 7b86ba74..f9eec4b7 100644 --- a/effectful/handlers/pyro.py +++ b/effectful/handlers/pyro.py @@ -377,147 +377,146 @@ def enumerate_support(self, expand=True): @defterm.register(TorchDistribution) @defterm.register(TorchDistributionMixin) -@functools.singledispatch def _embed_distribution(dist: TorchDistribution) -> Term[TorchDistribution]: raise ValueError( f"No embedding provided for distribution of type {type(dist).__name__}." ) -@_embed_distribution.register(dist.ExpandedDistribution) +@defterm.register(dist.ExpandedDistribution) def _embed_expanded(d: dist.ExpandedDistribution) -> Term[TorchDistribution]: if d._batch_shape == d.base_dist.batch_shape: return d.base_dist raise ValueError("Nontrivial ExpandedDistribution not implemented.") -@_embed_distribution.register(dist.Independent) +@defterm.register(dist.Independent) def _embed_independent(d) -> Term[TorchDistribution]: return _DistributionTerm(type(d), d.base_dist, d.reinterpreted_batch_ndims) -@_embed_distribution.register(dist.Cauchy) -@_embed_distribution.register(dist.Gumbel) -@_embed_distribution.register(dist.Laplace) -@_embed_distribution.register(dist.LogNormal) -@_embed_distribution.register(dist.LogisticNormal) -@_embed_distribution.register(dist.Normal) -@_embed_distribution.register(dist.StudentT) +@defterm.register(dist.Cauchy) +@defterm.register(dist.Gumbel) +@defterm.register(dist.Laplace) +@defterm.register(dist.LogNormal) +@defterm.register(dist.LogisticNormal) +@defterm.register(dist.Normal) +@defterm.register(dist.StudentT) def _embed_loc_scale(d) -> Term[TorchDistribution]: return _DistributionTerm(type(d), d.loc, d.scale) -@_embed_distribution.register(dist.Bernoulli) -@_embed_distribution.register(dist.Categorical) -@_embed_distribution.register(dist.ContinuousBernoulli) -@_embed_distribution.register(dist.Geometric) -@_embed_distribution.register(dist.OneHotCategorical) -@_embed_distribution.register(dist.OneHotCategoricalStraightThrough) +@defterm.register(dist.Bernoulli) +@defterm.register(dist.Categorical) +@defterm.register(dist.ContinuousBernoulli) +@defterm.register(dist.Geometric) +@defterm.register(dist.OneHotCategorical) +@defterm.register(dist.OneHotCategoricalStraightThrough) def _embed_probs(d) -> Term[TorchDistribution]: return _DistributionTerm(type(d), d.probs) -@_embed_distribution.register(dist.Beta) -@_embed_distribution.register(dist.Kumaraswamy) +@defterm.register(dist.Beta) +@defterm.register(dist.Kumaraswamy) def _embed_beta(d) -> Term[TorchDistribution]: return _DistributionTerm(type(d), d.concentration1, d.concentration0) -@_embed_distribution.register(dist.Binomial) +@defterm.register(dist.Binomial) def _embed_binomial(d: dist.Binomial) -> Term[TorchDistribution]: return _DistributionTerm(dist.Binomial, d.total_count, d.probs) -@_embed_distribution.register(dist.Chi2) +@defterm.register(dist.Chi2) def _embed_chi2(d: dist.Chi2) -> Term[TorchDistribution]: return _DistributionTerm(dist.Chi2, d.df) -@_embed_distribution.register(dist.Dirichlet) +@defterm.register(dist.Dirichlet) def _embed_dirichlet(d: dist.Dirichlet) -> Term[TorchDistribution]: return _DistributionTerm(dist.Dirichlet, d.concentration) -@_embed_distribution.register(dist.Exponential) +@defterm.register(dist.Exponential) def _embed_exponential(d: dist.Exponential) -> Term[TorchDistribution]: return _DistributionTerm(dist.Exponential, d.rate) -@_embed_distribution.register(dist.FisherSnedecor) +@defterm.register(dist.FisherSnedecor) def _embed_fisher_snedecor(d: dist.FisherSnedecor) -> Term[TorchDistribution]: return _DistributionTerm(dist.FisherSnedecor, d.df1, d.df2) -@_embed_distribution.register(dist.Gamma) +@defterm.register(dist.Gamma) def _embed_gamma(d: dist.Gamma) -> Term[TorchDistribution]: return _DistributionTerm(dist.Gamma, d.concentration, d.rate) -@_embed_distribution.register(dist.HalfCauchy) -@_embed_distribution.register(dist.HalfNormal) +@defterm.register(dist.HalfCauchy) +@defterm.register(dist.HalfNormal) def _embed_half_cauchy(d) -> Term[TorchDistribution]: return _DistributionTerm(type(d), d.scale) -@_embed_distribution.register(dist.LKJCholesky) +@defterm.register(dist.LKJCholesky) def _embed_lkj_cholesky(d: dist.LKJCholesky) -> Term[TorchDistribution]: return _DistributionTerm(dist.LKJCholesky, d.concentration, dim=d.dim) -@_embed_distribution.register(dist.Multinomial) +@defterm.register(dist.Multinomial) def _embed_multinomial(d: dist.Multinomial) -> Term[TorchDistribution]: return _DistributionTerm(dist.Multinomial, d.total_count, d.probs) -@_embed_distribution.register(dist.MultivariateNormal) +@defterm.register(dist.MultivariateNormal) def _embed_multivariate_normal( d: dist.MultivariateNormal, ) -> Term[TorchDistribution]: return _DistributionTerm(dist.MultivariateNormal, d.loc, d.scale_tril) -@_embed_distribution.register(dist.NegativeBinomial) +@defterm.register(dist.NegativeBinomial) def _embed_negative_binomial(d: dist.NegativeBinomial) -> Term[TorchDistribution]: return _DistributionTerm(dist.NegativeBinomial, d.total_count, d.probs) -@_embed_distribution.register(dist.Pareto) +@defterm.register(dist.Pareto) def _embed_pareto(d: dist.Pareto) -> Term[TorchDistribution]: return _DistributionTerm(dist.Pareto, d.scale, d.alpha) -@_embed_distribution.register(dist.Poisson) +@defterm.register(dist.Poisson) def _embed_poisson(d: dist.Poisson) -> Term[TorchDistribution]: return _DistributionTerm(dist.Poisson, d.rate) -@_embed_distribution.register(dist.RelaxedBernoulli) -@_embed_distribution.register(dist.RelaxedOneHotCategorical) +@defterm.register(dist.RelaxedBernoulli) +@defterm.register(dist.RelaxedOneHotCategorical) def _embed_relaxed(d) -> Term[TorchDistribution]: return _DistributionTerm(type(d), d.temperature, d.probs) -@_embed_distribution.register(dist.Uniform) +@defterm.register(dist.Uniform) def _embed_uniform(d: dist.Uniform) -> Term[TorchDistribution]: return _DistributionTerm(dist.Uniform, d.low, d.high) -@_embed_distribution.register(dist.VonMises) +@defterm.register(dist.VonMises) def _embed_von_mises(d: dist.VonMises) -> Term[TorchDistribution]: return _DistributionTerm(dist.VonMises, d.loc, d.concentration) -@_embed_distribution.register(dist.Weibull) +@defterm.register(dist.Weibull) def _embed_weibull(d: dist.Weibull) -> Term[TorchDistribution]: return _DistributionTerm(dist.Weibull, d.scale, d.concentration) -@_embed_distribution.register(dist.Wishart) +@defterm.register(dist.Wishart) def _embed_wishart(d: dist.Wishart) -> Term[TorchDistribution]: return _DistributionTerm(dist.Wishart, d.df, d.scale_tril) -@_embed_distribution.register(dist.Delta) +@defterm.register(dist.Delta) def _embed_delta(d: dist.Delta) -> Term[TorchDistribution]: return _DistributionTerm(dist.Delta, d.v, d.log_density, event_dim=d.event_dim) From 876ebc50fff5baff0626fed987c509b36f4e09fb Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Fri, 31 Jan 2025 10:24:05 -0500 Subject: [PATCH 12/19] handle Independent distributions --- effectful/handlers/pyro.py | 46 +++++++++++++++++++++++++------------- effectful/ops/syntax.py | 6 ++--- 2 files changed, 33 insertions(+), 19 deletions(-) diff --git a/effectful/handlers/pyro.py b/effectful/handlers/pyro.py index f9eec4b7..a82a07b1 100644 --- a/effectful/handlers/pyro.py +++ b/effectful/handlers/pyro.py @@ -272,15 +272,22 @@ def named_distribution( ) -> Annotated[TorchDistribution, Scoped[A | B]]: match defterm(d): case Term(op=_call, args=(dist_constr, *args)) if _call is call: - assert dist_constr is not dist.Independent - named_args = [] - for a in args: - assert isinstance(a, torch.Tensor) - named_args.append( - Indexable(typing.cast(torch.Tensor, a))[tuple(n() for n in names)] + if dist_constr is dist.Independent: + base_dist, reinterpreted_batch_ndims = args + return dist.Independent( + named_distribution(base_dist, *names), reinterpreted_batch_ndims ) - assert callable(dist_constr) - return dist_constr(*named_args) + else: + named_args = [] + for a in args: + assert isinstance(a, torch.Tensor) + named_args.append( + Indexable(typing.cast(torch.Tensor, a))[ + tuple(n() for n in names) + ] + ) + assert callable(dist_constr) + return dist_constr(*named_args) case _: raise NotImplementedError @@ -291,14 +298,21 @@ def positional_distribution( ) -> Tuple[TorchDistribution, Naming]: match defterm(d): case Term(op=_call, args=(dist_constr, *args)) if _call is call: - assert callable(dist_constr) - base_dist = dist_constr(*args) - indices = sizesof(d).keys() - n_base = len(base_dist.batch_shape) + len(base_dist.event_shape) - naming = Naming.from_shape(indices, n_base) - pos_args = [to_tensor(a, indices) for a in args] - pos_dist = dist_constr(*pos_args) - return pos_dist, naming + if dist_constr is dist.Independent: + base_dist, reinterpreted_batch_ndims = args + pos_base_dist, naming = positional_distribution(base_dist) + return dist.Independent( + pos_base_dist, reinterpreted_batch_ndims + ), naming + else: + assert callable(dist_constr) + base_dist = dist_constr(*args) + indices = sizesof(d).keys() + n_base = len(base_dist.batch_shape) + len(base_dist.event_shape) + naming = Naming.from_shape(indices, n_base) + pos_args = [to_tensor(a, indices) for a in args] + pos_dist = dist_constr(*pos_args) + return pos_dist, naming case _: raise NotImplementedError diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index 98ab015a..e03a3eee 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -542,7 +542,9 @@ def __default_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> "Expr[V]": Callable[Concatenate[Operation[Q, V], Q], Expr[V]], defdata )(self, *args, **kwargs) - def __fvs_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> tuple[ + def __fvs_rule__( + self, *args: Q.args, **kwargs: Q.kwargs + ) -> tuple[ tuple[collections.abc.Set[Operation], ...], dict[str, collections.abc.Set[Operation]], ]: @@ -613,7 +615,6 @@ def __str__(self): @defop.register(Operation) def _(t: Operation[P, T], *, name: Optional[str] = None) -> Operation[P, T]: - @functools.wraps(t) def func(*args, **kwargs): raise NotImplementedError @@ -643,7 +644,6 @@ def func() -> t: # type: ignore @defop.register(types.BuiltinFunctionType) def _(t: Callable[P, T], *, name: Optional[str] = None) -> Operation[P, T]: - @functools.wraps(t) def func(*args, **kwargs): if not any(isinstance(a, Term) for a in tree.flatten((args, kwargs))): From 506a1f51a0254e52ba66cd5911007d3e10ac4299 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Fri, 31 Jan 2025 10:26:00 -0500 Subject: [PATCH 13/19] remove singledispatch --- effectful/handlers/torch.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/effectful/handlers/torch.py b/effectful/handlers/torch.py index d8adfbbd..10f90623 100644 --- a/effectful/handlers/torch.py +++ b/effectful/handlers/torch.py @@ -53,9 +53,9 @@ def extra_dims(key): new_shape.append(1) new_key.append(slice(None)) elif k is Ellipsis: - assert not any( - k is Ellipsis for k in key[i + 1 :] - ), "only one Ellipsis allowed" + assert not any(k is Ellipsis for k in key[i + 1 :]), ( + "only one Ellipsis allowed" + ) # determine which of the original dimensions this ellipsis refers to pre_dims = i - extra_dims(key[:i]) # dimensions that precede the ellipsis @@ -84,7 +84,6 @@ def _getitem_ellipsis_and_none( return torch.reshape(x, new_shape), new_key -@functools.singledispatch def sizesof(value) -> Mapping[Operation[[], int], int]: """Return the sizes of named dimensions in a tensor expression. @@ -209,10 +208,8 @@ def to_tensor(t: T, order: Optional[Collection[Operation[[], int]]] = None) -> T @functools.cache def _register_torch_op(torch_fn: Callable[P, T]): - @defop def _torch_op(*args, **kwargs) -> torch.Tensor: - tm = defdata(_torch_op, *args, **kwargs) sized_fvs = sizesof(tm) @@ -374,7 +371,6 @@ def __torch_function__( @Term.register class _EagerTensorTerm(torch.Tensor): - op: Operation[..., torch.Tensor] = torch_getitem args: Tuple[torch.Tensor, Tuple[IndexElement, ...]] kwargs: Mapping[str, object] = {} From cd49b556594965a3333f1d00dd3bd4bac691ed52 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Fri, 31 Jan 2025 10:26:23 -0500 Subject: [PATCH 14/19] format --- effectful/handlers/torch.py | 6 +++--- effectful/ops/syntax.py | 4 +--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/effectful/handlers/torch.py b/effectful/handlers/torch.py index 10f90623..cee1edac 100644 --- a/effectful/handlers/torch.py +++ b/effectful/handlers/torch.py @@ -53,9 +53,9 @@ def extra_dims(key): new_shape.append(1) new_key.append(slice(None)) elif k is Ellipsis: - assert not any(k is Ellipsis for k in key[i + 1 :]), ( - "only one Ellipsis allowed" - ) + assert not any( + k is Ellipsis for k in key[i + 1 :] + ), "only one Ellipsis allowed" # determine which of the original dimensions this ellipsis refers to pre_dims = i - extra_dims(key[:i]) # dimensions that precede the ellipsis diff --git a/effectful/ops/syntax.py b/effectful/ops/syntax.py index e03a3eee..575a0167 100644 --- a/effectful/ops/syntax.py +++ b/effectful/ops/syntax.py @@ -542,9 +542,7 @@ def __default_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> "Expr[V]": Callable[Concatenate[Operation[Q, V], Q], Expr[V]], defdata )(self, *args, **kwargs) - def __fvs_rule__( - self, *args: Q.args, **kwargs: Q.kwargs - ) -> tuple[ + def __fvs_rule__(self, *args: Q.args, **kwargs: Q.kwargs) -> tuple[ tuple[collections.abc.Set[Operation], ...], dict[str, collections.abc.Set[Operation]], ]: From cd2d99322150063c117a52e75aac47781591fd17 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Fri, 31 Jan 2025 10:29:07 -0500 Subject: [PATCH 15/19] fix bugs in pyro shim --- effectful/handlers/pyro.py | 100 +++++++++++++++++++++++++------------ 1 file changed, 67 insertions(+), 33 deletions(-) diff --git a/effectful/handlers/pyro.py b/effectful/handlers/pyro.py index a82a07b1..7b851c1e 100644 --- a/effectful/handlers/pyro.py +++ b/effectful/handlers/pyro.py @@ -151,7 +151,23 @@ def _pyro_sample(self, msg: pyro.poutine.runtime.Message) -> None: ): return + # PyroShim turns each call to pyro.sample into two calls. The first + # dispatches to pyro_sample and the effectful stack. The effectful stack + # eventually calls pyro.sample again. We use state in PyroShim to + # recognize that we've been called twice, and we dispatch to the pyro + # stack. + # + # This branch handles the second call, so it massages the message to be + # compatible with Pyro. In particular, it removes all named dimensions + # and stores naming information in the message. Names are replaced by + # _pyro_post_sample. if getattr(self, "_current_site", None) == msg["name"]: + if "_index_naming" in msg: + return + + # We need to identify this pyro shim during post-sample. + msg["_pyro_shim_id"] = id(self) + if "_markov_scope" in msg["infer"] and self._current_site: msg["infer"]["_markov_scope"].pop(self._current_site, None) @@ -170,6 +186,9 @@ def _pyro_sample(self, msg: pyro.poutine.runtime.Message) -> None: else: mask = msg["mask"] + assert set(sizesof(mask).keys()) <= ( + set(indices.keys()) | set(sizesof(obs).keys()) + ) pos_mask, _ = PyroShim._broadcast_to_named(mask, dist.batch_shape, indices) pos_obs: Optional[torch.Tensor] = None @@ -178,52 +197,66 @@ def _pyro_sample(self, msg: pyro.poutine.runtime.Message) -> None: obs, dist.shape(), indices ) + # Each of the batch dimensions on the distribution gets a + # cond_indep_stack frame. for var, dim in naming.name_to_dim.items(): - frame = pyro.poutine.indep_messenger.CondIndepStackFrame( - name=str(var), dim=dim, size=-1, counter=0 - ) - msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"] + # There can be additional batch dimensions on the observation + # that do not get frames, so only consider dimensions on the + # distribution. + if var in indices: + frame = pyro.poutine.indep_messenger.CondIndepStackFrame( + name=str(var), + # dims are indexed from the right of the batch shape + dim=dim + len(pdist.event_shape), + size=indices[var], + counter=0, + ) + msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"] msg["fn"] = pdist msg["value"] = pos_obs msg["mask"] = pos_mask - msg["infer"]["_index_naming"] = naming # type: ignore + msg["_index_naming"] = naming # type: ignore assert sizesof(msg["value"]) == {} assert sizesof(msg["mask"]) == {} - return - - try: - self._current_site = msg["name"] - msg["value"] = pyro_sample( - msg["name"], - msg["fn"], - obs=msg["value"] if msg["is_observed"] else None, - infer=msg["infer"].copy(), - ) - finally: - self._current_site = None + # This branch handles the first call to pyro.sample by calling pyro_sample. + else: + try: + self._current_site = msg["name"] + msg["value"] = pyro_sample( + msg["name"], + msg["fn"], + obs=msg["value"] if msg["is_observed"] else None, + infer=msg["infer"].copy(), + ) + finally: + self._current_site = None - # flags to guarantee commutativity of condition, intervene, trace - msg["stop"] = True - msg["done"] = True - msg["mask"] = False - msg["is_observed"] = True - msg["infer"]["is_auxiliary"] = True - msg["infer"]["_do_not_trace"] = True + # flags to guarantee commutativity of condition, intervene, trace + msg["stop"] = True + msg["done"] = True + msg["mask"] = False + msg["is_observed"] = True + msg["infer"]["is_auxiliary"] = True + msg["infer"]["_do_not_trace"] = True def _pyro_post_sample(self, msg: pyro.poutine.runtime.Message) -> None: - infer = msg.get("infer") - if infer is None or "_index_naming" not in infer: + assert msg["value"] is not None + + # If this message has been handled already by a different pyro shim, ignore. + if "_pyro_shim_id" in msg and msg["_pyro_shim_id"] != id(self): return - # note: Pyro uses a TypedDict for infer, so it doesn't know we've stored this key - naming = infer["_index_naming"] # type: ignore + if getattr(self, "_current_site", None) == msg["name"]: + assert "_index_naming" in msg - value = msg["value"] + # note: Pyro uses a TypedDict for infer, so it doesn't know we've stored this key + naming = msg["_index_naming"] # type: ignore + + value = msg["value"] - if value is not None: # note: is it safe to assume that msg['fn'] is a distribution? dist_shape: tuple[int, ...] = msg["fn"].batch_shape + msg["fn"].event_shape # type: ignore if len(value.shape) < len(dist_shape): @@ -301,9 +334,10 @@ def positional_distribution( if dist_constr is dist.Independent: base_dist, reinterpreted_batch_ndims = args pos_base_dist, naming = positional_distribution(base_dist) - return dist.Independent( - pos_base_dist, reinterpreted_batch_ndims - ), naming + return ( + dist.Independent(pos_base_dist, reinterpreted_batch_ndims), + naming, + ) else: assert callable(dist_constr) base_dist = dist_constr(*args) From 600666fa0c9a3e6dd5faa5668a981a940f44367d Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Fri, 31 Jan 2025 10:35:18 -0500 Subject: [PATCH 16/19] lint --- effectful/handlers/pyro.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/effectful/handlers/pyro.py b/effectful/handlers/pyro.py index a82a07b1..599ed30a 100644 --- a/effectful/handlers/pyro.py +++ b/effectful/handlers/pyro.py @@ -275,7 +275,10 @@ def named_distribution( if dist_constr is dist.Independent: base_dist, reinterpreted_batch_ndims = args return dist.Independent( - named_distribution(base_dist, *names), reinterpreted_batch_ndims + named_distribution( + typing.cast(TorchDistribution, base_dist), *names + ), + reinterpreted_batch_ndims, ) else: named_args = [] @@ -300,10 +303,13 @@ def positional_distribution( case Term(op=_call, args=(dist_constr, *args)) if _call is call: if dist_constr is dist.Independent: base_dist, reinterpreted_batch_ndims = args - pos_base_dist, naming = positional_distribution(base_dist) - return dist.Independent( - pos_base_dist, reinterpreted_batch_ndims - ), naming + pos_base_dist, naming = positional_distribution( + typing.cast(TorchDistribution, base_dist) + ) + return ( + dist.Independent(pos_base_dist, reinterpreted_batch_ndims), + naming, + ) else: assert callable(dist_constr) base_dist = dist_constr(*args) From badfd4663192e414824b8230db1b4b2b0835466c Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Fri, 31 Jan 2025 10:38:13 -0500 Subject: [PATCH 17/19] lint --- effectful/handlers/pyro.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/effectful/handlers/pyro.py b/effectful/handlers/pyro.py index 861ac0f9..4066c787 100644 --- a/effectful/handlers/pyro.py +++ b/effectful/handlers/pyro.py @@ -1,4 +1,3 @@ -import functools import typing import warnings from typing import ( @@ -166,7 +165,7 @@ def _pyro_sample(self, msg: pyro.poutine.runtime.Message) -> None: return # We need to identify this pyro shim during post-sample. - msg["_pyro_shim_id"] = id(self) + msg["_pyro_shim_id"] = id(self) # type: ignore[typeddict-unknown-key] if "_markov_scope" in msg["infer"] and self._current_site: msg["infer"]["_markov_scope"].pop(self._current_site, None) @@ -246,7 +245,7 @@ def _pyro_post_sample(self, msg: pyro.poutine.runtime.Message) -> None: assert msg["value"] is not None # If this message has been handled already by a different pyro shim, ignore. - if "_pyro_shim_id" in msg and msg["_pyro_shim_id"] != id(self): + if "_pyro_shim_id" in msg and msg["_pyro_shim_id"] != id(self): # type: ignore[typeddict-item] return if getattr(self, "_current_site", None) == msg["name"]: From 4c5bb9f398b8a585c3d5c583bf0407ccb1f2339c Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Fri, 31 Jan 2025 10:39:00 -0500 Subject: [PATCH 18/19] lint --- effectful/handlers/pyro.py | 1 - 1 file changed, 1 deletion(-) diff --git a/effectful/handlers/pyro.py b/effectful/handlers/pyro.py index 599ed30a..f0317d6f 100644 --- a/effectful/handlers/pyro.py +++ b/effectful/handlers/pyro.py @@ -1,4 +1,3 @@ -import functools import typing import warnings from typing import ( From 6d5db50d90be2b12214c943d4b8f48e56234c905 Mon Sep 17 00:00:00 2001 From: Jack Feser Date: Fri, 31 Jan 2025 12:05:18 -0500 Subject: [PATCH 19/19] simplify named and positional_distribution --- effectful/handlers/pyro.py | 91 +++++++++++++++++++------------------- 1 file changed, 46 insertions(+), 45 deletions(-) diff --git a/effectful/handlers/pyro.py b/effectful/handlers/pyro.py index f0317d6f..5c2a5774 100644 --- a/effectful/handlers/pyro.py +++ b/effectful/handlers/pyro.py @@ -269,57 +269,58 @@ def named_distribution( d: Annotated[TorchDistribution, Scoped[A | B]], *names: Annotated[Operation[[], int], Scoped[B]], ) -> Annotated[TorchDistribution, Scoped[A | B]]: - match defterm(d): - case Term(op=_call, args=(dist_constr, *args)) if _call is call: - if dist_constr is dist.Independent: - base_dist, reinterpreted_batch_ndims = args - return dist.Independent( - named_distribution( - typing.cast(TorchDistribution, base_dist), *names - ), - reinterpreted_batch_ndims, - ) - else: - named_args = [] - for a in args: - assert isinstance(a, torch.Tensor) - named_args.append( - Indexable(typing.cast(torch.Tensor, a))[ - tuple(n() for n in names) - ] - ) - assert callable(dist_constr) - return dist_constr(*named_args) - case _: - raise NotImplementedError + d = defterm(d) + dist_constr, args = d.args[0], d.args[1:] + + if not ( + d.op is call + and ( + issubclass(dist_constr, TorchDistribution) + or issubclass(dist_constr, dist.torch_distribution.TorchDistributionMixin) + ) + ): + raise NotImplementedError + + def _to_named(a): + if isinstance(a, torch.Tensor): + return Indexable(typing.cast(torch.Tensor, a))[tuple(n() for n in names)] + elif isinstance(a, TorchDistribution): + return named_distribution(a, *names) + else: + return a + + return dist_constr(*[_to_named(a) for a in args], **d.kwargs) @defop def positional_distribution( d: Annotated[TorchDistribution, Scoped[A]], ) -> Tuple[TorchDistribution, Naming]: - match defterm(d): - case Term(op=_call, args=(dist_constr, *args)) if _call is call: - if dist_constr is dist.Independent: - base_dist, reinterpreted_batch_ndims = args - pos_base_dist, naming = positional_distribution( - typing.cast(TorchDistribution, base_dist) - ) - return ( - dist.Independent(pos_base_dist, reinterpreted_batch_ndims), - naming, - ) - else: - assert callable(dist_constr) - base_dist = dist_constr(*args) - indices = sizesof(d).keys() - n_base = len(base_dist.batch_shape) + len(base_dist.event_shape) - naming = Naming.from_shape(indices, n_base) - pos_args = [to_tensor(a, indices) for a in args] - pos_dist = dist_constr(*pos_args) - return pos_dist, naming - case _: - raise NotImplementedError + shape = d.shape() + d = defterm(d) + dist_constr, args = d.args[0], d.args[1:] + + if not ( + d.op is call + and ( + issubclass(dist_constr, TorchDistribution) + or issubclass(dist_constr, dist.torch_distribution.TorchDistributionMixin) + ) + ): + raise NotImplementedError + + indices = sizesof(d).keys() + naming = Naming.from_shape(indices, len(shape)) + + def _to_positional(a): + if isinstance(a, torch.Tensor): + return to_tensor(a, indices) + elif isinstance(a, TorchDistribution): + return positional_distribution(a)[0] + else: + return a + + return dist_constr(*[_to_positional(a) for a in args], **d.kwargs), naming class _DistributionTerm(Term[TorchDistribution], TorchDistribution):