Skip to content

Design help: 6D SplitSplineEstimator using AutoDiffCost #342

Answered by luisenp
urbste asked this question in Q&A
Discussion options

You must be logged in to vote

@urbste Following our conversation, I thought a bit more about this and came up with something like the mock code below. Instead of passing a function to err_fn I pass a callable object that stores the indices internally, so that they don't have to be passed as aux vars. The code below works for me.

BTW, I suspect the part that vmap doesn't like is the tensor reshaping, which I think you should also be able to avoid using torch.stack with the appropriate dimension. Let me know if the example below makes sense.

class MockReprErr:
    def __init__(self, so3_knot_idx, r3_knot_idx):
        self.so3_knot_idx = so3_knot_idx
        self.r3_knot_idx = r3_knot_idx
        
    def __call__(self, o…

Replies: 7 comments 21 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
18 replies
@urbste
Comment options

@urbste
Comment options

@luisenp
Comment options

luisenp Nov 11, 2022
Collaborator

@urbste
Comment options

@luisenp
Comment options

luisenp Nov 11, 2022
Collaborator

Comment options

You must be logged in to vote
1 reply
@urbste
Comment options

Comment options

You must be logged in to vote
1 reply
@urbste
Comment options

Comment options

You must be logged in to vote
1 reply
@urbste
Comment options

Answer selected by urbste
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants
Converted from issue

This discussion was converted from issue #339 on November 01, 2022 20:21.