Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed API compatibility on maxpool2d (with indices, empty args) functions #70

Merged
merged 5 commits into from
Jul 26, 2023

Conversation

brucekimrokcmu
Copy link
Contributor

The purpose of this PR is to address several test cases for max_pool2d and max_pool2d_with_indices that require accepting different combinations of a list versus an int for parameters such as kernel_size, stride, pooling, and dilation.

1 : MaxPool2dWithIndicesFullSizeKernelModule_basic
max_pool2d_with_indices(): incompatible function arguments. The following argument types are supported:
    1. (self: pi.mlir._mlir_libs._pi_mlir.Tensor, kernel_size: pi.mlir._mlir_libs._pi_mlir.AnyTorchListOfTorchIntValue, stride: pi.mlir._mlir_libs._pi_mlir.AnyTorchListOfTorchIntValue = [], padding: pi.mlir._mlir_libs._pi_mlir.AnyTorchListOfTorchIntValue = [0, 0], dilation: pi.mlir._mlir_libs._pi_mlir.AnyTorchListOfTorchIntValue = [1, 1], ceil_mode: pi.mlir._mlir_libs._pi_mlir.Torch_BoolValue = False, *, loc: mlir.ir.Location = None, ip: mlir.ir.InsertionPoint = None) -> Tuple[pi.mlir._mlir_libs._pi_mlir.Tensor, pi.mlir._mlir_libs._pi_mlir.Tensor]

Invoked with: Tensor(<block argument> of type '!torch.tensor' at index: 0); kwargs: kernel_size=[Torch_IntValue(%0 = "torch.constant.int"() {value = 4 : i64} : () -> !torch.int), Torch_IntValue(%1 = "torch.constant.int"() {value = 4 : i64} : () -> !torch.int)], stride=1, padding=0, dilation=1 

Since there are a lot of combinations (16 for 4 parameters) we generate the appropriate bindings via recursive template instantiation. Essentially, a recursive template function generates 16 different combinations ofPyAnyTorchListOfTorchIntValue and PyTorch_IntValue for four arguments (kernel_size, stride, padding, dilation) to be accepted for max_pool2d_with_indices_ Ops.

In addition, by setting default argument type and values for the arguments to the Op, we allowed max_pool2d and max_pool2d_with_indices to omit some of the kwargs in use.

Note that to avoid segfault, max_pool2d_with_indices() redefines default loc and ip before all casted arguments are passed into max_pool2d_with_indices_() ops function.

This passes an additional 6 tests.

@123epsilon
Copy link
Contributor

Note: the default template types are in place to allow for default argument support when certain args are omitted

Also, the generate bindings struct implemented here can be reused for other function signatures with the same behavior (i.e. either supporting list or int for a subset of arguments)

Copy link
Contributor

@123epsilon 123epsilon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, only thing is that we should probably add some unit tests for this - particularly in order to test that default arguments work.

@brucekimrokcmu brucekimrokcmu merged commit 26a11a6 into main Jul 26, 2023
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants