Skip to content

Commit

Permalink
Defer weights loading by default in case of cip
Browse files Browse the repository at this point in the history
Signed-off-by: Bogdan Pereanu <bogdan.pereanu@intel.com>
  • Loading branch information
pereanub committed Jan 10, 2025
1 parent 65e6ab4 commit 2995fd4
Show file tree
Hide file tree
Showing 7 changed files with 21 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@ namespace intel_npu {

class ICompilerAdapter {
public:
virtual std::shared_ptr<IGraph> compile(const std::shared_ptr<const ov::Model>& model,
const Config& config) const = 0;
virtual std::shared_ptr<IGraph> parse(std::vector<uint8_t> network, const Config& config) const = 0;
virtual ov::SupportedOpsMap query(const std::shared_ptr<const ov::Model>& model, const Config& config) const = 0;
virtual std::shared_ptr<IGraph> compile(const std::shared_ptr<const ov::Model>& model, Config& config) const = 0;
virtual std::shared_ptr<IGraph> parse(std::vector<uint8_t> network, Config& config) const = 0;
virtual ov::SupportedOpsMap query(const std::shared_ptr<const ov::Model>& model, Config& config) const = 0;

virtual ~ICompilerAdapter() = default;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ class DriverCompilerAdapter final : public ICompilerAdapter {
public:
DriverCompilerAdapter(const std::shared_ptr<ZeroInitStructsHolder>& zeroInitStruct);

std::shared_ptr<IGraph> compile(const std::shared_ptr<const ov::Model>& model, const Config& config) const override;
std::shared_ptr<IGraph> compile(const std::shared_ptr<const ov::Model>& model, Config& config) const override;

std::shared_ptr<IGraph> parse(std::vector<uint8_t> network, const Config& config) const override;
std::shared_ptr<IGraph> parse(std::vector<uint8_t> network, Config& config) const override;

ov::SupportedOpsMap query(const std::shared_ptr<const ov::Model>& model, const Config& config) const override;
ov::SupportedOpsMap query(const std::shared_ptr<const ov::Model>& model, Config& config) const override;

private:
/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ class PluginCompilerAdapter final : public ICompilerAdapter {
public:
PluginCompilerAdapter(const std::shared_ptr<ZeroInitStructsHolder>& zeroInitStruct);

std::shared_ptr<IGraph> compile(const std::shared_ptr<const ov::Model>& model, const Config& config) const override;
std::shared_ptr<IGraph> compile(const std::shared_ptr<const ov::Model>& model, Config& config) const override;

std::shared_ptr<IGraph> parse(std::vector<uint8_t> network, const Config& config) const override;
std::shared_ptr<IGraph> parse(std::vector<uint8_t> network, Config& config) const override;

ov::SupportedOpsMap query(const std::shared_ptr<const ov::Model>& model, const Config& config) const override;
ov::SupportedOpsMap query(const std::shared_ptr<const ov::Model>& model, Config& config) const override;

private:
std::shared_ptr<ZeroInitStructsHolder> _zeroInitStruct;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class PluginGraph final : public IGraph {
ze_graph_handle_t graphHandle,
NetworkMetadata metadata,
std::vector<uint8_t> blob,
const Config& config);
Config& config);

void export_blob(std::ostream& stream) const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ DriverCompilerAdapter::DriverCompilerAdapter(const std::shared_ptr<ZeroInitStruc
}

std::shared_ptr<IGraph> DriverCompilerAdapter::compile(const std::shared_ptr<const ov::Model>& model,
const Config& config) const {
Config& config) const {
OV_ITT_TASK_CHAIN(COMPILE_BLOB, itt::domains::NPUPlugin, "DriverCompilerAdapter", "compile");

const ze_graph_compiler_version_info_t& compilerVersion = _deviceGraphProperties.compilerVersion;
Expand Down Expand Up @@ -203,7 +203,7 @@ std::shared_ptr<IGraph> DriverCompilerAdapter::compile(const std::shared_ptr<con
std::nullopt);
}

std::shared_ptr<IGraph> DriverCompilerAdapter::parse(std::vector<uint8_t> network, const Config& config) const {
std::shared_ptr<IGraph> DriverCompilerAdapter::parse(std::vector<uint8_t> network, Config& config) const {
OV_ITT_TASK_CHAIN(PARSE_BLOB, itt::domains::NPUPlugin, "DriverCompilerAdapter", "parse");

_logger.debug("parse start");
Expand All @@ -221,8 +221,7 @@ std::shared_ptr<IGraph> DriverCompilerAdapter::parse(std::vector<uint8_t> networ
std::optional<std::vector<uint8_t>>(std::move(network)));
}

ov::SupportedOpsMap DriverCompilerAdapter::query(const std::shared_ptr<const ov::Model>& model,
const Config& config) const {
ov::SupportedOpsMap DriverCompilerAdapter::query(const std::shared_ptr<const ov::Model>& model, Config& config) const {
OV_ITT_TASK_CHAIN(query_BLOB, itt::domains::NPUPlugin, "DriverCompilerAdapter", "query");

const ze_graph_compiler_version_info_t& compilerVersion = _deviceGraphProperties.compilerVersion;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ PluginCompilerAdapter::PluginCompilerAdapter(const std::shared_ptr<ZeroInitStruc
}

std::shared_ptr<IGraph> PluginCompilerAdapter::compile(const std::shared_ptr<const ov::Model>& model,
const Config& config) const {
Config& config) const {
OV_ITT_TASK_CHAIN(COMPILE_BLOB, itt::domains::NPUPlugin, "PluginCompilerAdapter", "compile");

_logger.debug("compile start");
Expand Down Expand Up @@ -103,7 +103,7 @@ std::shared_ptr<IGraph> PluginCompilerAdapter::compile(const std::shared_ptr<con
config);
}

std::shared_ptr<IGraph> PluginCompilerAdapter::parse(std::vector<uint8_t> network, const Config& config) const {
std::shared_ptr<IGraph> PluginCompilerAdapter::parse(std::vector<uint8_t> network, Config& config) const {
OV_ITT_TASK_CHAIN(PARSE_BLOB, itt::domains::NPUPlugin, "PluginCompilerAdapter", "parse");

_logger.debug("parse start");
Expand All @@ -125,8 +125,7 @@ std::shared_ptr<IGraph> PluginCompilerAdapter::parse(std::vector<uint8_t> networ
config);
}

ov::SupportedOpsMap PluginCompilerAdapter::query(const std::shared_ptr<const ov::Model>& model,
const Config& config) const {
ov::SupportedOpsMap PluginCompilerAdapter::query(const std::shared_ptr<const ov::Model>& model, Config& config) const {
OV_ITT_TASK_CHAIN(QUERY_BLOB, itt::domains::NPUPlugin, "PluginCompilerAdapter", "query");

return _compiler->query(model, config);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@ PluginGraph::PluginGraph(const std::shared_ptr<ZeGraphExtWrappers>& zeGraphExt,
ze_graph_handle_t graphHandle,
NetworkMetadata metadata,
std::vector<uint8_t> blob,
const Config& config)
Config& config)
: IGraph(graphHandle, std::move(metadata), config, std::optional<std::vector<uint8_t>>(std::move(blob))),
_zeGraphExt(zeGraphExt),
_zeroInitStruct(zeroInitStruct),
_compiler(compiler),
_logger("PluginGraph", config.get<LOG_LEVEL>()) {
if (!(config.has<CREATE_EXECUTOR>() || config.has<DEFER_WEIGHTS_LOAD>())) {
config.update({{ov::intel_npu::defer_weights_load.name(), "YES"}});
}

if (!config.get<CREATE_EXECUTOR>() || config.get<DEFER_WEIGHTS_LOAD>()) {
_logger.info("Graph initialize is deferred from the \"Graph\" constructor");
return;
Expand Down

0 comments on commit 2995fd4

Please sign in to comment.