diff --git a/CMakeLists.txt b/CMakeLists.txt index deb46775e..9983eae10 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -207,6 +207,7 @@ option(MNN_OPENCL "Enable OpenCL" OFF) option(MNN_OPENGL "Enable OpenGL" OFF) option(MNN_VULKAN "Enable Vulkan" OFF) option(MNN_ARM82 "Enable ARMv8.2's FP16 Compute" ON) +option(MNN_KLEIDIAI "Enable KLEIDIAI" OFF) option(MNN_ONEDNN "Enable oneDNN" OFF) option(MNN_AVX512 "Enable AVX512" OFF) option(MNN_CUDA "Enable CUDA" OFF) @@ -253,6 +254,7 @@ message(STATUS "\tOpenCL: ${MNN_OPENCL}") message(STATUS "\tOpenGL: ${MNN_OPENGL}") message(STATUS "\tVulkan: ${MNN_VULKAN}") message(STATUS "\tARM82: ${MNN_ARM82}") +message(STATUS "\tKleidiAI: ${MNN_KLEIDIAI}") message(STATUS "\toneDNN: ${MNN_ONEDNN}") message(STATUS "\tTensorRT: ${MNN_TENSORRT}") message(STATUS "\tCoreML: ${MNN_COREML}") diff --git a/source/backend/cpu/CMakeLists.txt b/source/backend/cpu/CMakeLists.txt index 82287d69f..e8e465610 100644 --- a/source/backend/cpu/CMakeLists.txt +++ b/source/backend/cpu/CMakeLists.txt @@ -50,3 +50,10 @@ IF(MNN_ARM82) ENDIF() ENDIF() +# Kleidi AI +IF(MNN_KLEIDIAI) + add_definitions(-DMNN_KLEIDIAI_ENABLED=1) + include(${CMAKE_CURRENT_LIST_DIR}/arm/kleidiAI/CMakeLists.txt) + list(APPEND MNN_TARGETS MNN_KleidiAI) + list(APPEND MNN_OBJECTS_TO_LINK $) +ENDIF() \ No newline at end of file diff --git a/source/backend/cpu/CPUBackend.hpp b/source/backend/cpu/CPUBackend.hpp index 00e39fc30..9c11bd12b 100644 --- a/source/backend/cpu/CPUBackend.hpp +++ b/source/backend/cpu/CPUBackend.hpp @@ -17,6 +17,10 @@ #include "core/BufferAllocator.hpp" #include "MNN_generated.h" +#ifdef MNN_KLEIDIAI_ENABLED +#include "arm/kleidiAI/mnn_kleidiai.h" +#endif + namespace MNN { class CPURuntime : public Runtime { public: diff --git a/source/backend/cpu/arm/kleidiAI/CMakeLists.txt b/source/backend/cpu/arm/kleidiAI/CMakeLists.txt new file mode 100644 index 000000000..f12e27c19 --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/CMakeLists.txt @@ -0,0 +1,63 @@ +# +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# + +project(MNN_KleidiAI + LANGUAGES C CXX ASM +) + +set(KLEIDIAI_MIN_CLANG_VERSION 11) +set(KLEIDIAI_MIN_GNU_VERSION 11) + +if(CMAKE_C_COMPILER_ID STREQUAL "Clang" AND CMAKE_C_COMPILER_VERSION VERSION_LESS ${KLEIDIAI_MIN_CLANG_VERSION}) + message(WARNING "KleidiAI: Using non-supported Clang version. Expected ${KLEIDIAI_MIN_CLANG_VERSION} or newer, received ${CMAKE_C_COMPILER_VERSION}.") +endif() + +if(CMAKE_C_COMPILER_ID STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_LESS ${KLEIDIAI_MIN_GNU_VERSION}) + message(WARNING "KleidiAI: Using non-supported GCC version. Expected ${KLEIDIAI_MIN_GNU_VERSION} or newer, received ${CMAKE_C_COMPILER_VERSION}.") +endif() + +list(APPEND MNN_KleidiAI_SOURCES ${CMAKE_CURRENT_LIST_DIR}/mnn_kleidiai.cpp) +list(APPEND MNN_KleidiAI_HEADERS ${CMAKE_CURRENT_LIST_DIR}/mnn_kleidiai.h) + +add_library( + MNN_KleidiAI + SHARED + ${MNN_KleidiAI_SOURCES} ${MNN_KleidiAI_HEADERS} +) + +set(KLEIDIAI_SRC ${CMAKE_CURRENT_LIST_DIR}) + +include_directories( + ${KLEIDIAI_SRC}/ + ${KLEIDIAI_SRC}/kai/ + ${KLEIDIAI_SRC}/kai/ukernels/ + ${KLEIDIAI_SRC}/kai/ukernels/matmul/ + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/ + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/) + +set(KLEIDIAI_FILES_SCALAR + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c +) + +set(KLEIDIAI_FILES_NEON_DOTPROD + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c +) + +set(KLEIDIAI_FILES_NEON_I8MM + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c +) + +# Selectively enable architecture features. +target_sources(MNN_KleidiAI PRIVATE ${KLEIDIAI_FILES_SCALAR}) +if((CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") AND NOT MSVC) + target_sources(MNN_KleidiAI PRIVATE ${KLEIDIAI_FILES_NEON_DOTPROD}) + target_sources(MNN_KleidiAI PRIVATE ${KLEIDIAI_FILES_NEON_I8MM}) + + set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS -march=armv8-a) + set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+dotprod) + set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+i8mm) +endif() \ No newline at end of file diff --git a/source/backend/cpu/arm/kleidiAI/kai/kai_common.h b/source/backend/cpu/arm/kleidiAI/kai/kai_common.h new file mode 100644 index 000000000..9569e5468 --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/kai/kai_common.h @@ -0,0 +1,194 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// NOLINTBEGIN(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c) +// +// * cppcoreguidelines-avoid-do-while: do-while is necessary for macros. +// * cppcoreguidelines-pro-type-vararg: use of variadic arguments in fprintf is expected. +// * cert-err33-c: checking the output of fflush and fprintf is not necessary for error reporting. +#define KAI_ERROR(msg) \ + do { \ + fflush(stdout); \ + fprintf(stderr, "%s:%d %s", __FILE__, __LINE__, msg); \ + exit(EXIT_FAILURE); \ + } while (0) + +#define KAI_ASSERT_MSG(cond, msg) \ + do { \ + if (!(cond)) { \ + KAI_ERROR(msg); \ + } \ + } while (0) + +// NOLINTEND(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c) + +#define KAI_ASSERT(cond) KAI_ASSERT_MSG(cond, #cond) + +#define KAI_ASSERT_IF_MSG(precond, cond, msg) KAI_ASSERT_MSG(!(precond) || (cond), msg) +#define KAI_ASSERT_IF(precond, cond) KAI_ASSERT_IF_MSG(precond, cond, #precond " |-> " #cond) + +#define KAI_ASSUME_MSG KAI_ASSERT_MSG +#define KAI_ASSUME KAI_ASSERT +#define KAI_ASSUME_IF_MSG KAI_ASSERT_IF_MSG +#define KAI_ASSUME_IF KAI_ASSERT_IF + +#define KAI_UNUSED(x) (void)(x) +#define KAI_MIN(a, b) (((a) < (b)) ? (a) : (b)) +#define KAI_MAX(a, b) (((a) > (b)) ? (a) : (b)) + +/// KleidiAI data types +/// Format: (reserved)|(num-bytes)|(type)|(variant-type) +enum kai_datatype { + kai_dt_unknown = 0x0000, + kai_dt_f32 = 0x0411, + kai_dt_f16 = 0x0212, + kai_dt_bf16 = 0x0213, + kai_dt_int32 = 0x0421, + kai_dt_int16 = 0x0222, + kai_dt_int8 = 0x0124, + kai_dt_uint32 = 0x0431, + kai_dt_uint16 = 0x0232, + kai_dt_uint8 = 0x0134, + kai_dt_bool = 0x0441 +}; + +/// Gets number of bytes for a given data type +/// @param[in] dt KleidiAI data type +/// +/// @return the numbers of bytes for the data type +inline static size_t kai_get_datatype_size_in_bytes(enum kai_datatype dt) { + return (size_t)(dt >> 8); +} + +/// Converts a scalar f16 value to f32 +/// @param[in] f16 The f16 value +/// +/// @return the f32 value +inline static float kai_cast_f32_f16(uint16_t f16) { +#if defined(__ARM_NEON) + __fp16 f32 = 0; + memcpy(&f32, &f16, sizeof(uint16_t)); + return (float)f32; +#endif +} + +/// Converts a scalar bf16 value to f32 +/// @param[in] bf16 The f16 value +/// +/// @return the f32 value +inline static float kai_cast_f32_bf16(uint16_t bf16) { + const uint32_t i32 = (bf16 << 16); + float f32; + memcpy(&f32, &i32, sizeof(i32)); + return f32; +} + +/// Converts a f32 value to bf16 +/// @param[in] f32 The f32 value +/// +/// @return the bf16 value +inline static uint16_t kai_cast_bf16_f32(float f32) { + uint16_t bf16 = 0; +#ifdef __ARM_FEATURE_BF16 + __asm__ __volatile__("bfcvt %h[output], %s[input]" : [output] "=w"(bf16) : [input] "w"(f32)); +#else + const uint32_t* i32 = (uint32_t*)(&f32); + bf16 = (*i32 >> 16); +#endif + return bf16; +} + +/// Converts a scalar f32 value to f16 +/// @param[in] f32 The f32 value +/// +/// @return the f16 value +inline static uint16_t kai_cast_f16_f32(float f32) { +#if defined(__ARM_NEON) + uint16_t f16 = 0; + __fp16 tmp = f32; + memcpy(&f16, &tmp, sizeof(uint16_t)); + return f16; +#endif +} + +inline static size_t kai_roundup(size_t a, size_t b) { + return ((a + b - 1) / b) * b; +} + +#ifdef __ARM_FEATURE_SVE + +/// Gets the SME vector length for 8-bit elements. +inline static uint64_t kai_get_sme_vector_length_u8(void) { + uint64_t res = 0; + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "cntb %0\n" + ".inst 0xd503467f // SMSTOP\n" + : "=r"(res) + : + : "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", + "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); + + return res; +} + +/// Gets the SME vector length for 16-bit elements. +inline static uint64_t kai_get_sme_vector_length_u16(void) { + uint64_t res = 0; + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "cnth %0\n" + ".inst 0xd503467f // SMSTOP\n" + : "=r"(res) + : + : "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", + "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); + + return res; +} + +/// Gets the SME vector length for 32-bit elements. +inline static uint64_t kai_get_sme_vector_length_u32(void) { + uint64_t res = 0; + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "cntw %0\n" + ".inst 0xd503467f // SMSTOP\n" + : "=r"(res) + : + : "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", + "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); + + return res; +} + +#endif // __ARM_FEATURE_SVE + +/// Extends the sign bit of int 4-bit value (stored in int8_t variable) +/// @param[in] value The 4-bit int value +/// +/// @return the int8_t value with sign extended +inline static int8_t kai_ext_sign_i8_i4(int8_t value) { + return (value ^ 0x8) - 8; +} + +#ifdef __cplusplus +} +#endif diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c new file mode 100644 index 000000000..cd24f7313 --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c @@ -0,0 +1,229 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__ARM_FEATURE_DOTPROD) +#error "Dotprod extension required to compile this micro-kernel" +#else +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" + +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_m_step = 1; +static const size_t kai_n_step = 4; +static const size_t kai_mr = 1; +static const size_t kai_nr = 4; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +static const size_t kai_num_bytes_bias = sizeof(float); + +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t n_idx, size_t k) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod( + size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed, float* restrict dst, + size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + const size_t kai_k0 = kai_kr * kai_sr; + + const size_t num_rows = m; + const size_t num_cols = n; + + const size_t lhs_packed_stride = kai_lhs_packed_stride(k); + const size_t k_internal = kai_k_roundedup(k); + + const int8x16_t nibble_mask = vdupq_n_s8(0xF0); + + const uint8_t* lhs_ptr_start = lhs_packed; + + for (size_t row_idx = 0; row_idx < num_rows; row_idx += kai_mr) { + const uint8_t* rhs_ptr = rhs_packed; + for (size_t col_idx = 0; col_idx < num_cols; col_idx += kai_nr) { + const uint8_t* lhs_ptr = lhs_ptr_start; + + // Main f32 accumulator + int32x4_t iacc0011 = vdupq_n_s32(0); + int32x4_t iacc2233 = vdupq_n_s32(0); + + for (size_t b = 0; b < k_internal; b += kai_k0) { + // Set up RHS + const int8x16_t rhs_raw_vec_0 = vld1q_s8((const int8_t*)(rhs_ptr + 0)); + const int8x16_t rhs_raw_vec_1 = vld1q_s8((const int8_t*)(rhs_ptr + 16)); + const int8x16_t rhs_raw_vec_2 = vld1q_s8((const int8_t*)(rhs_ptr + 32)); + const int8x16_t rhs_raw_vec_3 = vld1q_s8((const int8_t*)(rhs_ptr + 48)); + + // Low nibble + const int8x16_t rhs_vec_0_0 = vshlq_n_s8(rhs_raw_vec_0, 4); + const int8x16_t rhs_vec_1_0 = vshlq_n_s8(rhs_raw_vec_1, 4); + const int8x16_t rhs_vec_2_0 = vshlq_n_s8(rhs_raw_vec_2, 4); + const int8x16_t rhs_vec_3_0 = vshlq_n_s8(rhs_raw_vec_3, 4); + + // High nibble + const int8x16_t rhs_vec_0_1 = vandq_s8(rhs_raw_vec_0, nibble_mask); + const int8x16_t rhs_vec_1_1 = vandq_s8(rhs_raw_vec_1, nibble_mask); + const int8x16_t rhs_vec_2_1 = vandq_s8(rhs_raw_vec_2, nibble_mask); + const int8x16_t rhs_vec_3_1 = vandq_s8(rhs_raw_vec_3, nibble_mask); + + const int8x16_t lhs_vec_0 = vld1q_s8((const int8_t*)(lhs_ptr + 0)); + const int8x16_t lhs_vec_1 = vld1q_s8((const int8_t*)(lhs_ptr + 16)); + + lhs_ptr += 32; + rhs_ptr += 64; + + int8x16_t t; + + t = vcombine_s8(vget_low_s8(lhs_vec_0), vget_low_s8(lhs_vec_0)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_0_0, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_1_0, t); + t = vcombine_s8(vget_high_s8(lhs_vec_0), vget_high_s8(lhs_vec_0)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_2_0, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_3_0, t); + t = vcombine_s8(vget_low_s8(lhs_vec_1), vget_low_s8(lhs_vec_1)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_0_1, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_1_1, t); + t = vcombine_s8(vget_high_s8(lhs_vec_1), vget_high_s8(lhs_vec_1)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_2_1, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_3_1, t); + } + + int32x4_t iacc = vpaddq_s32(iacc0011, iacc2233); + + // LHS offset + const int32x4_t lhs_offset = vld1q_dup_s32((const int32_t*)lhs_ptr); + lhs_ptr += sizeof(int32_t); + + // LHS scale + const float32x4_t lhs_scale = vld1q_dup_f32((const float*)lhs_ptr); + lhs_ptr += sizeof(float); + + // RHS sum values + const int32x4_t sum_n_s32 = vld1q_s32((const int32_t*)(rhs_ptr)); + rhs_ptr += sizeof(int32x4_t); + + // RHS scale + const float32x4_t rhs_scale = vld1q_f32((const float*)rhs_ptr); + rhs_ptr += sizeof(float32x4_t); + + // Load the bias + const float32x4_t bias0 = vld1q_f32((const float*)rhs_ptr); + rhs_ptr += sizeof(float32x4_t); + + // Add the reduction sum + iacc = vmlaq_s32(iacc, sum_n_s32, lhs_offset); + + float32x4_t main_acc = vmulq_f32(vcvtq_f32_s32(iacc), rhs_scale); + + main_acc = vmulq_f32(main_acc, lhs_scale); + + // Add the bias + main_acc = vaddq_f32(main_acc, bias0); + + // clamp (min-max) operation + const float32x4_t vmin_f32 = vdupq_n_f32(scalar_min); + const float32x4_t vmax_f32 = vdupq_n_f32(scalar_max); + + main_acc = vmaxq_f32(main_acc, vmin_f32); + main_acc = vminq_f32(main_acc, vmax_f32); + + if (col_idx + kai_nr <= n) { + vst1q_f32((float*)((uint8_t*)dst + col_idx * sizeof(float) + row_idx * dst_stride_row), main_acc); + } else { + size_t leftover = n % kai_nr; + *(float*)((uint8_t*)dst + (col_idx + 0) * sizeof(float) + row_idx * dst_stride_row) = + vgetq_lane_f32(main_acc, 0); + if (leftover > 1) { + *(float*)((uint8_t*)dst + (col_idx + 1) * sizeof(float) + row_idx * dst_stride_row) = + vgetq_lane_f32(main_acc, 1); + } + if (leftover > 2) { + *(float*)((uint8_t*)dst + (col_idx + 2) * sizeof(float) + row_idx * dst_stride_row) = + vgetq_lane_f32(main_acc, 2); + } + } + } + lhs_ptr_start += lhs_packed_stride; + } +} +#endif // Architectural feature check diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h new file mode 100644 index 000000000..fefca19a9 --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h @@ -0,0 +1,125 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Micro-kernel dependencies +/// +/// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix +/// -# kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0 OR kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 to pack the RHS matrix + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); + +/// Gets the mr value, which must be used to pack the LHS matrix with +/// the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); + +/// Gets the nr value, which must be used to pack the RHS matrix with +/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 micro-kernel +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); + +/// Gets the kr value, which must be used to pack the RHS matrix with +/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 micro-kernel +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); + +/// Gets the sr value, which must be used to pack the RHS matrix with +/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 micro-kernel +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed 8-bit quantized asymmetric per-row (qa8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t m_idx, size_t k); + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed 4-bit quantized symmetric per-channel (qsu4cx) values. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t n_idx, size_t k); + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the destination offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride); + +/// Gets the size in bytes for the destination matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix +/// @param[in] n Number of columns in the destination (DST) matrix +/// +/// @return the DST size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t m, size_t n); + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed +/// RHS matrix: Signed 4-bit quantized symmetric per-channel (qsu4cx) and packed. +/// Output tile: (rows x cols) = 1 x 4 +/// Accumulation performed in a single for loop: 64 +/// Instruction used: dotprod +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension of LHS & RHS. +/// @param[in] lhs_packed The LHS matrix packed. +/// When the activation are dynamically quantized, you can obtain this matrix +/// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs +/// both the dynamic quantization to 8-bit and activation packing in a single step. +/// @param[in] rhs_packed The RHS matrix packed, which is obtained by calling @ref +/// kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 +/// @param[out] dst Result of the vector-by-matrix +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +#ifdef __cplusplus +} +#endif diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c new file mode 100644 index 000000000..7e40839e6 --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c @@ -0,0 +1,508 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__ARM_FEATURE_MATMUL_INT8) +#error "I8mm extension required to compile this micro-kernel" +#else +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" + +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_m_step = 8; +static const size_t kai_n_step = 4; +static const size_t kai_mr = 4; +static const size_t kai_nr = 4; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +static const size_t kai_num_bytes_bias = sizeof(float); + +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(size_t n_idx, size_t k) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max) { + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + const size_t k_internal = kai_k_roundedup(k); + + size_t num_blocks = k_internal / 32; + + float clamp_vals[2] = {scalar_min, scalar_max}; + + __asm__ __volatile__( + "mov x12, %x[m]\n" + "mov x11, #0x80\n" + "movi v11.16b, #0xf0\n" + "mov x20, #0x20\n" + "cmp x12, #0x8\n" + "madd x11, %x[num_blocks], x11, x20\n" + "blt 10f\n" + "1:" // Row loop + "mov x10, %x[rhs_packed]\n" + "mov x9, %x[n]\n" + "add x28, %x[dst], %x[dst_stride_row], LSL #3\n" + "2:" // Column loop + "mov x22, %x[lhs_packed]\n" + "movi v10.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "mov x21, %x[num_blocks]\n" + "movi v8.4s, #0x0\n" + "movi v7.4s, #0x0\n" + "movi v6.4s, #0x0\n" + "movi v5.4s, #0x0\n" + "add x20, x22, x11\n" + "movi v4.4s, #0x0\n" + "movi v3.4s, #0x0\n" + "3:" // Sub block loop + "ldr q2, [x10, #0x0]\n" + "ldr q1, [x10, #0x10]\n" + "subs x21, x21, #0x1\n" + "ldr q20, [x22, #0x0]\n" + "ldr q19, [x22, #0x10]\n" + "ldr q18, [x20, #0x0]\n" + "ldr q0, [x20, #0x10]\n" + "ldr q31, [x10, #0x20]\n" + "ldr q30, [x10, #0x30]\n" + "shl v17.16b, v2.16b, #0x4\n" + "shl v16.16b, v1.16b, #0x4\n" + "ldr q29, [x22, #0x20]\n" + "ldr q28, [x22, #0x30]\n" + "and v2.16b, v2.16b, v11.16b\n" + "and v1.16b, v1.16b, v11.16b\n" + "ldr q27, [x20, #0x20]\n" + "ldr q26, [x20, #0x30]\n" + "add x10, x10, #0x40\n" + "ldr q25, [x22, #0x40]\n" + "ldr q24, [x22, #0x50]\n" + ".inst 0x4e91a68a // smmla v10.4s, v20.16b, v17.16b\n" + ".inst 0x4e90a689 // smmla v9.4s, v20.16b, v16.16b\n" + "ldr q23, [x20, #0x40]\n" + "ldr q22, [x20, #0x50]\n" + ".inst 0x4e91a668 // smmla v8.4s, v19.16b, v17.16b\n" + ".inst 0x4e90a667 // smmla v7.4s, v19.16b, v16.16b\n" + "ldr q21, [x22, #0x60]\n" + "ldr q20, [x22, #0x70]\n" + ".inst 0x4e91a646 // smmla v6.4s, v18.16b, v17.16b\n" + ".inst 0x4e90a645 // smmla v5.4s, v18.16b, v16.16b\n" + "ldr q19, [x20, #0x60]\n" + "ldr q18, [x20, #0x70]\n" + ".inst 0x4e91a404 // smmla v4.4s, v0.16b, v17.16b\n" + ".inst 0x4e90a403 // smmla v3.4s, v0.16b, v16.16b\n" + "shl v17.16b, v31.16b, #0x4\n" + "shl v16.16b, v30.16b, #0x4\n" + "add x22, x22, #0x80\n" + "add x20, x20, #0x80\n" + "and v31.16b, v31.16b, v11.16b\n" + "and v30.16b, v30.16b, v11.16b\n" + ".inst 0x4e91a7aa // smmla v10.4s, v29.16b, v17.16b\n" + ".inst 0x4e90a7a9 // smmla v9.4s, v29.16b, v16.16b\n" + ".inst 0x4e91a788 // smmla v8.4s, v28.16b, v17.16b\n" + ".inst 0x4e90a787 // smmla v7.4s, v28.16b, v16.16b\n" + ".inst 0x4e91a766 // smmla v6.4s, v27.16b, v17.16b\n" + ".inst 0x4e90a765 // smmla v5.4s, v27.16b, v16.16b\n" + ".inst 0x4e91a744 // smmla v4.4s, v26.16b, v17.16b\n" + ".inst 0x4e90a743 // smmla v3.4s, v26.16b, v16.16b\n" + ".inst 0x4e82a72a // smmla v10.4s, v25.16b, v2.16b\n" + ".inst 0x4e81a729 // smmla v9.4s, v25.16b, v1.16b\n" + ".inst 0x4e82a708 // smmla v8.4s, v24.16b, v2.16b\n" + ".inst 0x4e81a707 // smmla v7.4s, v24.16b, v1.16b\n" + ".inst 0x4e82a6e6 // smmla v6.4s, v23.16b, v2.16b\n" + ".inst 0x4e81a6e5 // smmla v5.4s, v23.16b, v1.16b\n" + ".inst 0x4e82a6c4 // smmla v4.4s, v22.16b, v2.16b\n" + ".inst 0x4e81a6c3 // smmla v3.4s, v22.16b, v1.16b\n" + ".inst 0x4e9fa6aa // smmla v10.4s, v21.16b, v31.16b\n" + ".inst 0x4e9ea6a9 // smmla v9.4s, v21.16b, v30.16b\n" + ".inst 0x4e9fa688 // smmla v8.4s, v20.16b, v31.16b\n" + ".inst 0x4e9ea687 // smmla v7.4s, v20.16b, v30.16b\n" + ".inst 0x4e9fa666 // smmla v6.4s, v19.16b, v31.16b\n" + ".inst 0x4e9ea665 // smmla v5.4s, v19.16b, v30.16b\n" + ".inst 0x4e9fa644 // smmla v4.4s, v18.16b, v31.16b\n" + ".inst 0x4e9ea643 // smmla v3.4s, v18.16b, v30.16b\n" + "bgt 3b\n" + "ldr q25, [x10, #0x0]\n" + "ld1 { v17.4s }, [x22]\n" + "uzp1 v23.2d, v10.2d, v9.2d\n" + "uzp2 v22.2d, v10.2d, v9.2d\n" + "ldr q24, [x10, #0x10]\n" + "uzp1 v21.2d, v8.2d, v7.2d\n" + "uzp2 v20.2d, v8.2d, v7.2d\n" + "add x22, x22, #0x10\n" + "ldr q16, [x22, #0x0]\n" + "add x10, x10, #0x20\n" + "mla v23.4s, v25.4s, v17.s[0]\n" + "mla v22.4s, v25.4s, v17.s[1]\n" + "mla v21.4s, v25.4s, v17.s[2]\n" + "mla v20.4s, v25.4s, v17.s[3]\n" + "fmul v19.4s, v24.4s, v16.s[0]\n" + "fmul v18.4s, v24.4s, v16.s[1]\n" + "fmul v17.4s, v24.4s, v16.s[2]\n" + "fmul v16.4s, v24.4s, v16.s[3]\n" + "scvtf v23.4s, v23.4s\n" + "scvtf v22.4s, v22.4s\n" + "scvtf v21.4s, v21.4s\n" + "scvtf v20.4s, v20.4s\n" + "fmul v10.4s, v23.4s, v19.4s\n" + "fmul v9.4s, v22.4s, v18.4s\n" + "fmul v8.4s, v21.4s, v17.4s\n" + "fmul v7.4s, v20.4s, v16.4s\n" + "ld1 { v17.4s }, [x20]\n" + "uzp1 v23.2d, v6.2d, v5.2d\n" + "uzp2 v22.2d, v6.2d, v5.2d\n" + "add x20, x20, #0x10\n" + "ldr q16, [x20, #0x0]\n" + "uzp1 v21.2d, v4.2d, v3.2d\n" + "uzp2 v20.2d, v4.2d, v3.2d\n" + "mla v23.4s, v25.4s, v17.s[0]\n" + "mla v22.4s, v25.4s, v17.s[1]\n" + "mla v21.4s, v25.4s, v17.s[2]\n" + "mla v20.4s, v25.4s, v17.s[3]\n" + "fmul v19.4s, v24.4s, v16.s[0]\n" + "fmul v18.4s, v24.4s, v16.s[1]\n" + "fmul v17.4s, v24.4s, v16.s[2]\n" + "scvtf v23.4s, v23.4s\n" + "fmul v16.4s, v24.4s, v16.s[3]\n" + "scvtf v22.4s, v22.4s\n" + "scvtf v21.4s, v21.4s\n" + "scvtf v20.4s, v20.4s\n" + "fmul v6.4s, v23.4s, v19.4s\n" + "fmul v5.4s, v22.4s, v18.4s\n" + "fmul v4.4s, v21.4s, v17.4s\n" + "fmul v3.4s, v20.4s, v16.4s\n" + "ldr q18, [x10, #0x0]\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x9, #0x4\n" + "ld1r { v16.4s }, [x20]\n" + "add x10, x10, #0x10\n" + "fadd v10.4s, v10.4s, v18.4s\n" + "fadd v9.4s, v9.4s, v18.4s\n" + "fadd v8.4s, v8.4s, v18.4s\n" + "fadd v7.4s, v7.4s, v18.4s\n" + "fadd v6.4s, v6.4s, v18.4s\n" + "fadd v5.4s, v5.4s, v18.4s\n" + "fadd v4.4s, v4.4s, v18.4s\n" + "fadd v3.4s, v3.4s, v18.4s\n" + "fmax v10.4s, v10.4s, v17.4s\n" + "fmax v9.4s, v9.4s, v17.4s\n" + "fmax v8.4s, v8.4s, v17.4s\n" + "fmax v7.4s, v7.4s, v17.4s\n" + "fmax v6.4s, v6.4s, v17.4s\n" + "fmax v5.4s, v5.4s, v17.4s\n" + "fmax v4.4s, v4.4s, v17.4s\n" + "fmax v3.4s, v3.4s, v17.4s\n" + "fmin v10.4s, v10.4s, v16.4s\n" + "fmin v9.4s, v9.4s, v16.4s\n" + "fmin v8.4s, v8.4s, v16.4s\n" + "fmin v7.4s, v7.4s, v16.4s\n" + "fmin v6.4s, v6.4s, v16.4s\n" + "fmin v5.4s, v5.4s, v16.4s\n" + "fmin v4.4s, v4.4s, v16.4s\n" + "fmin v3.4s, v3.4s, v16.4s\n" + "blt 6f\n" + "mov x20, %x[dst]\n" + "str q10, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q9, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q8, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q7, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q6, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q5, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q4, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q3, [x20, #0x0]\n" + "b 9f\n" + "6:" // Partial output + "mov x27, %x[dst]\n" + "add x26, x27, %x[dst_stride_row], LSL #2\n" + "add x25, x26, %x[dst_stride_row], LSL #1\n" + "add x24, x26, %x[dst_stride_row]\n" + "add x23, x25, %x[dst_stride_row]\n" + "add x22, x27, %x[dst_stride_row], LSL #1\n" + "add x21, x27, %x[dst_stride_row]\n" + "add x20, x22, %x[dst_stride_row]\n" + "tbz x9, #1, 7f\n" + "st1 { v3.d }[0], [x23], #0x8\n" + "st1 { v4.d }[0], [x25], #0x8\n" + "st1 { v5.d }[0], [x24], #0x8\n" + "st1 { v6.d }[0], [x26], #0x8\n" + "st1 { v7.d }[0], [x20], #0x8\n" + "st1 { v8.d }[0], [x22], #0x8\n" + "st1 { v9.d }[0], [x21], #0x8\n" + "st1 { v10.d }[0], [x27], #0x8\n" + "tbz x9, #0, 8f\n" + "st1 { v3.s }[2], [x23]\n" + "st1 { v4.s }[2], [x25]\n" + "st1 { v5.s }[2], [x24]\n" + "st1 { v6.s }[2], [x26]\n" + "st1 { v7.s }[2], [x20]\n" + "st1 { v8.s }[2], [x22]\n" + "st1 { v9.s }[2], [x21]\n" + "st1 { v10.s }[2], [x27]\n" + "b 8f\n" + "7:" // Output block 0: partial_1_0 + "st1 { v3.s }[0], [x23]\n" + "st1 { v4.s }[0], [x25]\n" + "st1 { v5.s }[0], [x24]\n" + "st1 { v6.s }[0], [x26]\n" + "st1 { v7.s }[0], [x20]\n" + "st1 { v8.s }[0], [x22]\n" + "st1 { v9.s }[0], [x21]\n" + "st1 { v10.s }[0], [x27]\n" + "8:" // Output block 0: Done + "9:" // Output stage exit + "subs x9, x9, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 2b\n" + "mov x20, #0x2\n" + "sub x12, x12, #0x8\n" + "cmp x12, #0x8\n" + "mov %x[dst], x28\n" + "madd %x[lhs_packed], x20, x11, %x[lhs_packed]\n" + "bge 1b\n" + "10:" // Row loop skip + "cbz x12, 19f\n" + "11:" // Row tail: Row loop + "mov x26, %x[rhs_packed]\n" + "mov x25, %x[n]\n" + "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" + "12:" // Row tail: Column loop + "mov x22, %x[lhs_packed]\n" + "movi v10.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "mov x20, %x[num_blocks]\n" + "movi v8.4s, #0x0\n" + "movi v7.4s, #0x0\n" + "13:" // Row tail: Sub block loop + "ldr q31, [x26, #0x0]\n" + "ldr q30, [x26, #0x10]\n" + "subs x20, x20, #0x1\n" + "ldr q29, [x22, #0x0]\n" + "ldr q28, [x22, #0x10]\n" + "ldr q27, [x26, #0x20]\n" + "ldr q26, [x26, #0x30]\n" + "add x26, x26, #0x40\n" + "ldr q25, [x22, #0x20]\n" + "ldr q24, [x22, #0x30]\n" + "shl v23.16b, v31.16b, #0x4\n" + "shl v22.16b, v30.16b, #0x4\n" + "ldr q21, [x22, #0x40]\n" + "ldr q20, [x22, #0x50]\n" + "and v31.16b, v31.16b, v11.16b\n" + "and v30.16b, v30.16b, v11.16b\n" + "ldr q19, [x22, #0x60]\n" + "ldr q18, [x22, #0x70]\n" + "shl v17.16b, v27.16b, #0x4\n" + "shl v16.16b, v26.16b, #0x4\n" + ".inst 0x4e97a7aa // smmla v10.4s, v29.16b, v23.16b\n" + ".inst 0x4e96a7a9 // smmla v9.4s, v29.16b, v22.16b\n" + "and v27.16b, v27.16b, v11.16b\n" + "add x22, x22, #0x80\n" + ".inst 0x4e97a788 // smmla v8.4s, v28.16b, v23.16b\n" + ".inst 0x4e96a787 // smmla v7.4s, v28.16b, v22.16b\n" + "and v26.16b, v26.16b, v11.16b\n" + ".inst 0x4e91a72a // smmla v10.4s, v25.16b, v17.16b\n" + ".inst 0x4e90a729 // smmla v9.4s, v25.16b, v16.16b\n" + ".inst 0x4e91a708 // smmla v8.4s, v24.16b, v17.16b\n" + ".inst 0x4e90a707 // smmla v7.4s, v24.16b, v16.16b\n" + ".inst 0x4e9fa6aa // smmla v10.4s, v21.16b, v31.16b\n" + ".inst 0x4e9ea6a9 // smmla v9.4s, v21.16b, v30.16b\n" + ".inst 0x4e9fa688 // smmla v8.4s, v20.16b, v31.16b\n" + ".inst 0x4e9ea687 // smmla v7.4s, v20.16b, v30.16b\n" + ".inst 0x4e9ba66a // smmla v10.4s, v19.16b, v27.16b\n" + ".inst 0x4e9aa669 // smmla v9.4s, v19.16b, v26.16b\n" + ".inst 0x4e9ba648 // smmla v8.4s, v18.16b, v27.16b\n" + ".inst 0x4e9aa647 // smmla v7.4s, v18.16b, v26.16b\n" + "bgt 13b\n" + "ldr q18, [x26, #0x0]\n" + "ld1 { v17.4s }, [x22]\n" + "uzp1 v24.2d, v10.2d, v9.2d\n" + "uzp2 v23.2d, v10.2d, v9.2d\n" + "ldr q22, [x26, #0x10]\n" + "uzp1 v21.2d, v8.2d, v7.2d\n" + "uzp2 v20.2d, v8.2d, v7.2d\n" + "add x22, x22, #0x10\n" + "ldr q16, [x22, #0x0]\n" + "add x26, x26, #0x20\n" + "mla v24.4s, v18.4s, v17.s[0]\n" + "mla v23.4s, v18.4s, v17.s[1]\n" + "mla v21.4s, v18.4s, v17.s[2]\n" + "mla v20.4s, v18.4s, v17.s[3]\n" + "fmul v19.4s, v22.4s, v16.s[0]\n" + "fmul v18.4s, v22.4s, v16.s[1]\n" + "fmul v17.4s, v22.4s, v16.s[2]\n" + "fmul v16.4s, v22.4s, v16.s[3]\n" + "scvtf v24.4s, v24.4s\n" + "scvtf v23.4s, v23.4s\n" + "scvtf v21.4s, v21.4s\n" + "scvtf v20.4s, v20.4s\n" + "fmul v10.4s, v24.4s, v19.4s\n" + "fmul v9.4s, v23.4s, v18.4s\n" + "fmul v8.4s, v21.4s, v17.4s\n" + "fmul v7.4s, v20.4s, v16.4s\n" + "ldr q18, [x26, #0x0]\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x25, #0x4\n" + "ld1r { v16.4s }, [x20]\n" + "add x26, x26, #0x10\n" + "fadd v10.4s, v10.4s, v18.4s\n" + "fadd v9.4s, v9.4s, v18.4s\n" + "fadd v8.4s, v8.4s, v18.4s\n" + "fadd v7.4s, v7.4s, v18.4s\n" + "fmax v10.4s, v10.4s, v17.4s\n" + "fmax v9.4s, v9.4s, v17.4s\n" + "fmax v8.4s, v8.4s, v17.4s\n" + "fmax v7.4s, v7.4s, v17.4s\n" + "fmin v10.4s, v10.4s, v16.4s\n" + "fmin v9.4s, v9.4s, v16.4s\n" + "fmin v8.4s, v8.4s, v16.4s\n" + "fmin v7.4s, v7.4s, v16.4s\n" + "blt 15f\n" + "mov x20, %x[dst]\n" + "cmp x12, #0x1\n" + "str q10, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 18f\n" + "cmp x12, #0x2\n" + "str q9, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 18f\n" + "cmp x12, #0x3\n" + "str q8, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 18f\n" + "str q7, [x20, #0x0]\n" + "b 18f\n" + "15:" // Row tail: Partial output + "mov x23, %x[dst]\n" + "cmp x12, #0x1\n" + "add x22, x23, %x[dst_stride_row]\n" + "csel x22, x22, x23, GT\n" + "cmp x12, #0x2\n" + "add x21, x23, %x[dst_stride_row], LSL #1\n" + "csel x21, x21, x22, GT\n" + "cmp x12, #0x3\n" + "add x20, x21, %x[dst_stride_row]\n" + "csel x20, x20, x21, GT\n" + "tbz x25, #1, 16f\n" + "st1 { v7.d }[0], [x20], #0x8\n" + "st1 { v8.d }[0], [x21], #0x8\n" + "st1 { v9.d }[0], [x22], #0x8\n" + "st1 { v10.d }[0], [x23], #0x8\n" + "tbz x25, #0, 17f\n" + "st1 { v7.s }[2], [x20]\n" + "st1 { v8.s }[2], [x21]\n" + "st1 { v9.s }[2], [x22]\n" + "st1 { v10.s }[2], [x23]\n" + "b 17f\n" + "16:" // Row tail: Output block 0: partial_1_0 + "st1 { v7.s }[0], [x20]\n" + "st1 { v8.s }[0], [x21]\n" + "st1 { v9.s }[0], [x22]\n" + "st1 { v10.s }[0], [x23]\n" + "17:" // Row tail: Output block 0: Done + "18:" // Row tail: Output stage exit + "subs x25, x25, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 12b\n" + "subs x12, x12, #0x4\n" + "add %x[lhs_packed], %x[lhs_packed], x11\n" + "mov %x[dst], x24\n" + "bgt 11b\n" + "19:" // Row tail: Row loop skip + : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) + : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), + [num_blocks] "r"(num_blocks), [rhs_packed] "r"(rhs_packed) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", + "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); +} +#endif // Architectural feature check diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h new file mode 100644 index 000000000..04df4a825 --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h @@ -0,0 +1,125 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Micro-kernel dependencies +/// +/// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix +/// -# kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0 OR kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 to pack the RHS matrix + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); + +/// Gets the mr value, which must be used to pack the LHS matrix with +/// the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); + +/// Function to get the nr value, which must be used to pack the RHS matrix with +/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 micro-kernel +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); + +/// Gets the kr value, which must be used to pack the RHS matrix with +/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 micro-kernel +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); + +/// Gets the sr value, which must be used to pack the RHS matrix with +/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 micro-kernel +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed 8-bit quantized asymmetric per-row (qai8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of 8 +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(size_t m_idx, size_t k); + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed 4-bit quantized symmetric per-channel (qsu4cx) values. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(size_t n_idx, size_t k); + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of 8. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride); + +/// Gets the size in bytes for the destination matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(size_t m, size_t n); + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed +/// RHS matrix: Signed 4-bit quantized symmetric per-channel (qsu4cx) and packed. +/// Output tile: (rows x cols) = 8 x 4 +/// Accumulation performed in a single for loop: 32 +/// Instruction used: i8mm +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension of LHS & RHS. +/// @param[in] lhs_packed The LHS matrix packed. +/// When the activation are dynamically quantized, you can obtain this matrix +/// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs +/// both the dynamic quantization to 8-bit and activation packing in a single step. +/// @param[in] rhs_packed The RHS matrix packed, which is obtained by calling @ref +/// kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 +/// @param[out] dst Result of the vector-by-matrix +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +#ifdef __cplusplus +} +#endif diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c new file mode 100644 index 000000000..20deb99d1 --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c @@ -0,0 +1,179 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_lhs_quant_pack_qai8dxp_f32.h" + +#if defined(__aarch64__) +#include +#endif +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_num_bytes_per_multiplier = sizeof(float); +static const size_t kai_num_bytes_per_offset = sizeof(int32_t); + +inline static size_t kai_k_roundedup(size_t k, size_t kr, size_t sr) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for memory alignment. + size_t kr_sr_roundedup4 = kai_roundup(kr * sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t sr) { + const size_t k_internal = kai_k_roundedup(k, kr, sr); + + KAI_ASSERT((k_internal % 2) == 0); + + return mr * (k_internal * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); +} + +size_t kai_get_m_step_lhs_quant_pack_qai8dxp_f32(size_t mr) { + KAI_UNUSED(mr); + return 1; +} + +size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t lhs_stride) { + return m_idx * lhs_stride; +} + +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { + // It always points to the beginning of the row + return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, sr); +} + +size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { + const size_t num_rows = kai_roundup(m, mr) / mr; + + return num_rows * kai_lhs_packed_stride(k, mr, kr, sr); +} + +void kai_run_lhs_quant_pack_qai8dxp_f32( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* restrict lhs, + size_t lhs_stride, void* restrict lhs_packed) { + KAI_ASSERT((kr % sr) == 0); + + if (m == 0) { + return; + } + + const size_t num_rows = m; + + const float* src_ptr = lhs; + + const size_t dst_stride = kai_lhs_packed_stride(k, mr, kr, sr); + const size_t k_internal = kai_k_roundedup(k, kr, sr); + const int32_t k_block_len = (int32_t)(kr / sr); + + for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) { + float max0 = -FLT_MAX; + float min0 = FLT_MAX; + + // Find min/max for each channel + int32_t k_idx = 0; + +#if defined(__aarch64__) + float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX); + float32x4_t vmin0 = vdupq_n_f32(FLT_MAX); + + for (; k_idx <= ((int32_t)k - 8); k_idx += 8) { + const float32x4_t src0_0 = vld1q_f32(src_ptr + 0 + (size_t)k_idx); + const float32x4_t src0_1 = vld1q_f32(src_ptr + 4 + (size_t)k_idx); + + // Calculate the max + vmax0 = vmaxq_f32(src0_0, vmax0); + vmax0 = vmaxq_f32(vmax0, src0_1); + + // Calculate the min + vmin0 = vminq_f32(src0_0, vmin0); + vmin0 = vminq_f32(vmin0, src0_1); + } + // Get the max/min + max0 = vmaxvq_f32(vmax0); + min0 = vminvq_f32(vmin0); +#endif + for (; k_idx < (int32_t)k; ++k_idx) { + const float src0_0 = *(src_ptr + (size_t)k_idx); + max0 = KAI_MAX(src0_0, max0); + min0 = KAI_MIN(src0_0, min0); + } + + // Maximum/minimum int8 values + const float qmin = (float)INT8_MIN; + const float qmax = (float)INT8_MAX; + + const float rmin0 = KAI_MIN(0.0F, min0); + const float rmax0 = KAI_MAX(0.0F, max0); + + const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0); + + // Reciprocal to quantize + const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; + + const float descaled_min0 = rmin0 * scale0; + const float descaled_max0 = rmax0 * scale0; + + const float zero_point_from_min_error0 = qmin + descaled_min0; + const float zero_point_from_max_error0 = qmax + descaled_max0; + + float zero_point0 = + zero_point_from_min_error0 + zero_point_from_max_error0 > 0 ? qmin - descaled_min0 : qmax - descaled_max0; + + zero_point0 = KAI_MAX(zero_point0, qmin); + zero_point0 = KAI_MIN(zero_point0, qmax); + + // Round to nearest integer + const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0); + + const size_t dst_x = ((row_idx + m_idx_start) % mr); + + uint8_t* dst_ptr = (uint8_t*)lhs_packed + dst_x * k_block_len * sizeof(int8_t); + + // Quantize the channels + k_idx = 0; + for (; k_idx < (int32_t)k_internal; k_idx += k_block_len) { + for (size_t k_block_idx = 0; k_block_idx < (size_t)k_block_len; ++k_block_idx) { + // Clamp at the last valid k-index + const size_t k_idx_start = KAI_MIN((size_t)k_idx + k_block_idx, k - 1); + + const float src0_0 = *(src_ptr + k_idx_start); + + // Scale the values + int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0)); + + v0_s32 = v0_s32 + nudged_zero_point0; + v0_s32 = KAI_MAX(v0_s32, INT8_MIN); + v0_s32 = KAI_MIN(v0_s32, INT8_MAX); + *((int8_t*)(dst_ptr)) = (int8_t)v0_s32; + dst_ptr += sizeof(int8_t); + } + dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); + } + + dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t)); + + dst_ptr += dst_x * kai_num_bytes_per_offset; + + // LHS offset at the beginning of the row + *((int32_t*)(dst_ptr)) = -nudged_zero_point0; + + // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier + KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier); + + dst_ptr += mr * kai_num_bytes_per_offset; + + // Store the scale quantization params + *((float*)(dst_ptr)) = recip_scale0; + + src_ptr += (lhs_stride / sizeof(float)); + + // Move to the next row if we have interleaved all Mr rows + if ((((row_idx + 1) + m_idx_start) % mr) == 0) { + lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride); + } + } +} diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h new file mode 100644 index 000000000..acba70cd6 --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h @@ -0,0 +1,77 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @param[in] mr The number of M rows to interleave on the same output row. +/// +/// @return the m step value +size_t kai_get_m_step_lhs_quant_pack_qai8dxp_f32(size_t mr); + +/// Gets the offset in bytes for the LHS matrix (not packed) +/// +/// This function should be called before passing the pointer to the LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). +/// @param[in] lhs_stride The number of bytes in in each row of the LHS matrix (not packed) +/// +/// @return the offset in bytes to the LHS matrix +size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t lhs_stride); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed 8-bit quantized asymmetric per-row (qa8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// @param[in] mr The number of M rows to interleave on the same output row. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr); + +/// Gets the size in bytes for the quantized and packed LHS matrix +/// +/// @param[in] m Total number of rows in the LHS matrix (not packed). +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// @param[in] mr The number of M rows to interleave on the same output row. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the packed LHS matrix size in bytes +size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(size_t m, size_t k, size_t mr, size_t kr, size_t sr); + +/// Run the micro-kernel to quantize and pack the LHS matrix. +/// +/// @param[in] m The number of output rows written. +/// @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 8. +/// @param[in] mr The number of M rows to interleave on the same output row. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// However, kr must be multiple of sr. +/// @param[in] m_idx_start The starting M index. +/// @param[in] lhs LHS of the vector-by-matrix. +/// @param[in] lhs_stride Stride in bytes between two rows of LHS. +/// @param[out] lhs_packed The quantized and packed LHS matrix. +void kai_run_lhs_quant_pack_qai8dxp_f32( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, size_t lhs_stride, + void* lhs_packed); + +#ifdef __cplusplus +} +#endif diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c new file mode 100644 index 000000000..d3ec86067 --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c @@ -0,0 +1,203 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" + +#include +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); + +inline static size_t kai_k_roundedup(size_t k, size_t kr, size_t sr) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kr * sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +size_t kai_get_n_step_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t nr) { + return nr; +} + +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t n_idx, size_t rhs_stride) { + return n_idx * rhs_stride; +} + +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t k, size_t nr, size_t kr, size_t sr) { + const size_t k_internal = kai_k_roundedup(k, kr, sr); + + KAI_ASSERT((k_internal % 2) == 0); + + return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); +} + +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0( + size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) { + KAI_ASSERT((n_idx % nr) == 0); + + return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr); +} + +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr) { + const size_t num_rows = kai_roundup(n, nr) / nr; + + return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr); +} + +void kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, + const float* scale, void* rhs_packed, size_t extra_bytes, + const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params) { + KAI_ASSERT(num_groups == 1); + KAI_ASSERT(extra_bytes == 0); + KAI_ASSERT((kr % sr) == 0); + KAI_ASSERT(rhs != NULL); + KAI_ASSERT(scale != NULL); + KAI_ASSERT(rhs_packed != NULL); + KAI_ASSERT(params != NULL); + KAI_ASSERT(params->lhs_zero_point == 1); + KAI_ASSERT(params->rhs_zero_point == 0 || params->rhs_zero_point == 8); + + const size_t rhs_zero_point = params->rhs_zero_point; + const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr); + const size_t k_internal = kai_k_roundedup(k, kr, sr); + const size_t dst_num_rows = kai_roundup(n, nr) / nr; + const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2); + const size_t block_length_in_bytes = kr / sr; + const size_t k_interleaved_v = 16U; + const size_t rhs_stride = kai_roundup(k, 2) / 2; + + for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { + uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; + + int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2)); + + // Initialize to zero the RHS reduction sums + memset(sums, 0, nr * sizeof(int32_t)); + + for (size_t dst_byte_idx = 0; dst_byte_idx < dst_num_bytes_per_row; ++dst_byte_idx) { + const size_t block_idx = dst_byte_idx / block_length_in_bytes; + const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes; + const size_t super_block_idx = block_idx / nr; + const size_t nr_idx = block_idx % nr; + + const size_t k_adjustment = + ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; + const size_t k0_idx = block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment; + const size_t k1_idx = k0_idx + k_interleaved_v; + const size_t n0_idx = dst_row_idx * nr + nr_idx; + + // Clamp the index to avoid out-of-bound reads + const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); + + const size_t src_addr_byte0 = (k0_idx / 2) + n0_valid_idx * rhs_stride; + const size_t src_addr_byte1 = (k1_idx / 2) + n0_valid_idx * rhs_stride; + + if (params->rhs_zero_point == 8) { + uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; + uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; + + if (k0_idx < k) { + byte0 = rhs[src_addr_byte0]; + } + + if (k1_idx < k) { + byte1 = rhs[src_addr_byte1]; + } + + // The following operations where we extract the values from the bytes + // can be also written in the following and less efficient manner: + /* + uint8_t src_x0_lo = 0; + uint8_t src_x0_hi = 0; + + if ((k0_idx % 2) == 0) { + src_x0_lo = (byte0 & 0x0F); + } else { + src_x0_lo = (byte0 >> 4); + } + + if ((k1_idx % 2) == 0) { + src_x0_hi = (byte1 & 0x0F); + } else { + src_x0_hi = (byte1 >> 4); + } + */ + const size_t shift_right_x0 = (k0_idx % 2) * 4; + const size_t shift_right_x1 = (k1_idx % 2) * 4; + + const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; + const uint8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F; + + sums[nr_idx] += (int32_t)src_x0_lo + (int32_t)src_x0_hi - 2 * (int32_t)rhs_zero_point; + + const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); + + *dst_row = dst_qs0 ^ 0x88; + dst_row += sizeof(uint8_t); + } else { + int8_t byte0 = 0; + int8_t byte1 = 0; + + if (k0_idx < k) { + byte0 = rhs[src_addr_byte0]; + } + + if (k1_idx < k) { + byte1 = rhs[src_addr_byte1]; + } + + // The logic behind the following operations where we extract the + // values from the bytes is same as unsigned + + const size_t shift_right_x0 = (k0_idx % 2) * 4; + const size_t shift_right_x1 = (k1_idx % 2) * 4; + + int8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; + int8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F; + + const int8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); + + *(int8_t*)dst_row = dst_qs0; + dst_row += sizeof(int8_t); + + src_x0_lo = kai_ext_sign_i8_i4(src_x0_lo); + src_x0_hi = kai_ext_sign_i8_i4(src_x0_hi); + sums[nr_idx] += (int32_t)src_x0_lo + (int32_t)src_x0_hi; + } + } + + // Adjust the reduction sums + for (size_t i = 0; i < nr; ++i) { + sums[i] = sums[i] * 16; + dst_row += sizeof(int32_t); + } + + // Adjust the scales + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + *((float*)(dst_row)) = scale[src_row_idx] * 0.0625F; + dst_row += sizeof(float); + } + + // Set the bias + if (bias == NULL) { + memset(dst_row, 0, nr * kai_num_bytes_bias); + } else { + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + ((float*)dst_row)[i] = bias[src_row_idx]; + } + } + } +} diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h new file mode 100644 index 000000000..dc7c1bd02 --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h @@ -0,0 +1,111 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params { + int8_t lhs_zero_point; + uint8_t rhs_zero_point; +}; + +/// Get the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// +/// @return the n step value +size_t kai_get_n_step_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t nr); + +/// Gets the offset in bytes for the RHS matrix (not packed). +/// +/// @note The int4 values are stored in a N x K matrix. Two int4 values are stored in one byte. +/// The lower order part of the byte (low) holds the first nibble (K-index + 0). +/// The higher order of the byte holds the second nibble (K-index + 1). +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] rhs_stride The number of bytes in in each row of the RHS matrix (not packed) +/// +/// @return the offset in bytes to the RHS matrix (not packed) +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t n_idx, size_t rhs_stride); + +/// Get the row stride in bytes to the packed RHS matrix +/// +/// @param[in] k In the RHS matrix (not packed), K is the number of columns. +/// @param[in] nr The number of columns written by the matmul micro-kernel. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the stride in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t k, size_t nr, size_t kr, size_t sr); + +/// Gets the offset in bytes for the packed RHS matrix, which contains the packed 4-bit quantized symmetric per-channel +/// (qsu4cx) values. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k In the RHS matrix (not packed), K is the number of columns. +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0( + size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr); + +/// @brief Gets the size in bytes for the packed RHS matrix +/// +/// @param[in] n The number of rows in the RHS matrix (not packed) +/// @param[in] k The number of columns in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the packed RHS matrix size in bytes +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr); + +/// Run the micro-kernel to pack the RHS matrix. +/// +/// @note The int4 values are stored in a N x K matrix. Two int4 values are stored in one byte. +/// The lower order part of the byte (low) holds the first nibble (K-index + 0). +/// The higher order of the byte holds the second nibble (K-index + 1). +/// +/// @param[in] num_groups The number of groups. It must be 1. +/// @param[in] n The number of rows. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). It must be an even value. +/// @param[in] nr The number of N rows to interleave on the same output output row. +/// @param[in] kr The number of K values loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// However, kr must be multiple of sr. +/// @param[in] rhs The RHS matrix containing the 4-bit values. +/// Size in bytes is expected to be greater than or equal to n * k * (sizeof(uint8_t) / 2). +/// @param[in] bias The biases. +/// @param[in] scale The scale for each output channel. +/// @param[out] rhs_packed The packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. +/// @param[in] params Parameters for the micro-kernel. +void kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + const uint8_t* rhs, // + const float* bias, // + const float* scale, // + void* rhs_packed, // + size_t extra_bytes, // + const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params); + +#ifdef __cplusplus +} +#endif diff --git a/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.cpp b/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.cpp new file mode 100644 index 000000000..dc1f9169f --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.cpp @@ -0,0 +1,359 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#if defined(__aarch64__) + +#include "mnn_kleidiai.h" + +using namespace MNN; + +KleidiAI *KleidiAI::instance = NULL; + +inline static size_t kai_k_roundedup(size_t k, size_t kr, size_t sr) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for memory alignment. + size_t kr_sr_roundedup4 = kai_roundup(kr * sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +static void packQsi4cxps16s0Qs4cxs0s1( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, + const float* scale, void* rhs_packed, size_t extra_bytes, + const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params) { + KAI_ASSERT(num_groups == 1); + KAI_ASSERT(extra_bytes == 0); + KAI_ASSERT((kr % sr) == 0); + KAI_ASSERT(rhs != NULL); + KAI_ASSERT(scale != NULL); + KAI_ASSERT(rhs_packed != NULL); + KAI_ASSERT(params != NULL); + KAI_ASSERT(params->rhs_zero_point == 8); + KAI_ASSERT(params->lhs_zero_point == 1); + + const size_t rhs_zero_point = params->rhs_zero_point; + const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr); + const size_t k_internal = kai_k_roundedup(k, kr, sr); + const size_t dst_num_rows = kai_roundup(n, nr) / nr; + const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2); + const size_t block_length_in_bytes = kr / sr; + const size_t k_interleaved_v = 16U; + const size_t rhs_stride = kai_roundup(k, 2) / 2; + + for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { + uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; + + int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2)); + + // Initialize to zero the RHS reduction sums + memset(sums, 0, nr * sizeof(int32_t)); + + for (size_t dst_byte_idx = 0; dst_byte_idx < dst_num_bytes_per_row; ++dst_byte_idx) { + const size_t block_idx = dst_byte_idx / block_length_in_bytes; + const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes; + const size_t super_block_idx = block_idx / nr; + const size_t nr_idx = block_idx % nr; + + const size_t k_adjustment = + ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; + const size_t k0_idx = block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment; + const size_t k1_idx = k0_idx + k_interleaved_v; + const size_t n0_idx = dst_row_idx * nr + nr_idx; + + // Clamp the index to avoid out-of-bound reads + const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); + + const size_t src_addr_byte0 = (k0_idx / 2) + n0_valid_idx * rhs_stride; + const size_t src_addr_byte1 = (k1_idx / 2) + n0_valid_idx * rhs_stride; + + uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; + uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; + + if (k0_idx < k) { + byte0 = rhs[src_addr_byte0]; + } + + if (k1_idx < k) { + byte1 = rhs[src_addr_byte1]; + } + + // The following operations where we extract the values from the bytes + // can be also written in the following and less efficient manner: + /* + uint8_t src_x0_lo = 0; + uint8_t src_x0_hi = 0; + + if ((k0_idx % 2) == 0) { + src_x0_lo = (byte0 & 0x0F); + } else { + src_x0_lo = (byte0 >> 4); + } + + if ((k1_idx % 2) == 0) { + src_x0_hi = (byte1 & 0x0F); + } else { + src_x0_hi = (byte1 >> 4); + } + */ + const size_t shift_right_x0 = ((k0_idx + 1) % 2) * 4; + const size_t shift_right_x1 = ((k1_idx + 1) % 2) * 4; + + const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; + const uint8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F; + + sums[nr_idx] += (int32_t)src_x0_lo + (int32_t)src_x0_hi - 2 * (int32_t)rhs_zero_point; + + const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); + + *dst_row = dst_qs0 ^ 0x88; + dst_row += sizeof(uint8_t); + } + + // Adjust the reduction sums + for (size_t i = 0; i < nr; ++i) { + sums[i] = sums[i] * 16; + dst_row += sizeof(int32_t); + } + + // Adjust the scales + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + *((float*)(dst_row)) = scale[src_row_idx] * 0.0625F; + dst_row += sizeof(float); + } + + // Set the bias + if (bias == NULL) { + memset(dst_row, 0, nr * sizeof(float)); + } else { + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + ((float*)dst_row)[i] = bias[src_row_idx]; + } + } + } +} + +static void packQs4cxs16s0Qsi8cx(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, + const float* scale, void* rhs_packed, size_t extra_bytes, + const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params) { + KAI_ASSERT(num_groups == 1); + KAI_ASSERT(extra_bytes == 0); + KAI_ASSERT((kr % sr) == 0); + KAI_ASSERT(rhs != NULL); + KAI_ASSERT(scale != NULL); + KAI_ASSERT(rhs_packed != NULL); + KAI_ASSERT(params != NULL); + KAI_ASSERT(params->rhs_zero_point == 8); + KAI_ASSERT(params->lhs_zero_point == 1); + + const size_t rhs_zero_point = params->rhs_zero_point; + const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr); + const size_t k_internal = kai_k_roundedup(k, kr, sr); + const size_t dst_num_rows = kai_roundup(n, nr) / nr; + const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2); + const size_t block_length_in_bytes = kr / sr; + const size_t k_interleaved_v = 16U; + const size_t rhs_stride = kai_roundup(k, 2); + + for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { + uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; + + int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2)); + + // Initialize to zero the RHS reduction sums + memset(sums, 0, nr * sizeof(int32_t)); + + for (size_t dst_byte_idx = 0; dst_byte_idx < dst_num_bytes_per_row; ++dst_byte_idx) { + const size_t block_idx = dst_byte_idx / block_length_in_bytes; + const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes; + const size_t super_block_idx = block_idx / nr; + const size_t nr_idx = block_idx % nr; + + const size_t k_adjustment = + ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; + const size_t k0_idx = block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment; + const size_t k1_idx = k0_idx + k_interleaved_v; + const size_t n0_idx = dst_row_idx * nr + nr_idx; + + // Clamp the index to avoid out-of-bound reads + const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); + + const size_t src_addr_byte0 = k0_idx + n0_valid_idx * rhs_stride; + const size_t src_addr_byte1 = k1_idx + n0_valid_idx * rhs_stride; + + int8_t byte0 = 0; + int8_t byte1 = 0; + + if (k0_idx < k) { + byte0 = rhs[src_addr_byte0]; + } + + if (k1_idx < k) { + byte1 = rhs[src_addr_byte1]; + } + + sums[nr_idx] += (int32_t)byte0 + (int32_t)byte1; + + const uint8_t dst_qs0 = (byte0 + rhs_zero_point) | ((byte1 + rhs_zero_point) << 4); + + *dst_row = dst_qs0 ^ 0x88; + dst_row += sizeof(uint8_t); + } + + // Adjust the reduction sums + for (size_t i = 0; i < nr; ++i) { + sums[i] = sums[i] * 16; + dst_row += sizeof(int32_t); + } + + // Adjust the scales + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + *((float*)(dst_row)) = scale[src_row_idx] * 0.0625F; + dst_row += sizeof(float); + } + + // Set the bias + if (bias == NULL) { + memset(dst_row, 0, nr * sizeof(float)); + } else { + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + ((float*)dst_row)[i] = bias[src_row_idx]; + } + } + } +} + +void KleidiAI::packNCHWToNC4HW4(float* data, size_t rowNum, size_t rowSize) { + if(rowNum == 1) { + return; + } + + const size_t tmp_size = rowNum * rowSize * sizeof(float); + uint8_t *tmpBuffer = new uint8_t[tmp_size]; + memcpy(tmpBuffer, data, tmp_size); + + const float *src = (const float *)tmpBuffer; + float *dst = (float *)data; + + size_t blockNum = rowSize / 4; + size_t blockSize = 4 * sizeof(float); + + for(size_t blockIndex = 0; blockIndex < blockNum; blockIndex++) { + const float *rowSrc = src + blockIndex * 4; + for(size_t rowIndex = 0; rowIndex < rowNum; rowIndex++) { + memcpy(dst, rowSrc, blockSize); + dst += 4; + rowSrc += rowSize; + } + } + + delete[] tmpBuffer; +} + +void KleidiAI::packNC4HW4ToNCHW(float* data, size_t rowNum, size_t rowSize) { + if(rowNum == 1) { + return; + } + + const size_t tmp_size = rowNum * rowSize * sizeof(float); + uint8_t *tmpBuffer = new uint8_t[tmp_size]; + memcpy(tmpBuffer, data, tmp_size); + + const float *src = (const float *)tmpBuffer; + float *dst = (float *)data; + + size_t blockNum = rowSize / 4; + size_t blockSize = 4 * sizeof(float); + + for(size_t blockIndex = 0; blockIndex < blockNum; blockIndex++) { + const float *rowSrc = src + blockIndex * 4 * rowNum; + float *block_dst = dst + blockIndex * 4; + for(size_t rowIndex = 0; rowIndex < rowNum; rowIndex++) { + memcpy(block_dst, rowSrc, blockSize); + block_dst += rowSize; + rowSrc += 4; + } + } + + delete[] tmpBuffer; +} + +//Set info +void KleidiAI::setEnable(bool enable) { + mKaiInfo.kaiEnable = enable; + if(canAccelerate()) { + MNN_PRINT("\nKleidiAI is running!\n"); + } +} + +void KleidiAI::setModelAsymmetric(bool bAsymmetric) { + mKaiInfo.asymmetric = bAsymmetric; + if(canAccelerate()) { + MNN_PRINT("\nKleidiAI is running!\n"); + } +} + +//Lhs +size_t KleidiAI::getLhsQuantedPackedSize(size_t m, size_t k) { + return kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, getMr(m), getKr(), getSr()); +} + +size_t KleidiAI::getLhsQuantedPackedOffset(size_t m, size_t mIdx, size_t k) { + return mIdx == 0 ? 0 : kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(mIdx, k, getMr(m), getKr(), getSr()); +} + +void KleidiAI::runLhsQuantPack(size_t m, size_t k, size_t mr, const void* lhs, void* lhsQuantedPacked) { + kai_run_lhs_quant_pack_qai8dxp_f32(m, k, mr, getKr(), getSr(), 0, (const float *)lhs, k * sizeof(float), lhsQuantedPacked); +} + +//Rhs +size_t KleidiAI::getRhsPackedSize(size_t n, size_t k) { + return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(n, k, getNr(), getKr(), getSr()); +} + +size_t KleidiAI::getRhsPackedOffset(size_t nIdx, size_t k) { + return kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(nIdx, k, getNr(), getKr(), getSr()); +} + +void KleidiAI::runRhsPack(size_t n, size_t k, const void* rhs, const void* scale, const void *bias, void* rhsPacked, bool packedInt4) { + struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + if(!packedInt4) { + packQs4cxs16s0Qsi8cx(1, n, k, getNr(), getKr(), getSr(), + (const uint8_t *)rhs, + (const float *)bias, (const float *)scale, + rhsPacked, + 0, ¶ms); + } else { + packQsi4cxps16s0Qs4cxs0s1(1, n, k, getNr(), getKr(), getSr(), + (const uint8_t *)rhs, + (const float *)bias, (const float *)scale, + rhsPacked, + 0, ¶ms); + } +} + +//Matmul +void KleidiAI::runMatmul(size_t m, size_t n, size_t k, const void* lhsPacked, const void* rhsPacked, size_t dst_stride, void* dst) { + if(m == 1) { //dotprod + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(m, n, k, + (const void *)lhsPacked, (const void *)rhsPacked, (float *)dst, + dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); + } else { //i8mm + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(m, n, k, + (const void *)lhsPacked, (const void *)rhsPacked, (float *)dst, + dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); + } +} + +#endif // defined(__aarch64__) \ No newline at end of file diff --git a/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.h b/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.h new file mode 100644 index 000000000..38cdce230 --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.h @@ -0,0 +1,125 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "kai_lhs_quant_pack_qai8dxp_f32.h" +#include "kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" + +#include "kai_common.h" + +namespace MNN { + class KleidiAI { + public: + static KleidiAI &getInstance(bool bAsymmetric, bool acthalf, bool blockwise) { + if(!instance) { + instance = new KleidiAI(bAsymmetric, acthalf, blockwise); + } + return *instance; + } + + static KleidiAI &getInstance() { + if(!instance) { + instance = new KleidiAI; + } + return *instance; + } + + ~KleidiAI() {} + + typedef struct KaiInfo { + bool kaiEnable = false; + bool asymmetric = false; //Asymmetric quantized model. + bool acthalf = false; // activation half precision. + bool blockwise = false; // weight quant using block wise. + bool dot = false; //CPU support sdot. + bool i8mm = false; //CPU support i8mm. + } KaiInfo; + + //Kai util + void packNCHWToNC4HW4(float* data, size_t rowNum, size_t rowSize); + void packNC4HW4ToNCHW(float* data, size_t rowNum, size_t rowSize); + + //Set info + void setEnable(bool enable); + void setModelAsymmetric(bool bAsymmetric); + + //Check + bool canAccelerate() { + return (mKaiInfo.kaiEnable && mKaiInfo.dot && mKaiInfo.i8mm && + !mKaiInfo.asymmetric && !mKaiInfo.acthalf && !mKaiInfo.blockwise); + } + + //Get info + size_t getMr(size_t m = 1) { return (m == 1) ? mKaiMrDotprod : mKaiMrI8mm; } + size_t getNr() { return mKaiNr; } + size_t getKr() { return mKaiKr; } + size_t getSr() { return mKaiSr; } + size_t getMStep(size_t m = 1) { return (m == 1) ? mKaiMstepDotprod : mKaiMstepI8mm; } + size_t getNStep() { return mKaiNStep; } + size_t getVecNumPerThread(size_t totalVec, size_t totalThread, size_t minStep) { return kai_roundup((totalVec + totalThread - 1) / totalThread, minStep); } + + //Lhs + size_t getLhsQuantedPackedSize(size_t m, size_t k); + size_t getLhsQuantedPackedOffset(size_t m, size_t mIdx, size_t k); + void runLhsQuantPack(size_t m, size_t k, size_t mr, const void* lhs, void* lhsQuantedPacked); + + //Rhs + size_t getRhsPackedSize(size_t n, size_t k); + size_t getRhsPackedOffset(size_t nIdx, size_t k); + void runRhsPack(size_t n, size_t k, const void* rhs, const void* scale, const void *bias, void* rhsPacked, bool packedInt4 = false); + + //Dst + size_t getDstOffset(size_t mIdx, size_t nIdx, size_t n) { return (nIdx * sizeof(float)) + mIdx * (n * sizeof(float)); } + + //Matmul + void runMatmul(size_t m, size_t n, size_t k, const void* lhsPacked, const void* rhsPacked, size_t dst_stride, void* dst); + + private: + KleidiAI(bool bAsymmetric = false, bool acthalf = false, bool blockwise = false) { + const MNNCPUInfo& gCPUInfo = *MNNGetCPUInfo(); + mKaiInfo.dot = gCPUInfo.dot; + mKaiInfo.i8mm = gCPUInfo.i8mm; + mKaiInfo.kaiEnable = true; + mKaiInfo.asymmetric = bAsymmetric; + mKaiInfo.acthalf = acthalf; + mKaiInfo.blockwise = blockwise; + + if(canAccelerate()) { + MNN_PRINT("\nKleidiAI is running!\n"); + } + } + + static KleidiAI *instance; + KaiInfo mKaiInfo; + + const size_t mKaiMstepDotprod = 1; + const size_t mKaiMstepI8mm = 8; + const size_t mKaiNStep = 4; + + const size_t mKaiMrDotprod = 1; + const size_t mKaiMrI8mm = 4; + const size_t mKaiNr = 4; + const size_t mKaiKr = 16; + const size_t mKaiSr = 2; + }; +} \ No newline at end of file diff --git a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp index bcba4eedb..4788d88c3 100644 --- a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp +++ b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp @@ -83,7 +83,7 @@ void ConvInt8TiledExecutor::reorderWeight(Tensor* weight, const uint8_t* weightS for (int y = 0; y < ic; ++y) { const int yOutSide = y / SRC_UNIT; const int yInSide = y % SRC_UNIT; - + int blockId = (yOutSide + k * icDivU) / blockL; int blockInsideId = (yOutSide + k * icDivU) % blockL; @@ -268,6 +268,39 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O int oc = convOp->common()->outputCount(); int ic = convOp->common()->inputCount(); bool directReadInt4weight = (kernelCount == 1 && ROUND_UP(oc, UNIT) == oc && ROUND_UP(ic, SRC_UNIT) == ic); + +#ifdef MNN_KLEIDIAI_ENABLED + bool half_act = gcore->bytes == 2; + int biasSize = mResourceInt8->mOriginBias->size(); + int alphaSize = mResourceInt8->mOriginScale->size(); + bool blockwise = (biasSize * 2) != alphaSize; + KleidiAI kai = KleidiAI::getInstance(quanCommon->asymmetric, half_act, blockwise); + if(quanCommon->canUseInt4 && kai.canAccelerate()) { + int n = oc; + int k = ic; + int packedWeightSize = kai.getRhsPackedSize(n, k); + + //Alloc packed weight tensor. + mResourceInt8->mWeightInt8.reset(Tensor::createDevice({packedWeightSize})); + bool success = backend->onAcquireBuffer(mResourceInt8->mWeightInt8.get(), Backend::STATIC); + + if (!success) { + MNN_ERROR("Out of static memory!\n"); + return; + } + + //Run rhs pack. + kai.runRhsPack(n, k, (uint8_t*)quanCommon->weight.get(), + mResourceInt8->mOriginScale->host(), + mResourceInt8->mOriginBias->host(), + mResourceInt8->mWeightInt8->host(), + directReadInt4weight); + + return; + } + +#endif + if (quanCommon->canUseInt4 && directReadInt4weight) { // int4 weight reorder mResourceInt8->mWeightAsymmetricQuant = true; @@ -276,7 +309,7 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O int lU = UP_DIV(ic, SRC_UNIT); int hP = UNIT; int lP = SRC_UNIT; - + // weight shape. std::vector shape; if (SRC_UNIT > pack) { @@ -308,7 +341,7 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O int blockkInsideId = j % blockL; for (int k = 0; k < cnt; ++k) { int dstIndx0 = (blockId * stride0 + i * stride1 + blockkInsideId * lP * hP) / 2 + (2 * k); - + int hpId0 = (2 * k + 1) / lP; int lpId0 = (2 * k) % lP; int hpId1 = (2 * (k + cnt) + 1) / lP; @@ -321,7 +354,7 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O int s3 = (srcPtr[srcIndx1] & 15); int d0 = s0 * 16 + s2; int d1 = s1 * 16 + s3; - + dstPtr[dstIndx0] = d0; dstPtr[dstIndx0 + 1] = d1; } @@ -329,7 +362,7 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O } } else { // std::shared_ptr srcWeight; - + if (quanCommon->canUseInt4) { mResourceInt8->mWeightAsymmetricQuant = true; auto srcPtr = reinterpret_cast(quanCommon->weight.get()); @@ -363,7 +396,7 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O dst0[j] = d; } } - + // Update int4 weight to mWeightInt8. mResourceInt8->mWeightInt8 = weightLow; } else { @@ -405,7 +438,7 @@ static void _computeAlphaScale(Backend* backend, const Convolution2D* conv2d, st auto alphaPtr = scaleBias->host(); auto biasPtr = reinterpret_cast(reinterpret_cast(alphaPtr) + ocUp4 * core->bytes); ::memset(alphaPtr, 0, 2 * ocUp4 * core->bytes); - + // Load quant scale and bias weightOrigin = resourceInt8->mWeightInt8->host(); auto wZero = resourceInt8->mWeightQuantZero->host(); // has packed to outputUp4 @@ -425,7 +458,7 @@ static void _computeAlphaScale(Backend* backend, const Convolution2D* conv2d, st } } resourceInt8->mOriginScale = scaleBias; - + // Compute float weightKernelSum resourceInt8->mWeightKernelSum.reset(Tensor::createDevice({ocUp4 * 4})); success = backend->onAcquireBuffer(resourceInt8->mWeightKernelSum.get(), Backend::STATIC); @@ -506,6 +539,25 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input int UNIT, SRC_UNIT, DST_XUNIT; core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT); +#ifdef MNN_KLEIDIAI_ENABLED + KleidiAI& kai = KleidiAI::getInstance(); + if(mResourceInt8->mDynamicQuant && mResourceInt8->mActBits == 4 && kai.canAccelerate()) { + int batch = inputs[0]->batch(); + int channel = inputs[0]->channel(); + + int packedSize = kai.getLhsQuantedPackedSize(batch, channel); + mTempIm2ColBuffer.reset(Tensor::createDevice({packedSize})); + bool success = backend()->onAcquireBuffer(mTempIm2ColBuffer.get(), Backend::DYNAMIC); + if (!success) { + MNN_ERROR("Out of dynamic memory!\n"); + return OUT_OF_MEMORY; + } + + backend()->onReleaseBuffer(mTempIm2ColBuffer.get(), Backend::DYNAMIC); + return NO_ERROR; + } +#endif + if (mResourceInt8->mDynamicQuant == false) { mMutableResource->updateInputOutputScale(TensorUtils::getQuantInfo(inputs[0]), TensorUtils::getQuantInfo(outputs[0])); CPUConvolution::onResize(inputs, outputs); @@ -660,6 +712,70 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu auto core = static_cast(backend())->int8Functions(); auto gcore = static_cast(backend())->functions(); +#ifdef MNN_KLEIDIAI_ENABLED + KleidiAI& kai = KleidiAI::getInstance(); + if(mResourceInt8->mDynamicQuant && mResourceInt8->mActBits == 4 && kai.canAccelerate()) { + const size_t m = input->batch(); //lhs vector number. + const size_t n = output->channel(); //rhs vector number. + const size_t k = input->channel(); //vector size. + + auto lhs = input->host(); + auto lhsPacked = mTempIm2ColBuffer->host(); + auto rhsPacked = mResourceInt8->mWeightInt8->host(); + auto dst = output->host(); + + int threadNum = static_cast(backend())->threadNumber(); + int threadNeed, vecPerThread; + +#if !KAI_CONV_NCHW_IN_OUT + kai.packNC4HW4ToNCHW((float *)lhs, m, k); +#endif + + //Dynamic quant pack lhs. + if(m == 1) { + kai.runLhsQuantPack(1, k, 1, lhs, lhsPacked); + } else { + vecPerThread = kai.getVecNumPerThread(m, threadNum, kai.getMr(m)); + threadNeed = m % vecPerThread == 0 ? m / vecPerThread : (m / vecPerThread + 1); + size_t srcStride = vecPerThread * k * sizeof(float); + + auto BatchDynamicQuant = [=, &kai](int tId) { + auto threadSrc = lhs + tId * srcStride; + auto threadDst = lhsPacked + kai.getLhsQuantedPackedOffset(m, tId * vecPerThread, k); + int vecNum = (tId == threadNeed - 1) ? (m - vecPerThread * tId) : vecPerThread; //Last threadN may less than vecPerThread. + kai.runLhsQuantPack(vecNum, k, kai.getMr(m), threadSrc, threadDst); + }; + + MNN_CONCURRENCY_BEGIN(tId, threadNeed) { + BatchDynamicQuant((int)tId); + } + MNN_CONCURRENCY_END(); + } + + //Run matmul. + vecPerThread = kai.getVecNumPerThread(n, threadNum, kai.getNStep()); + threadNeed = n % vecPerThread == 0 ? n / vecPerThread : (n / vecPerThread + 1); + + auto ThreadFunction = [=, &kai](int tId) { + auto threadRhsPacked = rhsPacked + kai.getRhsPackedOffset(tId * vecPerThread, k); + auto threadDst = dst + kai.getDstOffset(0, tId * vecPerThread, n); + int vecNum = (tId == threadNeed - 1) ? (n - vecPerThread * tId) : vecPerThread; //Last threadN may less than vecPerThread. + kai.runMatmul(m, vecNum, k, lhsPacked, threadRhsPacked, n * sizeof(float), threadDst); + }; + + MNN_CONCURRENCY_BEGIN(tId, threadNeed) { + ThreadFunction((int)tId); + } + MNN_CONCURRENCY_END(); + +#if !KAI_CONV_NCHW_IN_OUT + kai.packNCHWToNC4HW4((float *)dst, m, n); +#endif + + return NO_ERROR; + } +#endif + int UNIT__, SRC_UNIT, DST_XUNIT; core->MNNGetGemmUnit(&UNIT__, &SRC_UNIT, &DST_XUNIT); auto blitProc = core->MNNPackC4Int8ForMatMul_A; diff --git a/source/core/TensorUtils.hpp b/source/core/TensorUtils.hpp index 2237e065b..442b3184a 100644 --- a/source/core/TensorUtils.hpp +++ b/source/core/TensorUtils.hpp @@ -19,6 +19,16 @@ #undef CONSTANT #endif // CONSTANT +#ifdef MNN_KLEIDIAI_ENABLED +#include "../backend/cpu/arm/kleidiAI/mnn_kleidiai.h" +/** + * Set DenseConvInt8TiledExecutor's input/output tensor format: + * KAI_CONV_NCHW_IN_OUT = 1: format will be NCHW, skip pack/unpack functions. + * KAI_CONV_NCHW_IN_OUT = 0: format will be NC4HW4, need pack/unpack functions to fit kleidiAI ukernel. + **/ +#define KAI_CONV_NCHW_IN_OUT 1 +#endif + namespace MNN { struct TensorArrayAttr { // array size is dynamic or not diff --git a/source/geometry/GeometryConvUtils.cpp b/source/geometry/GeometryConvUtils.cpp index dd3b53cfa..21670bd24 100644 --- a/source/geometry/GeometryConvUtils.cpp +++ b/source/geometry/GeometryConvUtils.cpp @@ -247,12 +247,23 @@ std::shared_ptr GeometryConvUtils::im2Col(Tensor* im2Col, Tensor* input, return tempTensor; } bool GeometryConvUtils::computeSingle(const Op* op, const std::vector& inputs, const std::vector& outputs, GeometryComputer::Context& context, CommandBuffer& res) { +#if KAI_CONV_NCHW_IN_OUT + if(KleidiAI::getInstance().canAccelerate()) { + std::shared_ptr cmd(new Command); + cmd->op = op; + cmd->inputs = std::move(inputs); + cmd->outputs = std::move(outputs); + res.command.emplace_back(std::move(cmd)); + return true; + } +#endif auto newOutputs = outputs; auto newInputs = inputs; auto originOutput = outputs[0]; auto output = originOutput; auto inputDes = TensorUtils::getDescribe(newInputs[0]); auto format = inputDes->dimensionFormat; + if (MNN_DATA_FORMAT_NC4HW4 != format) { std::shared_ptr newInput(new Tensor(newInputs[0], Tensor::CAFFE_C4, false)); ConvertUtils::compute(newInputs[0], newInput.get(), res); diff --git a/source/shape/ShapeTensorConvert.cpp b/source/shape/ShapeTensorConvert.cpp index a3d2035e0..899b9410b 100644 --- a/source/shape/ShapeTensorConvert.cpp +++ b/source/shape/ShapeTensorConvert.cpp @@ -23,6 +23,11 @@ class TensorConvertSizeComputer : public SizeComputer { sourceFmt = MNN_DATA_FORMAT_NCHW; } auto destFmt = info->dest(); +#if KAI_CONV_NCHW_IN_OUT + if(KleidiAI::getInstance().canAccelerate()) { + destFmt = MNN_DATA_FORMAT_NCHW; + } +#endif TensorUtils::getDescribe(outputs[0])->dimensionFormat = destFmt; if (destFmt == MNN_DATA_FORMAT_NC4HW4) { destFmt = MNN_DATA_FORMAT_NCHW;