Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate kleidiAI release v0.3.0 into MNN 2.9.6 #2995

Merged
merged 8 commits into from
Oct 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ option(MNN_OPENCL "Enable OpenCL" OFF)
option(MNN_OPENGL "Enable OpenGL" OFF)
option(MNN_VULKAN "Enable Vulkan" OFF)
option(MNN_ARM82 "Enable ARMv8.2's FP16 Compute" ON)
option(MNN_KLEIDIAI "Enable KLEIDIAI" OFF)
option(MNN_ONEDNN "Enable oneDNN" OFF)
option(MNN_AVX512 "Enable AVX512" OFF)
option(MNN_CUDA "Enable CUDA" OFF)
Expand Down Expand Up @@ -253,6 +254,7 @@ message(STATUS "\tOpenCL: ${MNN_OPENCL}")
message(STATUS "\tOpenGL: ${MNN_OPENGL}")
message(STATUS "\tVulkan: ${MNN_VULKAN}")
message(STATUS "\tARM82: ${MNN_ARM82}")
message(STATUS "\tKleidiAI: ${MNN_KLEIDIAI}")
message(STATUS "\toneDNN: ${MNN_ONEDNN}")
message(STATUS "\tTensorRT: ${MNN_TENSORRT}")
message(STATUS "\tCoreML: ${MNN_COREML}")
Expand Down
7 changes: 7 additions & 0 deletions source/backend/cpu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,10 @@ IF(MNN_ARM82)
ENDIF()
ENDIF()

# Kleidi AI
IF(MNN_KLEIDIAI)
add_definitions(-DMNN_KLEIDIAI_ENABLED=1)
include(${CMAKE_CURRENT_LIST_DIR}/arm/kleidiAI/CMakeLists.txt)
list(APPEND MNN_TARGETS MNN_KleidiAI)
list(APPEND MNN_OBJECTS_TO_LINK $<TARGET_OBJECTS:MNN_KleidiAI>)
ENDIF()
4 changes: 4 additions & 0 deletions source/backend/cpu/CPUBackend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
#include "core/BufferAllocator.hpp"
#include "MNN_generated.h"

#ifdef MNN_KLEIDIAI_ENABLED
#include "arm/kleidiAI/mnn_kleidiai.h"
#endif

namespace MNN {
class CPURuntime : public Runtime {
public:
Expand Down
63 changes: 63 additions & 0 deletions source/backend/cpu/arm/kleidiAI/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#
# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#

project(MNN_KleidiAI
LANGUAGES C CXX ASM
)

set(KLEIDIAI_MIN_CLANG_VERSION 11)
set(KLEIDIAI_MIN_GNU_VERSION 11)

if(CMAKE_C_COMPILER_ID STREQUAL "Clang" AND CMAKE_C_COMPILER_VERSION VERSION_LESS ${KLEIDIAI_MIN_CLANG_VERSION})
message(WARNING "KleidiAI: Using non-supported Clang version. Expected ${KLEIDIAI_MIN_CLANG_VERSION} or newer, received ${CMAKE_C_COMPILER_VERSION}.")
endif()

if(CMAKE_C_COMPILER_ID STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_LESS ${KLEIDIAI_MIN_GNU_VERSION})
message(WARNING "KleidiAI: Using non-supported GCC version. Expected ${KLEIDIAI_MIN_GNU_VERSION} or newer, received ${CMAKE_C_COMPILER_VERSION}.")
endif()

list(APPEND MNN_KleidiAI_SOURCES ${CMAKE_CURRENT_LIST_DIR}/mnn_kleidiai.cpp)
list(APPEND MNN_KleidiAI_HEADERS ${CMAKE_CURRENT_LIST_DIR}/mnn_kleidiai.h)

add_library(
MNN_KleidiAI
SHARED
${MNN_KleidiAI_SOURCES} ${MNN_KleidiAI_HEADERS}
)

set(KLEIDIAI_SRC ${CMAKE_CURRENT_LIST_DIR})

include_directories(
${KLEIDIAI_SRC}/
${KLEIDIAI_SRC}/kai/
${KLEIDIAI_SRC}/kai/ukernels/
${KLEIDIAI_SRC}/kai/ukernels/matmul/
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)

set(KLEIDIAI_FILES_SCALAR
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c
)

set(KLEIDIAI_FILES_NEON_DOTPROD
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c
)

set(KLEIDIAI_FILES_NEON_I8MM
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c
)

# Selectively enable architecture features.
target_sources(MNN_KleidiAI PRIVATE ${KLEIDIAI_FILES_SCALAR})
if((CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") AND NOT MSVC)
target_sources(MNN_KleidiAI PRIVATE ${KLEIDIAI_FILES_NEON_DOTPROD})
target_sources(MNN_KleidiAI PRIVATE ${KLEIDIAI_FILES_NEON_I8MM})

set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS -march=armv8-a)
set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+dotprod)
set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+i8mm)
endif()
194 changes: 194 additions & 0 deletions source/backend/cpu/arm/kleidiAI/kai/kai_common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
//
// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
//
// SPDX-License-Identifier: Apache-2.0
//
#pragma once

#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#ifdef __cplusplus
extern "C" {
#endif

// NOLINTBEGIN(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c)
//
// * cppcoreguidelines-avoid-do-while: do-while is necessary for macros.
// * cppcoreguidelines-pro-type-vararg: use of variadic arguments in fprintf is expected.
// * cert-err33-c: checking the output of fflush and fprintf is not necessary for error reporting.
#define KAI_ERROR(msg) \
do { \
fflush(stdout); \
fprintf(stderr, "%s:%d %s", __FILE__, __LINE__, msg); \
exit(EXIT_FAILURE); \
} while (0)

#define KAI_ASSERT_MSG(cond, msg) \
do { \
if (!(cond)) { \
KAI_ERROR(msg); \
} \
} while (0)

// NOLINTEND(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c)

#define KAI_ASSERT(cond) KAI_ASSERT_MSG(cond, #cond)

#define KAI_ASSERT_IF_MSG(precond, cond, msg) KAI_ASSERT_MSG(!(precond) || (cond), msg)
#define KAI_ASSERT_IF(precond, cond) KAI_ASSERT_IF_MSG(precond, cond, #precond " |-> " #cond)

#define KAI_ASSUME_MSG KAI_ASSERT_MSG
#define KAI_ASSUME KAI_ASSERT
#define KAI_ASSUME_IF_MSG KAI_ASSERT_IF_MSG
#define KAI_ASSUME_IF KAI_ASSERT_IF

#define KAI_UNUSED(x) (void)(x)
#define KAI_MIN(a, b) (((a) < (b)) ? (a) : (b))
#define KAI_MAX(a, b) (((a) > (b)) ? (a) : (b))

/// KleidiAI data types
/// Format: <byte 3>(reserved)|<byte 2>(num-bytes)|<byte 1>(type)|<byte 0>(variant-type)
enum kai_datatype {
kai_dt_unknown = 0x0000,
kai_dt_f32 = 0x0411,
kai_dt_f16 = 0x0212,
kai_dt_bf16 = 0x0213,
kai_dt_int32 = 0x0421,
kai_dt_int16 = 0x0222,
kai_dt_int8 = 0x0124,
kai_dt_uint32 = 0x0431,
kai_dt_uint16 = 0x0232,
kai_dt_uint8 = 0x0134,
kai_dt_bool = 0x0441
};

/// Gets number of bytes for a given data type
/// @param[in] dt KleidiAI data type
///
/// @return the numbers of bytes for the data type
inline static size_t kai_get_datatype_size_in_bytes(enum kai_datatype dt) {
return (size_t)(dt >> 8);
}

/// Converts a scalar f16 value to f32
/// @param[in] f16 The f16 value
///
/// @return the f32 value
inline static float kai_cast_f32_f16(uint16_t f16) {
#if defined(__ARM_NEON)
__fp16 f32 = 0;
memcpy(&f32, &f16, sizeof(uint16_t));
return (float)f32;
#endif
}

/// Converts a scalar bf16 value to f32
/// @param[in] bf16 The f16 value
///
/// @return the f32 value
inline static float kai_cast_f32_bf16(uint16_t bf16) {
const uint32_t i32 = (bf16 << 16);
float f32;
memcpy(&f32, &i32, sizeof(i32));
return f32;
}

/// Converts a f32 value to bf16
/// @param[in] f32 The f32 value
///
/// @return the bf16 value
inline static uint16_t kai_cast_bf16_f32(float f32) {
uint16_t bf16 = 0;
#ifdef __ARM_FEATURE_BF16
__asm__ __volatile__("bfcvt %h[output], %s[input]" : [output] "=w"(bf16) : [input] "w"(f32));
#else
const uint32_t* i32 = (uint32_t*)(&f32);
bf16 = (*i32 >> 16);
#endif
return bf16;
}

/// Converts a scalar f32 value to f16
/// @param[in] f32 The f32 value
///
/// @return the f16 value
inline static uint16_t kai_cast_f16_f32(float f32) {
#if defined(__ARM_NEON)
uint16_t f16 = 0;
__fp16 tmp = f32;
memcpy(&f16, &tmp, sizeof(uint16_t));
return f16;
#endif
}

inline static size_t kai_roundup(size_t a, size_t b) {
return ((a + b - 1) / b) * b;
}

#ifdef __ARM_FEATURE_SVE

/// Gets the SME vector length for 8-bit elements.
inline static uint64_t kai_get_sme_vector_length_u8(void) {
uint64_t res = 0;

__asm__ __volatile__(
".inst 0xd503477f // SMSTART ZA\n"
"cntb %0\n"
".inst 0xd503467f // SMSTOP\n"
: "=r"(res)
:
: "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16",
"z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31");

return res;
}

/// Gets the SME vector length for 16-bit elements.
inline static uint64_t kai_get_sme_vector_length_u16(void) {
uint64_t res = 0;

__asm__ __volatile__(
".inst 0xd503477f // SMSTART ZA\n"
"cnth %0\n"
".inst 0xd503467f // SMSTOP\n"
: "=r"(res)
:
: "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16",
"z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31");

return res;
}

/// Gets the SME vector length for 32-bit elements.
inline static uint64_t kai_get_sme_vector_length_u32(void) {
uint64_t res = 0;

__asm__ __volatile__(
".inst 0xd503477f // SMSTART ZA\n"
"cntw %0\n"
".inst 0xd503467f // SMSTOP\n"
: "=r"(res)
:
: "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16",
"z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31");

return res;
}

#endif // __ARM_FEATURE_SVE

/// Extends the sign bit of int 4-bit value (stored in int8_t variable)
/// @param[in] value The 4-bit int value
///
/// @return the int8_t value with sign extended
inline static int8_t kai_ext_sign_i8_i4(int8_t value) {
return (value ^ 0x8) - 8;
}

#ifdef __cplusplus
}
#endif
Loading
Loading