Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add function that goes from transformed space to untransformed space #6721

Open
ricardoV94 opened this issue May 18, 2023 · 3 comments
Open
Labels

Comments

@ricardoV94
Copy link
Member

ricardoV94 commented May 18, 2023

Description

Because we don't save transformed variables in the returned InferenceData (why not?) it's not easy to evaluate the model logp once we have a trace.

One could rewrite the model without transforms (and we can make this automatically for the user) This is possible with https://www.pymc.io/projects/docs/en/stable/api/model/generated/pymc.model.transform.conditioning.remove_value_transforms.html

But someone might still want to evaluate it in the original model (with jacobians and all that).

One dirty implementation is given here: https://discourse.pymc.io/t/logp-questions-synthetic-dataset-to-evaluate-modeling/12129/6?u=ricardov94

@ricardoV94
Copy link
Member Author

Results should be saved in https://python.arviz.org/en/latest/schema/schema.html#unconstrained-posterior

We should make sure there's an option from pm.sample to store those, besides allowing users to populate them afterwards with a helper as initially suggested in this issue

@pipme
Copy link
Contributor

pipme commented Feb 26, 2025

Hi, any updates on this? Or do you have any suggestions for vectorizing the transformation of parameters between the constrained and unconstrained space?

I am currently doing an inefficient for-loop, which also feels a bit hacky:

  • constrained to unconstrained space:
    model: pm.Model
    transformed_rvs = []
    for free_rv in model.free_RVs:
        transform = model.rvs_to_transforms.get(free_rv)
        if transform is None:
            transformed_rvs.append(free_rv)
        else:
            transformed_rv = transform.forward(free_rv, *free_rv.owner.inputs)
            transformed_rvs.append(transformed_rv)

    fn = model.compile_fn(inputs=model.free_RVs, outs=transformed_rvs)
	# N parameter values to transform
    for i in range(N_samples):
        # the value_dict is e.g., {"sigma": 0.1, "a": [0.1, 0.2]}
        value_unconstrained_list = fn(value_dict)
  • unconstrained to constrained space:
    outputs = model.unobserved_value_vars
    fn_inv = model.compile_fn(outs=outputs)
    for i in range(N_samples):
        # value_dict = {"sigma_log__": np.log(0.1), "a": [0.1, 0.2]}
        value_constrained_list = fn_inv(value_dict)

Thanks!

@pipme
Copy link
Contributor

pipme commented Feb 26, 2025

from pymc.sampling.jax import _postprocess_samples, get_jaxified_graph
from pymc.util import (
    get_default_varnames,
)

filtered_var_names = model.unobserved_value_vars
vars_to_sample = list(
    get_default_varnames(filtered_var_names, include_transformed=False)
)

jax_fn_inv = get_jaxified_graph(
    inputs=model.value_vars, outputs=vars_to_sample
)
_postprocess_samples(
    jax_fn_inv, params_unconstrained
)

The above seems to work for transforming from the unconstrained to constrained space. Inside _postprocess_samples, jax.vmap is leveraged for vectorization. But the below doesn't work for going from the constrained to unconstrained space:

jax_fn = get_jaxified_graph(
    inputs=vars_to_sample, outputs=model.value_vars
)

Is it possible to have such a jaxified function?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants