Skip to content

Commit

Permalink
Use OV RTTI for ConversionExtensions (#28834)
Browse files Browse the repository at this point in the history
### Details:
 - Use OV RTTI for ConversionExtensions

### Tickets:
 - CVS-160510
  • Loading branch information
ilya-lavrenov authored Feb 8, 2025
1 parent ffc15bf commit b5a8d80
Show file tree
Hide file tree
Showing 38 changed files with 105 additions and 120 deletions.
1 change: 0 additions & 1 deletion cmake/developer_package/compile_flags/os_flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,6 @@ endfunction()
# ov_target_link_libraries_as_system(<TARGET NAME> <PUBLIC | PRIVATE | INTERFACE> <target1 target2 ...>)
#
function(ov_target_link_libraries_as_system TARGET_NAME LINK_TYPE)
message("Link to ${TARGET_NAME} using ${LINK_TYPE} the following ${ARGN}")
target_link_libraries(${TARGET_NAME} ${LINK_TYPE} ${ARGN})

# include directories as SYSTEM
Expand Down
8 changes: 4 additions & 4 deletions cmake/developer_package/ncc_naming_style/openvino.style
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
# custom OpenVINO values
CppMethod: '^(operator\W+|[a-z_\d]+|signaling_NaN|quiet_NaN|OPENVINO_OP)$'
CppMethod: '^(operator\W+|[a-z_\d]+|signaling_NaN|quiet_NaN|OPENVINO_OP|OPENVINO_RTTI)$'
ClassName: '^([A-Z][\w]+|b?float16|float8_e4m3|float8_e5m2|float4_e2m1|float8_e8m0|numeric_limits|ngraph_error|stopwatch|unsupported_op)$'
StructName: '^([A-Z][\w]+|element_type_traits|hash|oi_pair|stat)$'
FunctionName: '^(operator\W+|[a-z_\d]+)|PrintTo$'
Namespace: '^([a-z\d_]*|InferenceEngine)$'
NamespaceAlias: '^([a-z\d_]+|InferenceEngine)$'
Namespace: '^([a-z\d_]*)$'
NamespaceAlias: '^([a-z\d_]+)$'
UnionName: '[A-Z][\w]+$'
TemplateTemplateParameter: '[A-Z][\w]+'
NamespaceReference: '^([a-z\d_]+|InferenceEngine|GPUContextParams)$'
NamespaceReference: '^([a-z\d_]+)$'
TemplateNonTypeParameter: '^\w*$'
ClassTemplate: '^([A-Z][\w]+|element_type_traits)$'
TemplateTypeParameter: '^\w*$'
Expand Down
12 changes: 11 additions & 1 deletion src/cmake/openvino.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,18 @@ ov_add_vs_version_file(NAME ${TARGET_NAME} FILEDESCRIPTION "OpenVINO runtime lib

target_include_directories(${TARGET_NAME} PUBLIC
$<BUILD_INTERFACE:${OpenVINO_SOURCE_DIR}/src/core/include>
$<BUILD_INTERFACE:${OpenVINO_SOURCE_DIR}/src/inference/include>
$<BUILD_INTERFACE:${OpenVINO_SOURCE_DIR}/src/frontends/common/include>)

# to be aligned with OpenVINO archive, where all headers are located in the same folder and
# exposed via openvino::runtime
target_include_directories(${TARGET_NAME} INTERFACE
$<BUILD_INTERFACE:${OpenVINO_SOURCE_DIR}/src/frontends/common/include>
$<BUILD_INTERFACE:${OpenVINO_SOURCE_DIR}/src/inference/include>)
$<BUILD_INTERFACE:${OpenVINO_SOURCE_DIR}/src/frontends/onnx/frontend/include>
$<BUILD_INTERFACE:${OpenVINO_SOURCE_DIR}/src/frontends/paddle/include>
$<BUILD_INTERFACE:${OpenVINO_SOURCE_DIR}/src/frontends/pytorch/include>
$<BUILD_INTERFACE:${OpenVINO_SOURCE_DIR}/src/frontends/tensorflow/include>
$<BUILD_INTERFACE:${OpenVINO_SOURCE_DIR}/src/frontends/tensorflow_lite/include>)

target_link_libraries(${TARGET_NAME} PRIVATE openvino::reference
openvino::shape_inference
Expand Down
15 changes: 12 additions & 3 deletions src/core/include/openvino/core/extension.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "openvino/core/core_visibility.hpp"
#include "openvino/core/type.hpp"
#include "openvino/core/rtti.hpp"

#define OPENVINO_EXTENSION_C_API OPENVINO_EXTERN_C OPENVINO_CORE_EXPORTS
#define OPENVINO_EXTENSION_API OPENVINO_CORE_EXPORTS
Expand All @@ -24,6 +25,14 @@ class Extension;
*/
class OPENVINO_API Extension {
public:
_OPENVINO_HIDDEN_METHOD static const DiscreteTypeInfo& get_type_info_static() {
static const ::ov::DiscreteTypeInfo type_info_static{"Extension"};
return type_info_static;
}
virtual const DiscreteTypeInfo& get_type_info() const {
return get_type_info_static();
}

using Ptr = std::shared_ptr<Extension>;

virtual ~Extension();
Expand All @@ -37,15 +46,15 @@ class OPENVINO_API Extension {
/**
* @brief The entry point for library with OpenVINO extensions
*
* @param vector of extensions
* @param ext of extensions
*/
OPENVINO_EXTENSION_C_API
void OV_CREATE_EXTENSION(std::vector<ov::Extension::Ptr>&);
void OV_CREATE_EXTENSION(std::vector<ov::Extension::Ptr>& ext);

/**
* @brief Macro generates the entry point for the library
*
* @param vector of extensions
* @param ext of extensions
*/
#define OPENVINO_CREATE_EXTENSIONS(extensions) \
OPENVINO_EXTENSION_C_API void OV_CREATE_EXTENSION(std::vector<ov::Extension::Ptr>& ext); \
Expand Down
7 changes: 1 addition & 6 deletions src/core/include/openvino/core/op_extension.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@ namespace ov {
class OPENVINO_API BaseOpExtension : public Extension {
public:
using Ptr = std::shared_ptr<BaseOpExtension>;
/**
* @brief Returns the type info of operation
*
* @return ov::DiscreteTypeInfo
*/
virtual const ov::DiscreteTypeInfo& get_type_info() const = 0;

/**
* @brief Method creates an OpenVINO operation
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class OPENVINO_API TensorInfoMemoryType : public RuntimeAttribute {

TensorInfoMemoryType() = default;

~TensorInfoMemoryType() override;

explicit TensorInfoMemoryType(const std::string& value) : value(value) {}

bool visit_attributes(AttributeVisitor& visitor) override {
Expand Down
4 changes: 0 additions & 4 deletions src/core/include/openvino/core/rtti.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,3 @@
_OPENVINO_RTTI_WITH_TYPE_VERSION_PARENT, \
_OPENVINO_RTTI_WITH_TYPE_VERSION, \
_OPENVINO_RTTI_WITH_TYPE)(__VA_ARGS__))

/// Note: Please don't use this macros for new operations
#define BWDCMP_RTTI_DECLARATION
#define BWDCMP_RTTI_DEFINITION(CLASS)
4 changes: 1 addition & 3 deletions src/core/include/openvino/core/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,7 @@ OPENVINO_API
std::ostream& operator<<(std::ostream& s, const Shape& shape);

template <>
class OPENVINO_API AttributeAdapter<ov::Shape> : public IndirectVectorValueAccessor<ov::Shape, std::vector<int64_t>>

{
class OPENVINO_API AttributeAdapter<ov::Shape> : public IndirectVectorValueAccessor<ov::Shape, std::vector<int64_t>> {
public:
OPENVINO_RTTI("AttributeAdapter<Shape>");

Expand Down
33 changes: 20 additions & 13 deletions src/core/include/openvino/core/type.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,18 @@ struct OPENVINO_API DiscreteTypeInfo {
OPENVINO_API
std::ostream& operator<<(std::ostream& s, const DiscreteTypeInfo& info);

namespace frontend {
class ConversionExtensionBase;
} // frontend

template <typename T>
constexpr bool use_ov_dynamic_cast() {
#if defined(__ANDROID__) || defined(ANDROID)
# define OPENVINO_DYNAMIC_CAST
return true;
#else
return std::is_base_of_v<ov::frontend::ConversionExtensionBase, T>;
#endif
}

/// \brief Tests if value is a pointer/shared_ptr that can be statically cast to a
/// Type*/shared_ptr<Type>
Expand All @@ -97,11 +106,10 @@ template <typename Type, typename Value>
typename std::enable_if<std::is_convertible<decltype(static_cast<Type*>(std::declval<Value>())), Type*>::value,
Type*>::type
as_type(Value value) {
#ifdef OPENVINO_DYNAMIC_CAST
return ov::is_type<Type>(value) ? static_cast<Type*>(value) : nullptr;
#else
return dynamic_cast<Type*>(value);
#endif
if constexpr (use_ov_dynamic_cast<Type>())
return is_type<Type>(value) ? static_cast<Type*>(value) : nullptr;
else
return dynamic_cast<Type*>(value);
}

namespace util {
Expand All @@ -120,13 +128,12 @@ struct AsTypePtr<std::shared_ptr<In>> {

/// Casts a std::shared_ptr<Value> to a std::shared_ptr<Type> if it is of type
/// Type, nullptr otherwise
template <typename T, typename U>
auto as_type_ptr(const U& value) -> decltype(::ov::util::AsTypePtr<U>::template call<T>(value)) {
#ifdef OPENVINO_DYNAMIC_CAST
return ::ov::util::AsTypePtr<U>::template call<T>(value);
#else
return std::dynamic_pointer_cast<T>(value);
#endif
template <typename Type, typename Value>
auto as_type_ptr(const Value& value) -> decltype(::ov::util::AsTypePtr<Value>::template call<Type>(value)) {
if constexpr (use_ov_dynamic_cast<Type>())
return ::ov::util::AsTypePtr<Value>::template call<Type>(value);
else
return std::dynamic_pointer_cast<Type>(value);
}
} // namespace ov

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ class OPENVINO_API PrecisionSensitive : public RuntimeAttribute {

PrecisionSensitive() = default;

~PrecisionSensitive() override;

bool is_copyable() const override {
return false;
}
Expand Down
1 change: 1 addition & 0 deletions src/core/include/openvino/op/util/symbolic_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class OPENVINO_API SkipInvalidation : public RuntimeAttribute {
public:
OPENVINO_RTTI("SkipInvalidation", "0", RuntimeAttribute);
SkipInvalidation() = default;
~SkipInvalidation() override;
bool is_copyable() const override {
return false;
}
Expand Down
2 changes: 2 additions & 0 deletions src/core/src/op/util/precision_sensitive_attribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

#include "openvino/op/util/precision_sensitive_attribute.hpp"

ov::PrecisionSensitive::~PrecisionSensitive() = default;

void ov::mark_as_precision_sensitive(ov::Input<ov::Node> node_input) {
auto& rt_info = node_input.get_rt_info();
rt_info[PrecisionSensitive::get_type_info_static()] = PrecisionSensitive{};
Expand Down
2 changes: 2 additions & 0 deletions src/core/src/op/util/symbolic_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

#include "openvino/op/util/multi_subgraph_base.hpp"

ov::SkipInvalidation::~SkipInvalidation() = default;

void ov::skip_invalidation(const ov::Output<ov::Node>& output) {
output.get_tensor().get_rt_info()[ov::SkipInvalidation::get_type_info_static()] = nullptr;
}
Expand Down
3 changes: 3 additions & 0 deletions src/core/src/preprocess/pre_post_process.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ std::shared_ptr<Model> PrePostProcessor::build() {
return function;
}

// ------------------ TensorInfoMemoryType ----------------
TensorInfoMemoryType::~TensorInfoMemoryType() = default;

// --------------------- InputTensorInfo ------------------
InputTensorInfo::InputTensorInfo() : m_impl(std::unique_ptr<InputTensorInfoImpl>(new InputTensorInfoImpl())) {}
InputTensorInfo::~InputTensorInfo() = default;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
namespace ov {
namespace frontend {

class FRONTEND_API ConversionExtensionBase : public ov::Extension {
class FRONTEND_API ConversionExtensionBase : public Extension {
public:
OPENVINO_RTTI("ConversionExtensionBase", "0", Extension);

using Ptr = std::shared_ptr<ConversionExtensionBase>;
explicit ConversionExtensionBase(const std::string& op_type) : m_op_type(op_type) {}

Expand All @@ -28,6 +30,8 @@ class FRONTEND_API ConversionExtensionBase : public ov::Extension {

class FRONTEND_API ConversionExtension : public ConversionExtensionBase {
public:
OPENVINO_RTTI("ConversionExtension", "", ConversionExtensionBase);

using Ptr = std::shared_ptr<ConversionExtension>;
ConversionExtension(const std::string& op_type, const CreatorFunction& converter)
: ConversionExtensionBase(op_type),
Expand Down
3 changes: 2 additions & 1 deletion src/frontends/ir/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,9 @@ void FrontEnd::add_extension(const ov::Extension::Ptr& ext) {
if (std::dynamic_pointer_cast<ov::BaseOpExtension>(so_ext->extension())) {
m_extensions.emplace_back(so_ext->extension());
}
} else if (std::dynamic_pointer_cast<ov::BaseOpExtension>(ext))
} else if (std::dynamic_pointer_cast<ov::BaseOpExtension>(ext)) {
m_extensions.emplace_back(ext);
}
}

InputModel::Ptr FrontEnd::load_impl(const std::vector<ov::Any>& variants) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ namespace ov {
namespace frontend {
namespace jax {

class JAX_FRONTEND_API ConversionExtension : public ConversionExtensionBase {
class ConversionExtension : public ConversionExtensionBase {
public:
OPENVINO_RTTI("frontend::jax::ConversionExtension", "", ConversionExtensionBase);

using Ptr = std::shared_ptr<ConversionExtension>;

ConversionExtension() = delete;
Expand All @@ -27,8 +29,6 @@ class JAX_FRONTEND_API ConversionExtension : public ConversionExtensionBase {
return m_converter;
}

~ConversionExtension() override;

private:
ov::frontend::CreatorFunction m_converter;
};
Expand Down
7 changes: 0 additions & 7 deletions src/frontends/jax/src/extensions.cpp

This file was deleted.

4 changes: 2 additions & 2 deletions src/frontends/jax/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,12 @@ std::shared_ptr<Model> FrontEnd::decode(const InputModel::Ptr& model) const {
}

void FrontEnd::add_extension(const std::shared_ptr<ov::Extension>& extension) {
if (auto conv_ext = std::dynamic_pointer_cast<ov::frontend::ConversionExtension>(extension)) {
if (auto conv_ext = ov::as_type_ptr<ov::frontend::ConversionExtension>(extension)) {
m_conversion_extensions.push_back(conv_ext);
m_op_extension_translators[conv_ext->get_op_type()] = [=](const NodeContext& context) {
return conv_ext->get_converter()(context);
};
} else if (auto conv_ext = std::dynamic_pointer_cast<ov::frontend::jax::ConversionExtension>(extension)) {
} else if (auto conv_ext = ov::as_type_ptr<ov::frontend::jax::ConversionExtension>(extension)) {
m_conversion_extensions.push_back(conv_ext);
m_op_extension_translators[conv_ext->get_op_type()] = [=](const NodeContext& context) {
return conv_ext->get_converter()(context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@
namespace ov {
namespace frontend {
namespace onnx {
class ONNX_FRONTEND_API ConversionExtension : public ConversionExtensionBase {
class ConversionExtension : public ConversionExtensionBase {
public:
OPENVINO_RTTI("frontend::onnx::ConversionExtension", "", ConversionExtensionBase);

using Ptr = std::shared_ptr<ConversionExtension>;

ConversionExtension(const std::string& op_type, const ov::frontend::CreatorFunction& converter)
Expand All @@ -26,8 +28,6 @@ class ONNX_FRONTEND_API ConversionExtension : public ConversionExtensionBase {
m_domain{domain},
m_converter(converter) {}

~ConversionExtension() override;

const std::string& get_domain() const {
return m_domain;
}
Expand All @@ -37,9 +37,10 @@ class ONNX_FRONTEND_API ConversionExtension : public ConversionExtensionBase {
}

private:
std::string m_domain = "";
std::string m_domain;
ov::frontend::CreatorFunction m_converter;
};

} // namespace onnx
} // namespace frontend
} // namespace ov
4 changes: 2 additions & 2 deletions src/frontends/onnx/frontend/src/core/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ bool common_node_for_all_outputs(const ov::OutputVector& outputs) {
OperatorsBridge register_extensions(OperatorsBridge& bridge,
const std::vector<ov::frontend::ConversionExtensionBase::Ptr>& conversions) {
for (const auto& extension : conversions) {
if (const auto common_conv_ext = std::dynamic_pointer_cast<ov::frontend::ConversionExtension>(extension)) {
if (const auto common_conv_ext = ov::as_type_ptr<ov::frontend::ConversionExtension>(extension)) {
bridge.overwrite_operator(
common_conv_ext->get_op_type(),
"",
[common_conv_ext](const ov::frontend::onnx::Node& node) -> ov::OutputVector {
return common_conv_ext->get_converter()(ov::frontend::onnx::NodeContext(node));
});
} else if (const auto onnx_conv_ext =
std::dynamic_pointer_cast<ov::frontend::onnx::ConversionExtension>(extension)) {
ov::as_type_ptr<ov::frontend::onnx::ConversionExtension>(extension)) {
bridge.overwrite_operator(onnx_conv_ext->get_op_type(),
onnx_conv_ext->get_domain(),
[onnx_conv_ext](const ov::frontend::onnx::Node& node) -> ov::OutputVector {
Expand Down
7 changes: 0 additions & 7 deletions src/frontends/onnx/frontend/src/extensions.cpp

This file was deleted.

4 changes: 2 additions & 2 deletions src/frontends/onnx/frontend/src/frontend.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,9 @@ void FrontEnd::add_extension(const std::shared_ptr<ov::Extension>& extension) {
} else if (const auto& so_ext = std::dynamic_pointer_cast<ov::detail::SOExtension>(extension)) {
add_extension(so_ext->extension());
m_other_extensions.push_back(so_ext);
} else if (auto common_conv_ext = std::dynamic_pointer_cast<ov::frontend::ConversionExtension>(extension)) {
} else if (auto common_conv_ext = ov::as_type_ptr<ov::frontend::ConversionExtension>(extension)) {
m_extensions.conversions.push_back(common_conv_ext);
} else if (const auto onnx_conv_ext = std::dynamic_pointer_cast<onnx::ConversionExtension>(extension)) {
} else if (const auto onnx_conv_ext = ov::as_type_ptr<onnx::ConversionExtension>(extension)) {
m_extensions.conversions.push_back(onnx_conv_ext);
} else if (auto progress_reporter = std::dynamic_pointer_cast<ProgressReporterExtension>(extension)) {
m_extensions.progress_reporter = progress_reporter;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@ namespace ov {
namespace frontend {
namespace paddle {

class PADDLE_FRONTEND_API ConversionExtension : public ConversionExtensionBase {
class ConversionExtension : public ConversionExtensionBase {
public:
OPENVINO_RTTI("frontend::paddle::ConversionExtension", "", ConversionExtensionBase);

using Ptr = std::shared_ptr<ConversionExtension>;

ConversionExtension() = delete;
Expand All @@ -23,8 +25,6 @@ class PADDLE_FRONTEND_API ConversionExtension : public ConversionExtensionBase {
: ConversionExtensionBase(op_type),
m_converter(converter) {}

~ConversionExtension() override;

const ov::frontend::CreatorFunctionNamed& get_converter() const {
return m_converter;
}
Expand Down
Loading

0 comments on commit b5a8d80

Please sign in to comment.