diff --git a/CMakeLists.txt b/CMakeLists.txt index 9b276863..5cfb2a91 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -91,7 +91,7 @@ target_link_libraries(GoogleBenchmark INTERFACE Threads::Threads) # Find OpenCV #------------------------------------------------------------------------------- -if(DEFINED IMAGE_PROCESSING_BENCHMARKS OR OP_OPTIMIZATION_BENCHMARKS) +if(DEFINED IMAGE_PROCESSING_BENCHMARKS) find_package(OpenCV REQUIRED CONFIG) include_directories(${OpenCV_INCLUDE_DIRS}) endif() diff --git a/benchmarks/OpOptimization/CMakeLists.txt b/benchmarks/OpOptimization/CMakeLists.txt index f96942a1..3688f89e 100644 --- a/benchmarks/OpOptimization/CMakeLists.txt +++ b/benchmarks/OpOptimization/CMakeLists.txt @@ -1,2 +1,3 @@ add_subdirectory(Conv2dNchwFchw) add_subdirectory(MatMul) +add_subdirectory(Conv2dOp) diff --git a/benchmarks/OpOptimization/Conv2dOp/CMakeLists.txt b/benchmarks/OpOptimization/Conv2dOp/CMakeLists.txt new file mode 100644 index 00000000..bbf9d3eb --- /dev/null +++ b/benchmarks/OpOptimization/Conv2dOp/CMakeLists.txt @@ -0,0 +1,66 @@ +if (CROSS_COMPILE_RVV) + set(RISCV_GNU_TOOLCHAIN ${BUDDY_MLIR_BUILD_DIR}/thirdparty/riscv-gnu-toolchain) + set(RISCV_GNU_TOOLCHAIN_SYSROOT ${RISCV_GNU_TOOLCHAIN}/sysroot) + set(BUDDY_OPT_TRIPLE riscv64) + set(BUDDY_OPT_ATTR +v,+m) +endif() + +add_custom_command(OUTPUT conv2d_scalar.o + COMMAND cat ${BUDDY_SOURCE_DIR}/benchmarks/OpOptimization/Conv2dOp/Conv2D.mlir | + sed 's/@conv_2d/@conv_2d_scalar/' | + ${BUDDY_MLIR_BUILD_DIR}/bin/buddy-opt + -convert-linalg-to-loops + -lower-affine + -arith-bufferize + -convert-scf-to-cf + -convert-vector-to-llvm + -convert-arith-to-llvm + -finalize-memref-to-llvm + -llvm-request-c-wrappers + -convert-func-to-llvm + -reconcile-unrealized-casts | + ${LLVM_MLIR_BINARY_DIR}/mlir-translate --mlir-to-llvmir | + ${LLVM_MLIR_BINARY_DIR}/llc -O3 -mtriple=${BUDDY_OPT_TRIPLE} + -mattr=${BUDDY_OPT_ATTR} --filetype=obj + -o ${BUDDY_BINARY_DIR}/../benchmarks/OpOptimization/Conv2dOp/conv2d_scalar.o +) +add_library(Conv2DScalar STATIC conv2d_scalar.o) +set_target_properties(Conv2DScalar PROPERTIES LINKER_LANGUAGE CXX) + +add_custom_command(OUTPUT conv2d_rvv.o + COMMAND cat ${BUDDY_SOURCE_DIR}/benchmarks/OpOptimization/Conv2dOp/Conv2DRVV.mlir | + sed 's/@conv_2d/@conv_2d_rvv/' | + ${BUDDY_MLIR_BUILD_DIR}/bin/buddy-opt + -lower-affine + -convert-scf-to-cf + -convert-math-to-llvm + -lower-vector-exp + -lower-rvv + -convert-vector-to-llvm + -finalize-memref-to-llvm + -llvm-request-c-wrappers + -convert-func-to-llvm + -reconcile-unrealized-casts | + ${BUDDY_MLIR_BUILD_DIR}/bin/buddy-translate --buddy-to-llvmir | + ${LLVM_MLIR_BINARY_DIR}/llc -O3 -mtriple=${BUDDY_OPT_TRIPLE} + -mattr=${BUDDY_OPT_ATTR} --filetype=obj + -o ${BUDDY_BINARY_DIR}/../benchmarks/OpOptimization/Conv2dOp/conv2d_rvv.o +) +add_library(Conv2DRVV STATIC conv2d_rvv.o) +set_target_properties(Conv2DRVV PROPERTIES LINKER_LANGUAGE CXX) + +add_executable(conv2d-benchmark + Conv2DBenchmark.cpp + ) + +set_target_properties(conv2d-benchmark PROPERTIES + LINK_FLAGS "-static" +) + +set(BenchmarkTool GoogleBenchmark) + +target_link_libraries(conv2d-benchmark + ${BenchmarkTool} + Conv2DScalar + Conv2DRVV + ) diff --git a/benchmarks/OpOptimization/Conv2dOp/Conv2D.mlir b/benchmarks/OpOptimization/Conv2dOp/Conv2D.mlir new file mode 100644 index 00000000..3b81a237 --- /dev/null +++ b/benchmarks/OpOptimization/Conv2dOp/Conv2D.mlir @@ -0,0 +1,50 @@ +#map = affine_map<(d0) -> (d0)> +#map1 = affine_map<(d0) -> (d0 ceildiv 32)> +module{ +func.func @conv_2d(%arg0: memref, %arg1: memref, %arg2: memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c32 = arith.constant 32 : index + %cst = arith.constant 0 : i32 + %0 = vector.splat %cst : vector<32xi32> + %dim = memref.dim %arg1, %c0 : memref + %dim_0 = memref.dim %arg1, %c1 : memref + %dim_1 = memref.dim %arg2, %c0 : memref + %dim_2 = memref.dim %arg2, %c1 : memref + affine.for %arg3 = #map(%c0) to #map(%dim_1) { + affine.for %arg4 = #map(%c0) to #map(%dim) { + affine.for %arg5 = #map(%c0) to #map(%dim_0) { + affine.for %arg6 = #map(%c0) to #map1(%dim_2) { + %1 = memref.load %arg1[%arg4, %arg5] : memref + %2 = arith.index_cast %c0 : index to i32 + %4 = arith.cmpi sge, %1, %2 : i32 + scf.if %4 { + %5 = vector.broadcast %1 : i32 to vector<32xi32> + %6 = arith.muli %arg6, %c32 : index + %7 = arith.subi %dim_2, %6 : index + %8 = arith.cmpi sge, %7, %c32 : index + scf.if %8 { + %9 = affine.vector_load %arg0[%arg3 + %arg4, %arg5 + %arg6 * 32] : memref, vector<32xi32> + %10 = affine.vector_load %arg2[%arg3, %arg6 * 32] : memref, vector<32xi32> + %11 = arith.muli %9, %5 : vector<32xi32> + %12 = arith.addi %10, %11 : vector<32xi32> + affine.vector_store %12, %arg2[%arg3, %arg6 * 32] : memref, vector<32xi32> + } else { + %9 = vector.create_mask %7 : vector<32xi1> + %10 = arith.addi %arg3, %arg4 : index + %11 = arith.muli %arg6, %c32 : index + %12 = arith.addi %arg5, %11 : index + %13 = vector.maskedload %arg0[%10, %12], %9, %0 : memref, vector<32xi1>, vector<32xi32> into vector<32xi32> + %14 = vector.maskedload %arg2[%arg3, %11], %9, %0 : memref, vector<32xi1>, vector<32xi32> into vector<32xi32> + %15 = arith.muli %13, %5 : vector<32xi32> + %16 = arith.addi %14, %15 : vector<32xi32> + vector.maskedstore %arg2[%arg3, %11], %9, %16 : memref, vector<32xi1>, vector<32xi32> + } + } + } + } + } + } + return + } +} diff --git a/benchmarks/OpOptimization/Conv2dOp/Conv2DBenchmark.cpp b/benchmarks/OpOptimization/Conv2dOp/Conv2DBenchmark.cpp new file mode 100644 index 00000000..994f2523 --- /dev/null +++ b/benchmarks/OpOptimization/Conv2dOp/Conv2DBenchmark.cpp @@ -0,0 +1,133 @@ +//===- Conv2DBenchmark.cpp ------------------------------------------------===// +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +//===----------------------------------------------------------------------===// +// +// This file implements the benchmark for GEMM operation. +// +//===----------------------------------------------------------------------===// + +#include +#include +#include +#include + +// Define target layout. +#define INPUT_R 16 +#define INPUT_C 16 +#define KERNEL_R 4 +#define KERNEL_C 4 +#define OUTPUT_R (INPUT_R - KERNEL_R + 1) +#define OUTPUT_C (INPUT_C - KERNEL_C + 1) + +// Helper functions and variables. +namespace { +const std::string PASS = "\033[32mPASS\033[0m"; +const std::string FAIL = "\033[31mFAIL\033[0m"; + +bool areArraysEqual(int array1[], int array2[], int size) { + for (int i = 0; i < size; ++i) { + if (array1[i] != array2[i]) { + return false; + } + } + return true; +} +} // namespace + +namespace { +// Declare the C interface. +extern "C" { +void _mlir_ciface_conv_2d_scalar(MemRef *input, MemRef *filter, + MemRef *output); +void _mlir_ciface_conv_2d_rvv(MemRef *input, MemRef *filter, + MemRef *output); +} + +#define DEFINE_BENCHMARK(name, func) \ + void BM_CONV2D_##name(benchmark::State &state) { \ + intptr_t sizesInput[2] = {INPUT_R, INPUT_C}; \ + intptr_t sizesKernel[2] = {KERNEL_R, KERNEL_C}; \ + intptr_t sizesOutput[2] = {OUTPUT_R, OUTPUT_C}; \ + MemRef input(sizesInput, 1); \ + MemRef filter(sizesKernel, 1); \ + MemRef output(sizesOutput, 0); \ + for (auto _ : state) { \ + func(&input, &filter, &output); \ + } \ + } + +DEFINE_BENCHMARK(SCALAR, _mlir_ciface_conv_2d_scalar) +DEFINE_BENCHMARK(RVV, _mlir_ciface_conv_2d_rvv) +} // namespace + +BENCHMARK(BM_CONV2D_SCALAR)->Unit(benchmark::kMillisecond); +BENCHMARK(BM_CONV2D_RVV)->Unit(benchmark::kMillisecond); + +void verification() { + // Set the random number generator. + std::random_device rd; + std::mt19937 generator(rd()); + std::uniform_int_distribution distribution(1, 100); + + // Set the layout sizes of input and output memref container. + intptr_t sizesInput[2] = {INPUT_R, INPUT_C}; + intptr_t sizesKernel[2] = {KERNEL_R, KERNEL_C}; + intptr_t sizesOutput[2] = {OUTPUT_R, OUTPUT_C}; + + // Generate input memref container with random numbers. + const int inputSize = INPUT_R * INPUT_C; + int inputRand[inputSize]; + for (int i = 0; i < inputSize; ++i) { + inputRand[i] = distribution(generator); + } + MemRef inputMemRef(inputRand, sizesInput); + + // Generate kernel memref container with random numbers. + const int kernelSize = KERNEL_R * KERNEL_C; + int kernelRand[kernelSize]; + for (int i = 0; i < kernelSize; ++i) { + kernelRand[i] = distribution(generator); + } + MemRef kernelMemRef(kernelRand, sizesKernel); + + // Generate a result using a scalar method for comparison during verification. + const int outputSize = OUTPUT_R * OUTPUT_C; + MemRef outputScalar(sizesOutput, 0); + MemRef outputRVV(sizesOutput, 0); + _mlir_ciface_conv_2d_scalar(&inputMemRef, &kernelMemRef, &outputScalar); + _mlir_ciface_conv_2d_rvv(&inputMemRef, &kernelMemRef, &outputRVV); + auto resultScalar = outputScalar.getData(); + auto resultRVV = outputRVV.getData(); + + // Print the verfication result. + std::cout << "-----------------------------------------------------------" + << std::endl; + std::cout << "Correctness Verification:" << std::endl; + std::cout << "Transform case: " + << (areArraysEqual(resultScalar, resultRVV, outputSize) ? PASS + : FAIL) + << std::endl; + std::cout << "-----------------------------------------------------------" + << std::endl; +} + +int main(int argc, char **argv) { + // Run benchmark. + ::benchmark::Initialize(&argc, argv); + ::benchmark::RunSpecifiedBenchmarks(); + // Run correctness verification. + verification(); + return 0; +} diff --git a/benchmarks/OpOptimization/Conv2dOp/Conv2DRVV.mlir b/benchmarks/OpOptimization/Conv2dOp/Conv2DRVV.mlir new file mode 100644 index 00000000..ce98ec5e --- /dev/null +++ b/benchmarks/OpOptimization/Conv2dOp/Conv2DRVV.mlir @@ -0,0 +1,56 @@ +module{ + func.func @conv_2d(%arg0: memref, %arg1: memref, %arg2: memref) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %sew = arith.constant 2 : index + %dim = memref.dim %arg1, %c0 : memref + %dim_0 = memref.dim %arg1, %c1 : memref + %dim_1 = memref.dim %arg2, %c0 : memref + %dim_2 = memref.dim %arg2, %c1 : memref + + affine.for %tmp0 = %c0 to %dim_1 { + %tmpAVL, %tmpIdx = scf.while (%avl = %dim_2, %idx = %c0) : (index, index) -> (index, index) { + // If avl greater than zero. + %cond = arith.cmpi sgt, %avl, %c0 : index + // Pass avl, idx to the after region. + scf.condition(%cond) %avl, %idx : index, index + } do { + ^bb0(%avl : index, %idx : index): + %vl = rvv.setvl %avl, %sew, %c1 : index + %vl_i32 = arith.index_cast %vl : index to i32 + %mask = vector.create_mask %vl : vector<[8]xi1> + %c_vector = vector_exp.predication %mask, %vl_i32 : vector<[8]xi1>, i32 { + %ele = vector.load %arg2[%tmp0, %idx] : memref, vector<[8]xi32> + vector.yield %ele : vector<[8]xi32> + } : vector<[8]xi32> + %tmpvector = affine.for %tmp1 = %c0 to %dim iter_args(%vector_iter0 = %c_vector) -> (vector<[8]xi32>) { + %vector_next = affine.for %tmp2 = %c0 to %dim_0 iter_args(%vector_iter1 = %vector_iter0) -> (vector<[8]xi32>) { + %0 = affine.load %arg1[%tmp1, %tmp2] : memref + %1 = arith.addi %tmp0, %tmp1 : index + %2 = arith.addi %idx, %tmp2 : index + %input_vector = vector_exp.predication %mask, %vl_i32 : vector<[8]xi1>, i32 { + %ele = vector.load %arg0[%1, %2] : memref, vector<[8]xi32> + vector.yield %ele : vector<[8]xi32> + } : vector<[8]xi32> + + %3 = rvv.mul %input_vector, %0, %vl : vector<[8]xi32>, i32, index + %output = rvv.add %3, %vector_iter1, %vl : vector<[8]xi32>, vector<[8]xi32>, index + + affine.yield %output: vector<[8]xi32> + } + affine.yield %vector_next : vector<[8]xi32> + } + vector_exp.predication %mask, %vl_i32 : vector<[8]xi1>, i32 { + vector.store %tmpvector, %arg2[%tmp0, %idx] : memref, vector<[8]xi32> + vector.yield + } : () -> () + + // Update idx and avl. + %new_idx = arith.addi %idx, %vl : index + %new_avl = arith.subi %avl, %vl : index + scf.yield %new_avl, %new_idx : index, index + } + } + return + } +}