-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement sizesof for named and positional distributions #214
Conversation
I think it would be better and not that much more involved to address the root cause #203 instead of special-casing |
Ok, I'll rework this into something that addresses #203 |
Handling |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This mostly makes sense, I just have a few comments
|
||
super().__init__() | ||
@defop | ||
def positional_distribution( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, why does positional_distribution
needs to be a separate operation, as opposed to just using to_tensor
on defterm(d)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could move this logic into to_tensor
, sure. The default behavior of to_tensor
doesn't work though, because vmap is restricted to only return tensors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed, I'd like to leave this alone for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks for taking a pass at this! I think there are still a couple of easy simplifications of positional_distribution
and named_distribution
as noted in comments, but you can address them in a followup PR if this is blocking other work.
CI failure is just a randomized distribution test. |
): | ||
self.base_dist = base_dist | ||
self.indices = sizesof(base_dist) | ||
class _DistributionTerm(Term[TorchDistribution], TorchDistribution): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be registered with and accessed through defdata
, but we can also address that in a followup PR.
Closes #213