Skip to content

Commit

Permalink
Skip SplitM unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Jan 29, 2025
1 parent df9d66a commit 491001b
Showing 1 changed file with 67 additions and 67 deletions.
134 changes: 67 additions & 67 deletions src/common/snippets/tests/src/pass/mha_tokenization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -157,69 +157,69 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_Dynamic_Transpose_fusion) {
run();
}

TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM) {
const auto& f = MHASplitMFunction(std::vector<PartialShape>{{128, 12, 64}, {128, 12, 64}, {12, 128, 128}, {128, 12, 64}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{2, 64, 12, 64}, {12, 1, 64, 128}, {12, 2, 64, 128}, {1, 128, 12, 64}, {128, 12, 64}},
false);
model = f.getOriginal();
model_ref = f.getReference();
config.set_concurrency(24);
run();
}
// TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM) {
// const auto& f = MHASplitMFunction(std::vector<PartialShape>{{128, 12, 64}, {128, 12, 64}, {12, 128, 128}, {128, 12, 64}},
// std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
// std::vector<Shape>{{2, 64, 12, 64}, {12, 1, 64, 128}, {12, 2, 64, 128}, {1, 128, 12, 64}, {128, 12, 64}},
// false);
// model = f.getOriginal();
// model_ref = f.getReference();
// config.set_concurrency(24);
// run();
// }

TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM_withMul) {
const auto& f = MHASplitMFunction(std::vector<PartialShape>{{128, 12, 64}, {128, 12, 64}, {12, 128, 128}, {128, 12, 64}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{4, 32, 12, 64}, {12, 1, 64, 128}, {12, 4, 32, 128}, {1, 128, 12, 64}, {128, 12, 64}},
true);
model = f.getOriginal();
model_ref = f.getReference();
config.set_concurrency(16);
run();
}
// TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM_withMul) {
// const auto& f = MHASplitMFunction(std::vector<PartialShape>{{128, 12, 64}, {128, 12, 64}, {12, 128, 128}, {128, 12, 64}},
// std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
// std::vector<Shape>{{4, 32, 12, 64}, {12, 1, 64, 128}, {12, 4, 32, 128}, {1, 128, 12, 64}, {128, 12, 64}},
// true);
// model = f.getOriginal();
// model_ref = f.getReference();
// config.set_concurrency(16);
// run();
// }

TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM) {
const auto& f = MHASplitMFunction(std::vector<PartialShape>{{1, 384, 16, 64}, {1, 384, 16, 64}, {1, 1, 1, 384}, {1, 384, 16, 64}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{1, 12, 32, 16, 64}, {1, 16, 1, 64, 384}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}},
false);
model = f.getOriginal();
model_ref = f.getReference();
config.set_concurrency(60);
run();
}
// TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM) {
// const auto& f = MHASplitMFunction(std::vector<PartialShape>{{1, 384, 16, 64}, {1, 384, 16, 64}, {1, 1, 1, 384}, {1, 384, 16, 64}},
// std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
// std::vector<Shape>{{1, 12, 32, 16, 64}, {1, 16, 1, 64, 384}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}},
// false);
// model = f.getOriginal();
// model_ref = f.getReference();
// config.set_concurrency(60);
// run();
// }

TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM_withMul) {
const auto& f = MHASplitMFunction(std::vector<PartialShape>{{1, 384, 16, 64}, {1, 384, 16, 64}, {1, 1, 1, 384}, {1, 384, 16, 64}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{1, 12, 32, 16, 64}, {1, 16, 1, 64, 384}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}},
true);
model = f.getOriginal();
model_ref = f.getReference();
config.set_concurrency(60);
run();
}
// TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM_withMul) {
// const auto& f = MHASplitMFunction(std::vector<PartialShape>{{1, 384, 16, 64}, {1, 384, 16, 64}, {1, 1, 1, 384}, {1, 384, 16, 64}},
// std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
// std::vector<Shape>{{1, 12, 32, 16, 64}, {1, 16, 1, 64, 384}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}},
// true);
// model = f.getOriginal();
// model_ref = f.getReference();
// config.set_concurrency(60);
// run();
// }

TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHAWOTranspose_SplitM) {
const auto& f = MHAWOTransposeSplitMFunction(std::vector<PartialShape>{{10, 9216, 128}, {10, 128, 9216}, {10, 9216, 128}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{10, 18, 512, 128}, {10, 1, 128, 9216}, {10, 1, 9216, 128}, {10, 9216, 128}});
model = f.getOriginal();
model_ref = f.getReference();
config.set_concurrency(18);
run();
}
// TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHAWOTranspose_SplitM) {
// const auto& f = MHAWOTransposeSplitMFunction(std::vector<PartialShape>{{10, 9216, 128}, {10, 128, 9216}, {10, 9216, 128}},
// std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32}),
// std::vector<Shape>{{10, 18, 512, 128}, {10, 1, 128, 9216}, {10, 1, 9216, 128}, {10, 9216, 128}});
// model = f.getOriginal();
// model_ref = f.getReference();
// config.set_concurrency(18);
// run();
// }

TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_SplitM_AlmostAllThreads) {
const auto& f = MHAWOTransposeSplitMFunction(std::vector<PartialShape>{{5, 30, 32}, {5, 32, 30}, {5, 30, 32}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{5, 10, 3, 32}, {5, 1, 32, 30}, {5, 1, 30, 32}, {5, 30, 32}});
model = f.getOriginal();
model_ref = f.getReference();
config.set_concurrency(32);
run();
}
// TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_SplitM_AlmostAllThreads) {
// const auto& f = MHAWOTransposeSplitMFunction(std::vector<PartialShape>{{5, 30, 32}, {5, 32, 30}, {5, 30, 32}},
// std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32}),
// std::vector<Shape>{{5, 10, 3, 32}, {5, 1, 32, 30}, {5, 1, 30, 32}, {5, 30, 32}});
// model = f.getOriginal();
// model_ref = f.getReference();
// config.set_concurrency(32);
// run();
// }

TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_4D_SplitM_DynamicParameter) {
const auto &f = MHAFunction(std::vector<PartialShape>{{1, 128, 16, 64}, {1, 128, 16, 64}, {1, 16, 128, -1}, {1, 128, 16, 64}},
Expand All @@ -230,15 +230,15 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_4D_SplitM_DynamicParameter)
run();
}

TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHASelect_SplitM) {
const auto& f = MHASelectSplitMFunction(std::vector<PartialShape>{{8, 512, 18}, {8, 18, 64}, {1, 512, 64}, {1, 1, 64}, {8, 64, 512}},
std::vector<Shape>{{8, 2, 256, 18}, {8, 1, 18, 64}, {1, 2, 256, 64}, {1, 1, 1, 64},
{8, 1, 64, 512}, {8, 512, 512}});
model = f.getOriginal();
model_ref = f.getReference();
config.set_concurrency(16);
run();
}
// TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHASelect_SplitM) {
// const auto& f = MHASelectSplitMFunction(std::vector<PartialShape>{{8, 512, 18}, {8, 18, 64}, {1, 512, 64}, {1, 1, 64}, {8, 64, 512}},
// std::vector<Shape>{{8, 2, 256, 18}, {8, 1, 18, 64}, {1, 2, 256, 64}, {1, 1, 1, 64},
// {8, 1, 64, 512}, {8, 512, 512}});
// model = f.getOriginal();
// model_ref = f.getReference();
// config.set_concurrency(16);
// run();
// }

TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_Reshape_extraction) {
const auto& f = MHAWithExtractedReshapeFunction(std::vector<PartialShape>{{400, 196, 80},
Expand Down

0 comments on commit 491001b

Please sign in to comment.