Skip to content

Commit

Permalink
[CPU] Fixed accuracy of compressed matmul with scalar scale (openvino…
Browse files Browse the repository at this point in the history
…toolkit#28431)

### Details:
 - This PR fixes accuracy of compressed Matmuls with scalar scale
- The issue was on micro-kernel level: incorrect offset during scales
load

oneDNN PR: openvinotoolkit/oneDNN#270

### Tickets:
 - [CVS-143420](https://jira.devtools.intel.com/browse/CVS-143420)
  • Loading branch information
dmitry-gorokhov authored and MirceaDan99 committed Jan 22, 2025
1 parent 139ae28 commit fe3ebb1
Show file tree
Hide file tree
Showing 11 changed files with 125 additions and 61 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights,
::testing::ValuesIn(decompression_precisions),
::testing::Values(ov::element::undefined),
::testing::Values(true),
::testing::Values(DecompressionSubtractType::full),
::testing::Values(DecompressionType::full),
::testing::Values(DecompressionType::full),
::testing::Values(false),
::testing::ValuesIn(filter_additional_config_basic()),
::testing::ValuesIn(fusing_params),
Expand All @@ -61,9 +62,9 @@ const std::vector<MatMulDecompressionShapeParams> input_shapes_corner_cases = {
};

const std::vector<bool> transpose_weights = {true, false};
const std::vector<DecompressionSubtractType> decompression_subtract_type = {DecompressionSubtractType::full,
DecompressionSubtractType::scalar,
DecompressionSubtractType::empty};
const std::vector<DecompressionType> decompression_subtract_type = {DecompressionType::full,
DecompressionType::scalar,
DecompressionType::empty};
const std::vector<bool> reshape_on_decompression = {true, false};
const std::vector<ov::test::ElementType> decompression_precisions_corner_cases = {ov::element::f16, ov::element::f32};

Expand All @@ -74,6 +75,7 @@ INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_corner_cases,
::testing::ValuesIn(decompression_precisions_corner_cases),
::testing::Values(ov::element::undefined),
::testing::ValuesIn(transpose_weights),
::testing::Values(DecompressionType::full),
::testing::ValuesIn(decompression_subtract_type),
::testing::ValuesIn(reshape_on_decompression),
::testing::ValuesIn(filter_additional_config_basic()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ std::string MatmulWeightsDecompression::getTestCaseName(testing::TestParamInfo<M
ov::test::ElementType decompression_precision;
ov::test::ElementType scale_precision;
bool transpose;
DecompressionSubtractType decompression_subtract_type;
DecompressionType decompression_multiply_type;
DecompressionType decompression_subtract_type;
bool reshape_on_decompression;
ov::AnyMap additional_config;
fusingSpecificParams fusing_params;
Expand All @@ -27,6 +28,7 @@ std::string MatmulWeightsDecompression::getTestCaseName(testing::TestParamInfo<M
decompression_precision,
scale_precision,
transpose,
decompression_multiply_type,
decompression_subtract_type,
reshape_on_decompression,
additional_config,
Expand All @@ -39,6 +41,7 @@ std::string MatmulWeightsDecompression::getTestCaseName(testing::TestParamInfo<M
result << "decompression_precision=" << decompression_precision << "_";
result << "scale_precision=" << scale_precision << "_";
result << "transpose_weights=" << transpose << "_";
result << "decompression_multiply=" << decompression_multiply_type << "_";
result << "decompression_subtract=" << decompression_subtract_type << "_";
result << "reshape_on_decompression=" << reshape_on_decompression << "_";

Expand All @@ -60,7 +63,8 @@ std::shared_ptr<ov::Model> MatmulWeightsDecompression::initSubgraph(const ov::Pa
const ov::element::Type decompression_precision,
const ov::element::Type scale_precision,
const bool transpose_weights,
const DecompressionSubtractType decompression_subtract_type,
const DecompressionType decompression_multiply_type,
const DecompressionType decompression_subtract_type,
const bool reshape_on_decompression) {
ov::ParameterVector params{std::make_shared<ov::op::v0::Parameter>(data_precision, data_shape)};
const auto weights_subgraph = initMatMulDecompressionSubgraph(weights_shape,
Expand All @@ -70,6 +74,7 @@ std::shared_ptr<ov::Model> MatmulWeightsDecompression::initSubgraph(const ov::Pa
decompression_precision,
scale_precision,
transpose_weights,
decompression_multiply_type,
decompression_subtract_type,
reshape_on_decompression);
auto matMul = std::make_shared<ov::op::v0::MatMul>(params[0], weights_subgraph);
Expand All @@ -84,7 +89,8 @@ void MatmulWeightsDecompression::SetUp() {
ov::test::ElementType decompression_precision;
ov::test::ElementType scale_precision;
bool transpose_weights;
DecompressionSubtractType decompression_subtract_type;
DecompressionType decompression_multiply_type;
DecompressionType decompression_subtract_type;
bool reshape_on_decompression;
ov::AnyMap additional_config;
fusingSpecificParams fusing_params;
Expand All @@ -95,6 +101,7 @@ void MatmulWeightsDecompression::SetUp() {
decompression_precision,
scale_precision,
transpose_weights,
decompression_multiply_type,
decompression_subtract_type,
reshape_on_decompression,
additional_config,
Expand Down Expand Up @@ -131,14 +138,15 @@ void MatmulWeightsDecompression::SetUp() {
decompression_precision,
scale_precision,
transpose_weights,
decompression_multiply_type,
decompression_subtract_type,
reshape_on_decompression);
}

void MatmulWeightsDecompression::check_results() {
const auto& test_param = GetParam();
const ov::element::Type compressed_weights_precision = std::get<1>(test_param);
const bool use_matmul_decompression_impl = std::get<9>(test_param);
const bool use_matmul_decompression_impl = std::get<10>(test_param);

const auto runtime_model = compiledModel.get_runtime_model();
const auto result = runtime_model->get_result();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ typedef std::tuple<MatMulDecompressionShapeParams,
ov::test::ElementType, // decompression precision
ov::test::ElementType, // scale precision
bool, // transpose on weights
DecompressionSubtractType, // decompression subtract type
DecompressionType, // decompression multiply type
DecompressionType, // decompression subtract type
bool, // reshape on decompression constants
ov::AnyMap, // additional config
fusingSpecificParams,
Expand All @@ -67,7 +68,8 @@ class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeig
const ov::element::Type decompression_precision,
const ov::element::Type scale_precision,
const bool transpose_weights,
const DecompressionSubtractType decompression_subtract_type,
const DecompressionType decompression_multiply_type,
const DecompressionType decompression_subtract_type,
const bool reshape_on_decompression);

void SetUp() override;
Expand Down
Loading

0 comments on commit fe3ebb1

Please sign in to comment.