diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 8b264d4e5..ac9916888 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -47,6 +47,21 @@ namespace topi { using namespace tvm::te; using namespace topi::detail; +/*! + * \brief Creates an operation to form windows over the input x. + * + * \param x The input tensor. + * \param axis What axis the windows begin forming over. Windows will be formed + * over this axis and all following axes. The axis value determines the window + * shape (and thus, the number of strides): window shape and strides must both + * be of length `data.ndim-axis`. + * \param window_shape The window shape to form over the input. Window shape + * must be of length `data.ndim-axis`. + * \param strides How to stride the window along each dimension. Strides must be + * of length `data.ndim-axis`. + * + * \return A Tensor whose op member is the dim expansion operation + */ inline Tensor windows(const Tensor& x, int axis, Array window_shape, Array strides, std::string name = "T_windows", // TODO(@gussmith23) what to tag it? diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index c80b3e031..feb685865 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -73,6 +73,7 @@ # windows @_reg.register_compute("windows") def compute_windows(attrs, inputs, output_type): + """Compute definition of windows""" return [topi.windows(inputs[0], attrs.axis, attrs.window_shape, attrs.strides)] _reg.register_strategy("windows", strategy.windows_strategy) diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index de4619a6d..a0654c3ca 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -26,6 +26,32 @@ def windows(data, axis, window_shape, strides): + """Form windows over the data tensor. + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + axis : int + What axis the windows begin forming over. Windows will be formed over + this axis and all following axes. The axis value determines the window + shape (and thus, the number of strides): window shape and strides must + both be of length `data.ndim-axis`. + + window_shape : List[int] + The window shape to form over the input. Window shape must be of length + `data.ndim-axis`. + + strides : List[int] + How to stride the window along each dimension. Strides must be of length + `data.ndim-axis`. + + Returns + ------- + result : relay.Expr + The resulting tensor. + """ from .. import _ffi_api as _relay_make return _relay_make.windows(data, axis, window_shape, strides) diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index d7c303efa..3ade90c09 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -973,4 +973,30 @@ def invert_permutation(data): return result def windows(data, axis, window_shape, strides): + """Form windows over the data tensor. + + Parameters + ---------- + data : relay.Expr + The input data to the operator. + + axis : int + What axis the windows begin forming over. Windows will be formed over + this axis and all following axes. The axis value determines the window + shape (and thus, the number of strides): window shape and strides must + both be of length `data.ndim-axis`. + + window_shape : List[int] + The window shape to form over the input. Window shape must be of length + `data.ndim-axis`. + + strides : List[int] + How to stride the window along each dimension. Strides must be of length + `data.ndim-axis`. + + Returns + ------- + result : relay.Expr + The resulting tensor. + """ return cpp.windows(data, axis, window_shape, strides) diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 0a8bca18f..08eea07cb 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -114,21 +114,12 @@ Expr MakeWindows(Expr data, int axis, Array window_shape, Array() .add_argument("data", "Tensor", "The input tensor.") - // TODO(@gussmith23) - //.set_support_level(3) .add_type_rel("Windows", WindowsRel) - // Not needed if we register in python? - //.set_attr("FTVMCompute", WindowsCompute) -// TODO(@gussmith23) .set_attr("TOpPattern", kOpaque); -// TODO(@gussmith23) -//.set_attr("FInferCorrectLayout", ElemwiseArbitraryLayout); // relay.cast TVM_REGISTER_NODE_TYPE(CastAttrs);