Skip to content
This repository has been archived by the owner on Jan 24, 2024. It is now read-only.

Commit

Permalink
Revert "remove sum decomposer and revert sum ir realize (#1520)" (#1536)
Browse files Browse the repository at this point in the history
This reverts commit 68cbd14.
  • Loading branch information
thisjiang authored Jun 27, 2023
1 parent dd96678 commit d7395bc
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 58 deletions.
2 changes: 2 additions & 0 deletions cinn/frontend/decomposer/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ core_gather_headers()

gather_srcs(cinnapi_src SRCS
activation.cc
elementwise.cc
broadcast.cc
batch_norm.cc
top_k.cc
Expand All @@ -11,6 +12,7 @@ cc_library(decomposer_test_helper SRCS test_helper.cc DEPS cinncore)

if (WITH_CUDA)
cc_test(test_activation_decomposer SRCS activation_test.cc DEPS cinncore decomposer_test_helper)
cc_test(test_elementwise_decomposer SRCS elementwise_test.cc DEPS cinncore decomposer_test_helper)
cc_test(test_broadcast_decomposer SRCS broadcast_test.cc DEPS cinncore decomposer_test_helper)
cc_test(test_batch_norm_decomposer SRCS batch_norm_test.cc DEPS cinncore decomposer_test_helper)
cc_test(test_top_k_decomposer SRCS top_k_test.cc DEPS cinncore decomposer_test_helper)
Expand Down
46 changes: 46 additions & 0 deletions cinn/frontend/decomposer/elementwise.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "cinn/frontend/decomposer_registry.h"
#include "cinn/frontend/syntax.h"

namespace cinn {
namespace frontend {
namespace decomposer {

void sum(const Instruction& instr, const DecomposerContext& context) {
CHECK_GT(instr->inputs.size(), 0UL) << "At least 1 input tensor for " << instr->op_type;
CHECK_EQ(instr->outputs.size(), 1UL) << "1 output tensor for " << instr->op_type;
auto inputs = instr->inputs;
auto output = instr->outputs[0];
auto* builder = context.builder();

auto sum = builder->Identity(inputs[0]);
for (size_t i = 1; i < inputs.size(); ++i) {
sum = builder->Add(sum, inputs[i]);
}

// map the the output of decomposed operator to the original.
context.MapOutToOrigin(sum, output);
}

} // namespace decomposer
} // namespace frontend
} // namespace cinn

CINN_REGISTER_HELPER(sum_decomposers) {
CINN_DECOMPOSER_REGISTER(sum, cinn::frontend::decomposer::sum);

return true;
}
45 changes: 45 additions & 0 deletions cinn/frontend/decomposer/elementwise_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <gtest/gtest.h>

#include "cinn/frontend/decomposer/test_helper.h"

namespace cinn::frontend {

TEST(Decomposer, sum) {
NetBuilder builder("sum");
auto x = builder.CreateInput(Float(32), {32, 16});
auto y = builder.CreateInput(Float(32), {32, 16});
auto z = builder.CreateInput(Float(32), {32, 16});
auto out = builder.Sum({x, y, z});

auto sum_cpu = [](const std::vector<size_t>& lengths, const std::vector<void*>& ptrs) {
size_t n = lengths[0];
float* x = static_cast<float*>(ptrs[0]);
float* y = static_cast<float*>(ptrs[1]);
float* z = static_cast<float*>(ptrs[2]);
float* out = static_cast<float*>(ptrs[3]);
for (size_t i = 0; i < n; ++i) {
out[i] = x[i] + y[i] + z[i];
}
};

std::vector<std::string> input_names = {x.id().data(), y.id().data(), z.id().data()};
std::vector<std::string> output_names = {out->id};
std::vector<std::vector<int>> output_shapes = {{32, 16}};
RunAndCheck<float>(builder, input_names, output_names, output_shapes, sum_cpu);
}

} // namespace cinn::frontend
1 change: 1 addition & 0 deletions cinn/frontend/decomposer/use_decomposer.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ CINN_USE_REGISTER(relu_decomposers)
CINN_USE_REGISTER(relu_grad_decomposers)
CINN_USE_REGISTER(gelu_decomposers)
CINN_USE_REGISTER(softmax_decomposers)
CINN_USE_REGISTER(sum_decomposers)
CINN_USE_REGISTER(broadcast_decomposers)
CINN_USE_REGISTER(broadcast_grad_decomposers)
CINN_USE_REGISTER(batch_norm_decomposer)
Expand Down
1 change: 1 addition & 0 deletions cinn/frontend/pass/auto_cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ static std::unordered_map<std::string, CastImplFunc> need_cast_list = {
{"reduce_prod", CommonCastImpl},
// composite function
{"sigmoid", CommonCastImpl},
{"sum", CommonCastImpl},
{"softmax", CommonCastImpl},
{"gelu", CommonCastImpl},
{"batch_norm",
Expand Down
41 changes: 1 addition & 40 deletions cinn/hlir/op/elementwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#include "cinn/hlir/pe/ir_schedule_pe.h"
#include "cinn/hlir/pe/nn.h"
#include "cinn/hlir/pe/schedule.h"
#include "cinn/ir/ir.h"
#include "cinn/ir/ir_operators.h"
#include "cinn/utils/functional.h"

Expand Down Expand Up @@ -262,45 +261,7 @@ std::shared_ptr<OpStrategy> StrategyForSum(const framework::NodeAttr &attrs,
const std::vector<Type> &out_type,
const std::vector<std::vector<int>> &output_shapes,
const Target &target) {
framework::CINNCompute sum_compute([=](lang::Args args, lang::RetValue *ret) {
CHECK(!args.empty()) << "The input argument of const_float compute is empty! Please check.";

CINNValuePack pack_args = args[0];
CHECK_GE(pack_args.size(), 2U) << "The sum op should has at least 1 inputs.";

// the front args are tensor
std::vector<ir::Tensor> inputs;
for (int i = 0; i < pack_args.size() - 1; ++i) {
Expr arg = pack_args[i];
CHECK(arg.as_tensor());
inputs.emplace_back(arg.as_tensor_ref());
}
// the last arg is tensor name
CHECK(pack_args.back().is_string()) << "Cannot run at FLAGS_cinn_ir_schedule=false! Please check.";
auto tensor_name = pack_args.back().operator std::string();

auto out = lang::Compute(
{ToCinnExprs(output_shapes.at(0))},
[=](const std::vector<Expr> &indice) {
std::vector<Expr> nums;
for (auto &in : inputs) {
nums.emplace_back(in(indice));
}
return ir::Sum::Make(nums);
},
tensor_name);
CHECK(out.defined()) << "can't create sum op";

auto tensors = inputs;
tensors.emplace_back(out);
auto stages = CreateStages(tensors);
*ret = CINNValuePack{{CINNValue(out), CINNValue(stages)}};
});

auto strategy = std::make_shared<framework::OpStrategy>();
strategy->AddImpl(sum_compute, GetElementwiseScheduleFunc(output_shapes, target), "strategy.const_scalar.x86", 1);

return strategy;
LOG(FATAL) << "The operator will be decomposed into several primitive operators. Please Use Decomposer Program Pass.";
}

std::vector<shape_t> InferShapeForSum(const std::vector<shape_t> &inputs_shape, const framework::AttrMapType &attrs) {
Expand Down
18 changes: 0 additions & 18 deletions python/tests/ops/test_sum_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,24 +158,6 @@ def init_attrs(self):
self.attrs = []


class TestSumOpLargeInputTest(TestCaseHelper):
def init_attrs(self):
self.class_name = "TestSumOpLargeInputTest"
self.cls = TestSumOp
self.inputs = [{
"shapes": [[64]] * 100,
}, {
"shapes": [[64, 32, 16, 1, 128]] * 20,
}, {
"shapes": [[1, 1, 1, 1, 1]] * 100,
}, {
"shapes": [[1048576]] * 20,
}]
self.dtypes = [{"dtype": "float32"}]
self.attrs = []


if __name__ == "__main__":
TestSumOpShapeTest().run()
TestSumOpDtypeTest().run()
TestSumOpLargeInputTest().run()

0 comments on commit d7395bc

Please sign in to comment.