Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement sizesof for named and positional distributions #214

Merged
merged 19 commits into from
Jan 31, 2025

Conversation

jfeser
Copy link
Contributor

@jfeser jfeser commented Jan 27, 2025

Closes #213

@jfeser jfeser requested a review from eb8680 January 27, 2025 21:29
@jfeser jfeser added status:awaiting review bug Something isn't working labels Jan 27, 2025
@eb8680
Copy link
Contributor

eb8680 commented Jan 27, 2025

I think it would be better and not that much more involved to address the root cause #203 instead of special-casing handlers.torch.sizesof

@jfeser
Copy link
Contributor Author

jfeser commented Jan 27, 2025

Ok, I'll rework this into something that addresses #203

@jfeser
Copy link
Contributor Author

jfeser commented Jan 28, 2025

Handling NamedDistribution is tricky, because the named dimensions don't appear as arguments to torch_getitem. We would need to switch from wrapping the distribution to reconstructing it with named arguments. For most (all?) distributions, naming the batch dimensions of the parameters is the same as naming the batch dimensions of the sample results, so this approach should still work. It's certainly more involved than the original version of this PR, because we need to introspect the distributions and reconstruct them, which might need per-distribution logic.

Copy link
Contributor

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This mostly makes sense, I just have a few comments

effectful/handlers/torch.py Outdated Show resolved Hide resolved

super().__init__()
@defop
def positional_distribution(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, why does positional_distribution needs to be a separate operation, as opposed to just using to_tensor on defterm(d)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could move this logic into to_tensor, sure. The default behavior of to_tensor doesn't work though, because vmap is restricted to only return tensors.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, I'd like to leave this alone for now.

effectful/ops/syntax.py Outdated Show resolved Hide resolved
effectful/handlers/pyro.py Outdated Show resolved Hide resolved
effectful/handlers/pyro.py Outdated Show resolved Hide resolved
effectful/handlers/pyro.py Outdated Show resolved Hide resolved
@jfeser jfeser mentioned this pull request Jan 31, 2025
Copy link
Contributor

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks for taking a pass at this! I think there are still a couple of easy simplifications of positional_distribution and named_distribution as noted in comments, but you can address them in a followup PR if this is blocking other work.

effectful/handlers/pyro.py Outdated Show resolved Hide resolved
effectful/handlers/pyro.py Outdated Show resolved Hide resolved
@eb8680
Copy link
Contributor

eb8680 commented Jan 31, 2025

CI failure is just a randomized distribution test.

):
self.base_dist = base_dist
self.indices = sizesof(base_dist)
class _DistributionTerm(Term[TorchDistribution], TorchDistribution):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be registered with and accessed through defdata, but we can also address that in a followup PR.

@eb8680 eb8680 merged commit c2a53ba into master Jan 31, 2025
3 checks passed
@eb8680 eb8680 deleted the jf-sizesof-named-positional branch January 31, 2025 19:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working status:awaiting review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

sizesof doesn't work on NamedDistributions
2 participants