From 5c3bd50495c64a8d4071ffb0fca0c05e8667fbde Mon Sep 17 00:00:00 2001 From: Mateusz Date: Tue, 28 Nov 2023 13:10:08 +0100 Subject: [PATCH] Add shape tests --- .../scaled_dot_product_attention.cpp | 52 ++++++++++++++++++- 1 file changed, 50 insertions(+), 2 deletions(-) diff --git a/src/core/tests/type_prop/scaled_dot_product_attention.cpp b/src/core/tests/type_prop/scaled_dot_product_attention.cpp index 48fb1070ada873..cb401c02150f70 100644 --- a/src/core/tests/type_prop/scaled_dot_product_attention.cpp +++ b/src/core/tests/type_prop/scaled_dot_product_attention.cpp @@ -60,7 +60,7 @@ TEST(type_prop, scaled_dot_product_attention_static_3_inputs_causal) { EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 3, 6})); } -TEST(type_prop, scaled_dot_product_attention_static_iopored_attention_mask) { +TEST(type_prop, scaled_dot_product_attention_static_ignored_attention_mask) { const auto query = std::make_shared(element::f32, PartialShape{2, 3, 4}); const auto key = std::make_shared(element::f32, PartialShape{2, 5, 4}); const auto value = std::make_shared(element::f32, PartialShape{2, 5, 6}); @@ -122,7 +122,7 @@ TEST(type_prop, scaled_dot_product_attention_static_3_inputs_extra_batch_causal_ EXPECT_EQ(op->get_output_partial_shape(0), (PartialShape{2, 7, 3, 6})); } -TEST(type_prop, scaled_dot_product_attention_static_iopored_attention_mask_extra_batch) { +TEST(type_prop, scaled_dot_product_attention_static_ignored_attention_mask_extra_batch) { const auto query = std::make_shared(element::f32, PartialShape{2, 7, 3, 4}); const auto key = std::make_shared(element::f32, PartialShape{2, 7, 5, 4}); const auto value = std::make_shared(element::f32, PartialShape{2, 7, 5, 6}); @@ -165,3 +165,51 @@ TEST(type_prop, scaled_dot_product_attention_dynamic_4d) { EXPECT_EQ(op->get_output_element_type(0), element::f32); EXPECT_EQ(op->get_output_partial_shape(0), (dynamic)); } +TEST(type_prop, scaled_dot_product_unsupported_key_shape) { + const auto query = std::make_shared(element::f32, PartialShape{2, 3, 4}); + const auto key = std::make_shared(element::f32, PartialShape{3, 5, 4}); + const auto value = std::make_shared(element::f32, PartialShape{2, 5, 6}); + const auto attention_mask = std::make_shared(element::f32, PartialShape{3, 3, 3, 5}); + auto causal = false; + + OV_EXPECT_THROW(std::make_shared(query, key, value, attention_mask, causal), + AssertFailure, + testing::HasSubstr("Key input shape not compatible with other inputs.")); +} +TEST(type_prop, scaled_dot_product_unsupported_value_shape) { + const auto query = std::make_shared(element::f32, PartialShape{2, 3, 4}); + const auto key = std::make_shared(element::f32, PartialShape{2, 5, 4}); + const auto value = std::make_shared(element::f32, PartialShape{3, 5, 6}); + const auto attention_mask = std::make_shared(element::f32, PartialShape{3, 3, 3, 5}); + auto causal = false; + + OV_EXPECT_THROW(std::make_shared(query, key, value, attention_mask, causal), + AssertFailure, + testing::HasSubstr("Value input shape not compatible with other inputs.")); +} + +TEST(type_prop, scaled_dot_product_unsupported_attention_shape) { + const auto query = std::make_shared(element::f32, PartialShape{2, 3, 4}); + const auto key = std::make_shared(element::f32, PartialShape{2, 5, 4}); + const auto value = std::make_shared(element::f32, PartialShape{2, 5, 6}); + const auto attention_mask = std::make_shared(element::f32, PartialShape{3, 3, 3, 5}); + auto causal = false; + + OV_EXPECT_THROW(std::make_shared(query, key, value, attention_mask, causal), + AssertFailure, + testing::HasSubstr("Attention mask input shape not compatible with other inputs.")); +} + +TEST(type_prop, scaled_dot_product_unsupported_scale_shape) { + const auto query = std::make_shared(element::f32, PartialShape{2, 3, 4}); + const auto key = std::make_shared(element::f32, PartialShape{2, 5, 4}); + const auto value = std::make_shared(element::f32, PartialShape{2, 5, 6}); + const auto attention_mask = std::make_shared(element::f32, PartialShape{3, 3, 3, 5}); + const auto scale = std::make_shared(element::f32, PartialShape{1}); + auto causal = false; + + OV_EXPECT_THROW( + std::make_shared(query, key, value, attention_mask, scale, causal), + AssertFailure, + testing::HasSubstr("Scale input accepts only scalar tensor.")); +}