Skip to content

Commit

Permalink
Fix CPU test weights
Browse files Browse the repository at this point in the history
  • Loading branch information
mmikolajcz committed May 16, 2024
1 parent 329755e commit e631787
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,23 +128,35 @@ const std::vector<InputShape> input_shapes = {
const std::vector<std::vector<size_t>> indices = {{0, 1, 2, 2, 3}, {4, 4, 3, 1, 0}, {1, 2, 1, 2, 1, 2, 1, 2, 1, 2}};
const std::vector<std::vector<size_t>> offsets = {{0, 2}, {0, 0, 2, 2}, {2, 4}};
const std::vector<size_t> default_index = {0, 4};
const std::vector<bool> with_weights = {false};
const std::vector<bool> with_default_index = {false, true};
const std::vector<ov::op::util::EmbeddingBagOffsetsBase::Reduction> reduction = {
ov::op::util::EmbeddingBagOffsetsBase::Reduction::SUM,
ov::op::util::EmbeddingBagOffsetsBase::Reduction::MEAN};

const auto embBagOffsetArgSet = ::testing::Combine(::testing::ValuesIn(input_shapes),
const auto embBagOffsetArgSetWthWeights = ::testing::Combine(::testing::ValuesIn(input_shapes),
::testing::ValuesIn(indices),
::testing::ValuesIn(offsets),
::testing::ValuesIn(default_index),
::testing::ValuesIn(with_weights),
::testing::Values(true),
::testing::ValuesIn(with_default_index),
::testing::Values(ov::op::util::EmbeddingBagOffsetsBase::Reduction::SUM));
const auto embBagOffsetArgSetNoWeights = ::testing::Combine(::testing::ValuesIn(input_shapes),
::testing::ValuesIn(indices),
::testing::ValuesIn(offsets),
::testing::ValuesIn(default_index),
::testing::Values(false),
::testing::ValuesIn(with_default_index),
::testing::ValuesIn(reduction));

INSTANTIATE_TEST_SUITE_P(smoke,
INSTANTIATE_TEST_SUITE_P(smoke_EmbeddingBagOffsets_With_Weights,
EmbeddingBagOffsetsLayerCPUTest,
::testing::Combine(embBagOffsetArgSetWthWeights,
::testing::ValuesIn(netPrecisions),
::testing::ValuesIn(indPrecisions),
::testing::Values(ov::test::utils::DEVICE_CPU)),
EmbeddingBagOffsetsLayerCPUTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_EmbeddingBagOffsets_No_Weights,
EmbeddingBagOffsetsLayerCPUTest,
::testing::Combine(embBagOffsetArgSet,
::testing::Combine(embBagOffsetArgSetNoWeights,
::testing::ValuesIn(netPrecisions),
::testing::ValuesIn(indPrecisions),
::testing::Values(ov::test::utils::DEVICE_CPU)),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,20 +119,31 @@ const std::vector<InputShape> input_shapes = {
const std::vector<std::vector<std::vector<size_t>>> indices = {{{0, 1}, {2, 2}, {3, 4}},
{{4, 4, 3}, {1, 0, 2}},
{{1, 2, 1, 2}, {1, 2, 1, 2}}};
const std::vector<bool> with_weights = {false};

const std::vector<ov::op::util::EmbeddingBagPackedBase::Reduction> reduction = {
ov::op::util::EmbeddingBagPackedBase::Reduction::SUM,
ov::op::util::EmbeddingBagPackedBase::Reduction::MEAN};

const auto embBagPackedArgSet = ::testing::Combine(::testing::ValuesIn(input_shapes),
const auto embBagPackedArgSetWthWeights = ::testing::Combine(::testing::ValuesIn(input_shapes),
::testing::ValuesIn(indices),
::testing::ValuesIn(with_weights),
::testing::Values(true),
::testing::Values(ov::op::util::EmbeddingBagPackedBase::Reduction::SUM));

const auto embBagPackedArgSetNoWeights = ::testing::Combine(::testing::ValuesIn(input_shapes),
::testing::ValuesIn(indices),
::testing::Values(false),
::testing::ValuesIn(reduction));

INSTANTIATE_TEST_SUITE_P(smoke,
INSTANTIATE_TEST_SUITE_P(smoke_EmbeddingBagPacked_With_Weights,
EmbeddingBagPackedLayerCPUTest,
::testing::Combine(embBagPackedArgSetWthWeights,
::testing::ValuesIn(netPrecisions),
::testing::ValuesIn(indPrecisions),
::testing::Values(ov::test::utils::DEVICE_CPU)),
EmbeddingBagPackedLayerCPUTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_EmbeddingBagPacked_No_Weights,
EmbeddingBagPackedLayerCPUTest,
::testing::Combine(embBagPackedArgSet,
::testing::Combine(embBagPackedArgSetNoWeights,
::testing::ValuesIn(netPrecisions),
::testing::ValuesIn(indPrecisions),
::testing::Values(ov::test::utils::DEVICE_CPU)),
Expand Down

0 comments on commit e631787

Please sign in to comment.