-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add function that goes from transformed space to untransformed space #6721
Comments
Results should be saved in https://python.arviz.org/en/latest/schema/schema.html#unconstrained-posterior We should make sure there's an option from |
Hi, any updates on this? Or do you have any suggestions for vectorizing the transformation of parameters between the constrained and unconstrained space? I am currently doing an inefficient for-loop, which also feels a bit hacky:
model: pm.Model
transformed_rvs = []
for free_rv in model.free_RVs:
transform = model.rvs_to_transforms.get(free_rv)
if transform is None:
transformed_rvs.append(free_rv)
else:
transformed_rv = transform.forward(free_rv, *free_rv.owner.inputs)
transformed_rvs.append(transformed_rv)
fn = model.compile_fn(inputs=model.free_RVs, outs=transformed_rvs)
# N parameter values to transform
for i in range(N_samples):
# the value_dict is e.g., {"sigma": 0.1, "a": [0.1, 0.2]}
value_unconstrained_list = fn(value_dict)
outputs = model.unobserved_value_vars
fn_inv = model.compile_fn(outs=outputs)
for i in range(N_samples):
# value_dict = {"sigma_log__": np.log(0.1), "a": [0.1, 0.2]}
value_constrained_list = fn_inv(value_dict) Thanks! |
from pymc.sampling.jax import _postprocess_samples, get_jaxified_graph
from pymc.util import (
get_default_varnames,
)
filtered_var_names = model.unobserved_value_vars
vars_to_sample = list(
get_default_varnames(filtered_var_names, include_transformed=False)
)
jax_fn_inv = get_jaxified_graph(
inputs=model.value_vars, outputs=vars_to_sample
)
_postprocess_samples(
jax_fn_inv, params_unconstrained
) The above seems to work for transforming from the unconstrained to constrained space. Inside jax_fn = get_jaxified_graph(
inputs=vars_to_sample, outputs=model.value_vars
) Is it possible to have such a jaxified function? |
Description
Because we don't save transformed variables in the returned InferenceData (why not?) it's not easy to evaluate the model logp once we have a trace.
One could rewrite the model without transforms (and we can make this automatically for the user)This is possible with https://www.pymc.io/projects/docs/en/stable/api/model/generated/pymc.model.transform.conditioning.remove_value_transforms.htmlBut someone might still want to evaluate it in the original model (with jacobians and all that).
One dirty implementation is given here: https://discourse.pymc.io/t/logp-questions-synthetic-dataset-to-evaluate-modeling/12129/6?u=ricardov94
The text was updated successfully, but these errors were encountered: