-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make to_tensor
an operation
#247
Conversation
|
_partial_eval
effect and handlerto_tensor
an operation
t: Annotated[torch.Tensor, Scoped[A | B]], | ||
*args: Annotated[Operation[[], torch.Tensor], Scoped[A]], | ||
) -> Annotated[torch.Tensor, Scoped[B]]: | ||
def _evaluate(expr): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is there a duplicate of evaluate
here? This seems to suggest we should be changing the interface of _partial_eval
to be an apply
rule. Also, should _partial_eval
just be absorbed into the body of to_tensor_
? When would it ever be correct to use _partial_eval
on its own?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had a bad time trying to overload apply
without breaking something else, so I gave up.
_partial_eval
has other callers, so I left it as its own function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only see one other caller (_register_torch_op
). Should that be made to use to_tensor_
instead, or should _partial_eval
be absorbed into its body? Right now it seems like the logic is dispersed across a few different locations in a way that makes it hard to unwind.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I looked at this a bit over the weekend, will put up a separate PR.
dim_ops = [a.op if isinstance(a, Term) else None for a in dims] | ||
perm = [dim_ops.index(o) for o in args] + reindex_dims | ||
tensor = tensor.permute(perm) | ||
return tensor[(slice(None),) * len(args) + tuple(dims[i] for i in reindex_dims)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not seeing the equivalent code in the previous version for this block at the end of to_tensor_
. What is this logic doing that is not accomplished by _partial_eval
's reindex_flat_tensor
? Why is the permute
call above necessary? Is this meant to fix a bug?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The distinction between the two functions is that _partial_eval
is only concerned with the partial evaluation rule and to_tensor
applies partial evaluation, then reorders dimensions. In the previous implementation, _partial_eval
did both things and the permute was implicit in ordered_sized_fvs
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good - seems like handlers.torch
is converging on the idealized design mentioned in #203.
Will merge as soon as conflicts from #254 are addressed. |
Allows partial evaluation to be part of a term.