Skip to content

Commit

Permalink
Make do interventions shared variables by default
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Feb 27, 2025
1 parent 62335ac commit 63cf0e8
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 4 deletions.
17 changes: 13 additions & 4 deletions pymc/model/transform/conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
from collections.abc import Mapping, Sequence
from typing import Any, Union

from pytensor.graph import ancestors
import pytensor

from pytensor.graph import Constant, ancestors
from pytensor.tensor import TensorVariable

from pymc.logprob.transforms import Transform
Expand Down Expand Up @@ -126,7 +128,9 @@ def observe(
def do(
model: Model,
vars_to_interventions: Mapping[Union["str", TensorVariable], Any],
prune_vars=False,
*,
make_interventions_shared: bool = True,
prune_vars: bool = False,
) -> Model:
"""Replace model variables by intervention variables.
Expand All @@ -140,6 +144,8 @@ def do(
Dictionary that maps model variables (or names) to intervention expressions.
Intervention expressions must have a shape and data type that is compatible
with the original model variable.
make_interventions_shared: bool, defaults to True,
Whether to make constant interventions shared variables.
prune_vars: bool, defaults to False
Whether to prune model variables that are not connected to any observed variables,
after the interventions.
Expand Down Expand Up @@ -170,11 +176,14 @@ def do(
"""
do_mapping = {}
for var, obs in vars_to_interventions.items():
for var, intervention in vars_to_interventions.items():
if isinstance(var, str):
var = model[var]
try:
do_mapping[var] = var.type.filter_variable(obs)
intervention = var.type.filter_variable(intervention)
if make_interventions_shared and isinstance(intervention, Constant):
intervention = pytensor.shared(intervention.data, name=var.name)
do_mapping[var] = intervention
except TypeError as err:
raise TypeError(
"Incompatible replacement type. Make sure the shape and datatype of the interventions match the original variables"
Expand Down
43 changes: 43 additions & 0 deletions tests/model/transform/test_conditioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,12 @@
import pytest

from pytensor import config
from pytensor.compile import SharedVariable
from pytensor.graph import Constant

import pymc as pm

from pymc import sample_posterior_predictive, set_data
from pymc.distributions.transforms import logodds
from pymc.model.transform.conditioning import (
change_value_transforms,
Expand Down Expand Up @@ -253,6 +256,46 @@ def test_do_self_reference():
np.testing.assert_allclose(draw_x + 100, draw_do_x)


def test_do_make_intervenstions_shared():
with pm.Model(coords={"obs": [1]}) as m:
x = pm.Normal("x", dims="obs")
y = pm.Normal("y", dims="obs")

constant_m = do(m, {x: [0.5]}, make_interventions_shared=False)
constant_x = constant_m["x"]
assert isinstance(constant_x, Constant)
np.testing.assert_array_equal(constant_x.data, [0.5])

shared_m = do(m, {x: [0.5]}, make_interventions_shared=True)
shared_x = shared_m["x"]
assert isinstance(shared_x, SharedVariable)
np.testing.assert_array_equal(shared_x.get_value(borrow=True), [0.5])

with shared_m:
set_data({"x": [0.6, 0.9]}, coords={"obs": [2, 3]})
pp_y = pm.sample_prior_predictive(draws=3).prior["y"]
assert pp_y.sizes == {"chain": 1, "draw": 3, "obs": 2}
assert pp_y.shape == (1, 3, 2)


@pytest.mark.parametrize(
"make_interventions_shared",
[True, pytest.param(False, marks=pytest.mark.xfail(reason="#6876"))],
)
def test_do_sample_posterior_predictive(make_interventions_shared):
# Regression test for https://github.com/pymc-devs/pymc/issues/6977
with pm.Model() as model:
a = pm.Normal("a")
b = pm.Deterministic("b", a * 2)
c = pm.Normal("c", b / 2)

idata = az.from_dict({"a": [[1.0]], "b": [[2.0]], "c": [[1.0]]})

with do(model, {a: 1000}, make_interventions_shared=make_interventions_shared):
pp = sample_posterior_predictive(idata, var_names=["c"], predictions=True).predictions
assert (pp["c"] > 500).all()


def test_change_value_transforms():
with pm.Model() as base_m:
p = pm.Uniform("p", 0, 1, default_transform=None)
Expand Down

0 comments on commit 63cf0e8

Please sign in to comment.