Skip to content

Commit

Permalink
Add shape tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mmikolajcz committed Nov 28, 2023
1 parent fe7f486 commit 5c3bd50
Showing 1 changed file with 50 additions and 2 deletions.
52 changes: 50 additions & 2 deletions src/core/tests/type_prop/scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ TEST(type_prop, scaled_dot_product_attention_static_3_inputs_causal) {
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 6}));
}

TEST(type_prop, scaled_dot_product_attention_static_iopored_attention_mask) {
TEST(type_prop, scaled_dot_product_attention_static_ignored_attention_mask) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 6});
Expand Down Expand Up @@ -122,7 +122,7 @@ TEST(type_prop, scaled_dot_product_attention_static_3_inputs_extra_batch_causal_
EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 7, 3, 6}));
}

TEST(type_prop, scaled_dot_product_attention_static_iopored_attention_mask_extra_batch) {
TEST(type_prop, scaled_dot_product_attention_static_ignored_attention_mask_extra_batch) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 7, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 7, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 7, 5, 6});
Expand Down Expand Up @@ -165,3 +165,51 @@ TEST(type_prop, scaled_dot_product_attention_dynamic_4d) {
EXPECT_EQ(op->get_output_element_type(0), element::f32);
EXPECT_EQ(op->get_output_partial_shape(0), (dynamic));
}
TEST(type_prop, scaled_dot_product_unsupported_key_shape) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{3, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{3, 3, 3, 5});
auto causal = false;

OV_EXPECT_THROW(std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, causal),
AssertFailure,
testing::HasSubstr("Key input shape not compatible with other inputs."));
}
TEST(type_prop, scaled_dot_product_unsupported_value_shape) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{3, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{3, 3, 3, 5});
auto causal = false;

OV_EXPECT_THROW(std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, causal),
AssertFailure,
testing::HasSubstr("Value input shape not compatible with other inputs."));
}

TEST(type_prop, scaled_dot_product_unsupported_attention_shape) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{3, 3, 3, 5});
auto causal = false;

OV_EXPECT_THROW(std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, causal),
AssertFailure,
testing::HasSubstr("Attention mask input shape not compatible with other inputs."));
}

TEST(type_prop, scaled_dot_product_unsupported_scale_shape) {
const auto query = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 3, 4});
const auto key = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 4});
const auto value = std::make_shared<opset13::Parameter>(element::f32, PartialShape{2, 5, 6});
const auto attention_mask = std::make_shared<opset13::Parameter>(element::f32, PartialShape{3, 3, 3, 5});
const auto scale = std::make_shared<opset13::Parameter>(element::f32, PartialShape{1});
auto causal = false;

OV_EXPECT_THROW(
std::make_shared<opset13::ScaledDotProductAttention>(query, key, value, attention_mask, scale, causal),
AssertFailure,
testing::HasSubstr("Scale input accepts only scalar tensor."));
}

0 comments on commit 5c3bd50

Please sign in to comment.