Skip to content

Commit

Permalink
Ignore named variables that are not traceable in `get_vars_in_point_l…
Browse files Browse the repository at this point in the history
…ist`
  • Loading branch information
ricardoV94 committed May 30, 2023
1 parent fbc62d5 commit b17a60d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
3 changes: 2 additions & 1 deletion pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def get_vars_in_point_list(trace, model):
names_in_trace = list(trace[0])
else:
names_in_trace = trace.varnames
vars_in_trace = [model[v] for v in names_in_trace if v in model]
traceable_varnames = {var.name for var in (model.free_RVs + model.deterministics)}
vars_in_trace = [model[v] for v in names_in_trace if v in traceable_varnames]
return vars_in_trace


Expand Down
4 changes: 3 additions & 1 deletion tests/sampling/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1634,11 +1634,13 @@ def test_get_vars_in_point_list():
with pm.Model() as modelA:
pm.Normal("a", 0, 1)
pm.Normal("b", 0, 1)
pm.Normal("d", 0, 1)
with pm.Model() as modelB:
a = pm.Normal("a", 0, 1)
pm.Normal("c", 0, 1)
pm.ConstantData("d", 0)

point_list = [{"a": 0, "b": 0}]
point_list = [{"a": 0, "b": 0, "d": 0}]
vars_in_trace = get_vars_in_point_list(point_list, modelB)
assert set(vars_in_trace) == {a}

Expand Down

0 comments on commit b17a60d

Please sign in to comment.