Skip to content

Commit

Permalink
Fixed API compatibility on maxpool2d (with indices, empty args) funct…
Browse files Browse the repository at this point in the history
…ions (#70)

* Solved maxpool2d (with indices, empty args)

* added unit test for all cases of inputting/omitting args

* python-black reformatted

* added comma for reformatting

* added an empty line for reformatting
  • Loading branch information
brucekimrokcmu authored Jul 26, 2023
1 parent 2896e74 commit 26a11a6
Show file tree
Hide file tree
Showing 8 changed files with 272 additions and 47 deletions.
169 changes: 167 additions & 2 deletions cpp_ext/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,165 @@ PyAnyTorchTensorValue softplus(const PyAnyTorchTensorValue &self,
return {opRef, mlirOperationGetResult(operation, 0)};
}

template <typename T>
const PyAnyTorchListOfTorchIntValue castTypeToListInt(const T arg) {
if constexpr (std::is_same_v<T, PyAnyTorchListOfTorchIntValue>) {
return arg;
} else {
return PyAnyTorchListOfTorchIntValue(py::make_tuple(arg, arg));
}
}

// aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) ->
// (Tensor, Tensor)
std::tuple<PyAnyTorchTensorValue, PyAnyTorchTensorValue>
max_pool2d_with_indices_(const PyAnyTorchTensorValue &self,
const PyAnyTorchListOfTorchIntValue &kernel_size,
const PyAnyTorchListOfTorchIntValue &stride,
const PyAnyTorchListOfTorchIntValue &padding,
const PyAnyTorchListOfTorchIntValue &dilation,
const PyTorch_BoolValue &ceil_mode, PyLocation *loc,
PyInsertionPoint *ip) {
std::string operationName = "torch.aten.max_pool2d_with_indices";
std::vector<PyType> _returnTypes = {
PyAnyTorchTensorType::getWithLeastStaticInformation(
loc->getContext().get()),
PyAnyTorchTensorType::getWithLeastStaticInformation(
loc->getContext().get())};
std::vector<std::reference_wrapper<const PyType>> returnTypes;
for (const auto &returnType : _returnTypes)
returnTypes.push_back(returnType);
PyOperationRef opRef =
createOperation(operationName, returnTypes,
{self, kernel_size, stride, padding, dilation, ceil_mode},
/*attributes=*/{}, loc, ip);
MlirOperation operation = opRef->get();
return std::tuple<PyAnyTorchTensorValue, PyAnyTorchTensorValue>(
{opRef, mlirOperationGetResult(operation, 0)},
{opRef, mlirOperationGetResult(operation, 1)});
}

// aten::max_pool2d_with_indices : (Tensor, Union[int[], int], Union[int[],
// int], Union[int[], int], Union[int[], int], bool) -> (Tensor, Tensor)
template <typename T1, typename T2, typename T3, typename T4>
std::tuple<PyAnyTorchTensorValue, PyAnyTorchTensorValue>
max_pool2d_with_indices(PyAnyTorchTensorValue &self, T1 &kernel_size,
T2 &stride, T3 &padding, T4 &dilation,
PyTorch_BoolValue &ceil_mode, PyLocation *loc,
PyInsertionPoint *ip) {

PyAnyTorchListOfTorchIntValue kernel_size_ = castTypeToListInt(kernel_size);
PyAnyTorchListOfTorchIntValue stride_ = castTypeToListInt(stride);
PyAnyTorchListOfTorchIntValue padding_ = castTypeToListInt(padding);
PyAnyTorchListOfTorchIntValue dilation_ = castTypeToListInt(dilation);
PyLocation *loc_ = &DefaultingPyLocation::resolve();
PyInsertionPoint *ip_ = &DefaultingPyInsertionPoint::resolve();

return max_pool2d_with_indices_(self, kernel_size_, stride_, padding_,
dilation_, ceil_mode, loc_, ip_);
}

struct bind_max_pool2d_with_indices {
template <typename T1, typename T2 = PyAnyTorchListOfTorchIntValue,
typename T3 = PyAnyTorchListOfTorchIntValue,
typename T4 = PyAnyTorchListOfTorchIntValue>
static void bind(py::module &m) {
m.def(
"max_pool2d_with_indices",
[](PyAnyTorchTensorValue &self, T1 &kernel_size, T2 &stride,
T3 &padding, T4 &dilation, PyTorch_BoolValue &ceil_mode,
PyLocation *loc, PyInsertionPoint *ip)
-> std::tuple<PyAnyTorchTensorValue, PyAnyTorchTensorValue> {
return max_pool2d_with_indices(self, kernel_size, stride, padding,
dilation, ceil_mode, loc, ip);
},
"self"_a, "kernel_size"_a, "stride"_a = std::vector<int>{},
"padding"_a = std::vector<int>{0, 0},
"dilation"_a = std::vector<int>{1, 1}, "ceil_mode"_a = false,
py::kw_only(), "loc"_a = py::none(), "ip"_a = py::none());
}
};

// aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)
PyAnyTorchTensorValue
max_pool2d_(const PyAnyTorchTensorValue &self,
const PyAnyTorchListOfTorchIntValue &kernel_size,
const PyAnyTorchListOfTorchIntValue &stride,
const PyAnyTorchListOfTorchIntValue &padding,
const PyAnyTorchListOfTorchIntValue &dilation,
const PyTorch_BoolValue &ceil_mode, PyLocation *loc,
PyInsertionPoint *ip) {
std::string operationName = "torch.aten.max_pool2d";
std::vector<PyType> _returnTypes = {
PyAnyTorchTensorType::getWithLeastStaticInformation(
loc->getContext().get())};
std::vector<std::reference_wrapper<const PyType>> returnTypes;
for (const auto &returnType : _returnTypes)
returnTypes.push_back(returnType);
PyOperationRef opRef =
createOperation(operationName, returnTypes,
{self, kernel_size, stride, padding, dilation, ceil_mode},
/*attributes=*/{}, loc, ip);
MlirOperation operation = opRef->get();
return {opRef, mlirOperationGetResult(operation, 0)};
}

// aten::max_pool2d : (Tensor, Union[int[], int], Union[int[], int],
// Union[int[], int], Union[int[], int], bool) -> (Tensor)
template <typename T1, typename T2, typename T3, typename T4>
PyAnyTorchTensorValue max_pool2d(PyAnyTorchTensorValue &self, T1 &kernel_size,
T2 &stride, T3 &padding, T4 &dilation,
PyTorch_BoolValue &ceil_mode, PyLocation *loc,
PyInsertionPoint *ip) {

PyAnyTorchListOfTorchIntValue kernel_size_ = castTypeToListInt(kernel_size);
PyAnyTorchListOfTorchIntValue stride_ = castTypeToListInt(stride);
PyAnyTorchListOfTorchIntValue padding_ = castTypeToListInt(padding);
PyAnyTorchListOfTorchIntValue dilation_ = castTypeToListInt(dilation);
PyLocation *loc_ = &DefaultingPyLocation::resolve();
PyInsertionPoint *ip_ = &DefaultingPyInsertionPoint::resolve();

return max_pool2d_(self, kernel_size_, stride_, padding_, dilation_,
ceil_mode, loc_, ip_);
}

struct bind_max_pool2d {
template <typename T1, typename T2 = PyAnyTorchListOfTorchIntValue,
typename T3 = PyAnyTorchListOfTorchIntValue,
typename T4 = PyAnyTorchListOfTorchIntValue>
static void bind(py::module &m) {
m.def(
"max_pool2d",
[](PyAnyTorchTensorValue &self, T1 &kernel_size, T2 &stride,
T3 &padding, T4 &dilation, PyTorch_BoolValue &ceil_mode,
PyLocation *loc, PyInsertionPoint *ip) -> PyAnyTorchTensorValue {
return max_pool2d(self, kernel_size, stride, padding, dilation,
ceil_mode, loc, ip);
},
"self"_a, "kernel_size"_a, "stride"_a = std::vector<int>{},
"padding"_a = std::vector<int>{0, 0},
"dilation"_a = std::vector<int>{1, 1}, "ceil_mode"_a = false,
py::kw_only(), "loc"_a = py::none(), "ip"_a = py::none());
}
};

// Recursive function to generate bindings for all combinations of
// Torch_IntValue and AnyTorchListOfTorchIntValue in N slots
template <unsigned int N, class Callback, typename... Args>
struct generateListIntCompatibleBindings {
static void generate(py::module &m) {
generateListIntCompatibleBindings<N - 1, Callback, Args...,
PyTorch_IntValue>::generate(m);
generateListIntCompatibleBindings<
N - 1, Callback, Args..., PyAnyTorchListOfTorchIntValue>::generate(m);
}
};

template <class Callback, typename... Args>
struct generateListIntCompatibleBindings<0, Callback, Args...> {
static void generate(py::module &m) { Callback::template bind<Args...>(m); }
};

void populateTorchMLIROps(py::module &m) {
py::register_exception_translator([](std::exception_ptr p) {
try {
Expand Down Expand Up @@ -184,6 +343,7 @@ void populateTorchMLIROps(py::module &m) {
},
"lhs"_a, "rhs"_a, py::kw_only(), "loc"_a = py::none(),
"ip"_a = py::none());

m.def("avg_pool1d",
[](PyAnyTorchTensorValue &self,
PyAnyTorchListOfTorchIntType &kernel_size,
Expand Down Expand Up @@ -229,7 +389,8 @@ void populateTorchMLIROps(py::module &m) {
"dtype"_a = py::none(), py::kw_only(), "loc"_a = py::none(),
"ip"_a = py::none());

// aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)
// aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) ->
// (Tensor)
m.def(
"vector_norm",
[](const PyAnyTorchTensorValue &self, const PyAnyTorchScalarValue &ord,
Expand Down Expand Up @@ -307,6 +468,10 @@ void populateTorchMLIROps(py::module &m) {
},
"self"_a, "beta"_a = 1, "threshold__"_a = 20, py::kw_only(),
"loc"_a = py::none(), "ip"_a = py::none());

generateListIntCompatibleBindings<4, bind_max_pool2d>::generate(m);
generateListIntCompatibleBindings<4, bind_max_pool2d_with_indices>::generate(
m);
}

} // namespace mlir::torch
}; // namespace mlir::torch
20 changes: 20 additions & 0 deletions cpp_ext/TorchOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,26 @@ PyAnyTorchTensorValue softplus(const PyAnyTorchTensorValue &self,
const PyAnyTorchScalarValue &threshold__,
PyLocation *loc, PyInsertionPoint *ip);

// aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) ->
// (Tensor, Tensor)
std::tuple<PyAnyTorchTensorValue, PyAnyTorchTensorValue>
max_pool2d_with_indices_(const PyAnyTorchTensorValue &self,
const PyAnyTorchListOfTorchIntValue &kernel_size,
const PyAnyTorchListOfTorchIntValue &stride,
const PyAnyTorchListOfTorchIntValue &padding,
const PyAnyTorchListOfTorchIntValue &dilation,
const PyTorch_BoolValue &ceil_mode, PyLocation *loc,
PyInsertionPoint *ip);

PyAnyTorchTensorValue
max_pool2d_(const PyAnyTorchTensorValue &self,
const PyAnyTorchListOfTorchIntValue &kernel_size,
const PyAnyTorchListOfTorchIntValue &stride,
const PyAnyTorchListOfTorchIntValue &padding,
const PyAnyTorchListOfTorchIntValue &dilation,
const PyTorch_BoolValue &ceil_mode, PyLocation *loc,
PyInsertionPoint *ip);

void populateTorchMLIROps(py::module &m);

} // namespace mlir::torch
Expand Down
32 changes: 0 additions & 32 deletions cpp_ext/TorchOps.impls.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3822,22 +3822,6 @@ PyAnyTorchTensorValue max(const PyAnyTorchTensorValue &self, PyLocation *loc, Py
MlirOperation operation = opRef->get();
return {opRef, mlirOperationGetResult(operation, 0)};
}
// aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)
PyAnyTorchTensorValue max_pool2d(const PyAnyTorchTensorValue &self, const PyAnyTorchListOfTorchIntValue &kernel_size, const PyAnyTorchListOfTorchIntValue &stride, const PyAnyTorchListOfTorchIntValue &padding, const PyAnyTorchListOfTorchIntValue &dilation, const PyTorch_BoolValue &ceil_mode, PyLocation *loc, PyInsertionPoint *ip) {
std::string operationName = "torch.aten.max_pool2d";
std::vector<PyType> _returnTypes = {PyAnyTorchTensorType::getWithLeastStaticInformation(loc->getContext().get())};
std::vector<std::reference_wrapper<const PyType>> returnTypes;
for (const auto& returnType : _returnTypes)
returnTypes.push_back(returnType);
PyOperationRef opRef = createOperation(operationName,
returnTypes,
{self, kernel_size, stride, padding, dilation, ceil_mode},
/*attributes=*/{},
loc,
ip);
MlirOperation operation = opRef->get();
return {opRef, mlirOperationGetResult(operation, 0)};
}
// aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)
PyAnyTorchTensorValue max_pool2d_with_indices_backward(const PyAnyTorchTensorValue &grad_output, const PyAnyTorchTensorValue &self, const PyAnyTorchListOfTorchIntValue &kernel_size, const PyAnyTorchListOfTorchIntValue &stride, const PyAnyTorchListOfTorchIntValue &padding, const PyAnyTorchListOfTorchIntValue &dilation, const PyTorch_BoolValue &ceil_mode, const PyAnyTorchTensorValue &indices, PyLocation *loc, PyInsertionPoint *ip) {
std::string operationName = "torch.aten.max_pool2d_with_indices_backward";
Expand All @@ -3854,22 +3838,6 @@ PyAnyTorchTensorValue max_pool2d_with_indices_backward(const PyAnyTorchTensorVal
MlirOperation operation = opRef->get();
return {opRef, mlirOperationGetResult(operation, 0)};
}
// aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)
std::tuple<PyAnyTorchTensorValue, PyAnyTorchTensorValue> max_pool2d_with_indices(const PyAnyTorchTensorValue &self, const PyAnyTorchListOfTorchIntValue &kernel_size, const PyAnyTorchListOfTorchIntValue &stride, const PyAnyTorchListOfTorchIntValue &padding, const PyAnyTorchListOfTorchIntValue &dilation, const PyTorch_BoolValue &ceil_mode, PyLocation *loc, PyInsertionPoint *ip) {
std::string operationName = "torch.aten.max_pool2d_with_indices";
std::vector<PyType> _returnTypes = {PyAnyTorchTensorType::getWithLeastStaticInformation(loc->getContext().get()), PyAnyTorchTensorType::getWithLeastStaticInformation(loc->getContext().get())};
std::vector<std::reference_wrapper<const PyType>> returnTypes;
for (const auto& returnType : _returnTypes)
returnTypes.push_back(returnType);
PyOperationRef opRef = createOperation(operationName,
returnTypes,
{self, kernel_size, stride, padding, dilation, ceil_mode},
/*attributes=*/{},
loc,
ip);
MlirOperation operation = opRef->get();
return std::tuple<PyAnyTorchTensorValue, PyAnyTorchTensorValue>({opRef, mlirOperationGetResult(operation, 0)}, {opRef, mlirOperationGetResult(operation, 1)});
}
// aten::maximum : (Tensor, Tensor) -> (Tensor)
PyAnyTorchTensorValue maximum(const PyAnyTorchTensorValue &self, const PyAnyTorchTensorValue &other, PyLocation *loc, PyInsertionPoint *ip) {
std::string operationName = "torch.aten.maximum";
Expand Down
6 changes: 0 additions & 6 deletions cpp_ext/TorchOps.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -716,15 +716,9 @@ std::tuple<PyAnyTorchTensorValue, PyAnyTorchTensorValue> max(const PyAnyTorchTen
// aten::max : (Tensor) -> (Tensor)
PyAnyTorchTensorValue max(const PyAnyTorchTensorValue &self, PyLocation *loc, PyInsertionPoint *ip);

// aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)
PyAnyTorchTensorValue max_pool2d(const PyAnyTorchTensorValue &self, const PyAnyTorchListOfTorchIntValue &kernel_size, const PyAnyTorchListOfTorchIntValue &stride, const PyAnyTorchListOfTorchIntValue &padding, const PyAnyTorchListOfTorchIntValue &dilation, const PyTorch_BoolValue &ceil_mode, PyLocation *loc, PyInsertionPoint *ip);

// aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)
PyAnyTorchTensorValue max_pool2d_with_indices_backward(const PyAnyTorchTensorValue &grad_output, const PyAnyTorchTensorValue &self, const PyAnyTorchListOfTorchIntValue &kernel_size, const PyAnyTorchListOfTorchIntValue &stride, const PyAnyTorchListOfTorchIntValue &padding, const PyAnyTorchListOfTorchIntValue &dilation, const PyTorch_BoolValue &ceil_mode, const PyAnyTorchTensorValue &indices, PyLocation *loc, PyInsertionPoint *ip);

// aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)
std::tuple<PyAnyTorchTensorValue, PyAnyTorchTensorValue> max_pool2d_with_indices(const PyAnyTorchTensorValue &self, const PyAnyTorchListOfTorchIntValue &kernel_size, const PyAnyTorchListOfTorchIntValue &stride, const PyAnyTorchListOfTorchIntValue &padding, const PyAnyTorchListOfTorchIntValue &dilation, const PyTorch_BoolValue &ceil_mode, PyLocation *loc, PyInsertionPoint *ip);

// aten::maximum : (Tensor, Tensor) -> (Tensor)
PyAnyTorchTensorValue maximum(const PyAnyTorchTensorValue &self, const PyAnyTorchTensorValue &other, PyLocation *loc, PyInsertionPoint *ip);

Expand Down
6 changes: 0 additions & 6 deletions cpp_ext/TorchOps.pybinds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -716,15 +716,9 @@ m.def("max", [](const PyAnyTorchTensorValue &self, const PyTorch_IntValue &dim,
// aten::max : (Tensor) -> (Tensor)
m.def("max", [](const PyAnyTorchTensorValue &self, DefaultingPyLocation &loc, const DefaultingPyInsertionPoint &ip) -> PyAnyTorchTensorValue { return max(self, loc.get(), ip.get()); }, "self"_a, py::kw_only(), "loc"_a = py::none(), "ip"_a = py::none());

// aten::max_pool2d : (Tensor, int[], int[], int[], int[], bool) -> (Tensor)
m.def("max_pool2d", [](const PyAnyTorchTensorValue &self, const PyAnyTorchListOfTorchIntValue &kernel_size, const PyAnyTorchListOfTorchIntValue &stride, const PyAnyTorchListOfTorchIntValue &padding, const PyAnyTorchListOfTorchIntValue &dilation, const PyTorch_BoolValue &ceil_mode, DefaultingPyLocation &loc, const DefaultingPyInsertionPoint &ip) -> PyAnyTorchTensorValue { return max_pool2d(self, kernel_size, stride, padding, dilation, ceil_mode, loc.get(), ip.get()); }, "self"_a, "kernel_size"_a, "stride"_a = std::vector<int>{}, "padding"_a = std::vector<int>{0, 0}, "dilation"_a = std::vector<int>{1, 1}, "ceil_mode"_a = false, py::kw_only(), "loc"_a = py::none(), "ip"_a = py::none());

// aten::max_pool2d_with_indices_backward : (Tensor, Tensor, int[], int[], int[], int[], bool, Tensor) -> (Tensor)
m.def("max_pool2d_with_indices_backward", [](const PyAnyTorchTensorValue &grad_output, const PyAnyTorchTensorValue &self, const PyAnyTorchListOfTorchIntValue &kernel_size, const PyAnyTorchListOfTorchIntValue &stride, const PyAnyTorchListOfTorchIntValue &padding, const PyAnyTorchListOfTorchIntValue &dilation, const PyTorch_BoolValue &ceil_mode, const PyAnyTorchTensorValue &indices, DefaultingPyLocation &loc, const DefaultingPyInsertionPoint &ip) -> PyAnyTorchTensorValue { return max_pool2d_with_indices_backward(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices, loc.get(), ip.get()); }, "grad_output"_a, "self"_a, "kernel_size"_a, "stride"_a, "padding"_a, "dilation"_a, "ceil_mode"_a, "indices"_a, py::kw_only(), "loc"_a = py::none(), "ip"_a = py::none());

// aten::max_pool2d_with_indices : (Tensor, int[], int[], int[], int[], bool) -> (Tensor, Tensor)
m.def("max_pool2d_with_indices", [](const PyAnyTorchTensorValue &self, const PyAnyTorchListOfTorchIntValue &kernel_size, const PyAnyTorchListOfTorchIntValue &stride, const PyAnyTorchListOfTorchIntValue &padding, const PyAnyTorchListOfTorchIntValue &dilation, const PyTorch_BoolValue &ceil_mode, DefaultingPyLocation &loc, const DefaultingPyInsertionPoint &ip) -> std::tuple<PyAnyTorchTensorValue, PyAnyTorchTensorValue> { return max_pool2d_with_indices(self, kernel_size, stride, padding, dilation, ceil_mode, loc.get(), ip.get()); }, "self"_a, "kernel_size"_a, "stride"_a = std::vector<int>{}, "padding"_a = std::vector<int>{0, 0}, "dilation"_a = std::vector<int>{1, 1}, "ceil_mode"_a = false, py::kw_only(), "loc"_a = py::none(), "ip"_a = py::none());

// aten::maximum : (Tensor, Tensor) -> (Tensor)
m.def("maximum", [](const PyAnyTorchTensorValue &self, const PyAnyTorchTensorValue &other, DefaultingPyLocation &loc, const DefaultingPyInsertionPoint &ip) -> PyAnyTorchTensorValue { return maximum(self, other, loc.get(), ip.get()); }, "self"_a, "other"_a, py::kw_only(), "loc"_a = py::none(), "ip"_a = py::none());

Expand Down
3 changes: 3 additions & 0 deletions cpp_ext/TorchTensor.pybinds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,9 @@ c.def("_nested_tensor_strides", [](PyAnyTorchTensorValue& self, py::args args, p
// _nnz(self) -> _int
c.def("_nnz", [](PyAnyTorchTensorValue& self, py::args args, py::kwargs kwargs) { throw NotImplementedError("NotImplementedError: _nnz with signature _nnz(self) -> _int"); });

// _sparse_mask_projection(self, mask: Tensor) -> Tensor
c.def("_sparse_mask_projection", [](PyAnyTorchTensorValue& self, py::args args, py::kwargs kwargs) { throw NotImplementedError("NotImplementedError: _sparse_mask_projection with signature _sparse_mask_projection(self, mask: Tensor) -> Tensor"); });

// _to_dense(self, dtype: Optional[_dtype]=None, masked_grad: Optional[_bool]=None) -> Tensor
c.def("_to_dense", [](PyAnyTorchTensorValue& self, py::args args, py::kwargs kwargs) { throw NotImplementedError("NotImplementedError: _to_dense with signature _to_dense(self, dtype: Optional[_dtype]=None, masked_grad: Optional[_bool]=None) -> Tensor"); });

Expand Down
Loading

0 comments on commit 26a11a6

Please sign in to comment.