From 83d7867db638dfe788d8b6936a37efbb1b1196d8 Mon Sep 17 00:00:00 2001 From: Stefania Hergane Date: Wed, 4 Dec 2024 10:15:32 +0000 Subject: [PATCH] [NPU] Add NF4 precision support in NPU plugin Signed-off-by: Stefania Hergane --- .../intel_npu/src/backend/src/zero_infer_request.cpp | 7 +++++-- .../src/compiler_adapter/src/driver_compiler_adapter.cpp | 2 ++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/plugins/intel_npu/src/backend/src/zero_infer_request.cpp b/src/plugins/intel_npu/src/backend/src/zero_infer_request.cpp index a0e5d2d11c1fef..1b419f962d3d0a 100644 --- a/src/plugins/intel_npu/src/backend/src/zero_infer_request.cpp +++ b/src/plugins/intel_npu/src/backend/src/zero_infer_request.cpp @@ -647,6 +647,8 @@ void ZeroInferRequest::check_network_precision(const ov::element::Type_t precisi break; case ov::element::Type_t::bf16: break; + case ov::element::Type_t::nf4: + break; case ov::element::Type_t::u4: break; case ov::element::Type_t::i4: @@ -670,8 +672,9 @@ void ZeroInferRequest::check_network_precision(const ov::element::Type_t precisi case ov::element::Type_t::f64: break; default: - OPENVINO_THROW("Unsupported tensor precision: " + ov::element::Type(precision).get_type_name() + - "! Supported precisions: FP32, FP16, BF16, U4, I4, U8, I8, U16, I16, U32, I32, U64, I64, FP64"); + OPENVINO_THROW( + "Unsupported tensor precision: " + ov::element::Type(precision).get_type_name() + + "! Supported precisions: FP32, FP16, BF16, NF4, U4, I4, U8, I8, U16, I16, U32, I32, U64, I64, FP64"); } } diff --git a/src/plugins/intel_npu/src/compiler_adapter/src/driver_compiler_adapter.cpp b/src/plugins/intel_npu/src/compiler_adapter/src/driver_compiler_adapter.cpp index 9d634656db109a..7769841b287eb1 100644 --- a/src/plugins/intel_npu/src/compiler_adapter/src/driver_compiler_adapter.cpp +++ b/src/plugins/intel_npu/src/compiler_adapter/src/driver_compiler_adapter.cpp @@ -81,6 +81,8 @@ std::string ovPrecisionToLegacyPrecisionString(const ov::element::Type& precisio return "FP64"; case ov::element::Type_t::bf16: return "BF16"; + case ov::element::Type_t::nf4: + return "NF4"; case ov::element::Type_t::i4: return "I4"; case ov::element::Type_t::i8: