From 93d1fffae175f52bbcc4b32e8d9ef21a6d5ad616 Mon Sep 17 00:00:00 2001 From: Kirill Suvorov Date: Mon, 27 Jan 2025 09:21:49 +0100 Subject: [PATCH] [JS API] Fix inference for outputs with missing names Signed-off-by: Kirill Suvorov --- src/bindings/js/node/src/infer_request.cpp | 16 ++++++- src/bindings/js/node/src/node_output.cpp | 17 ++++++- .../js/node/tests/unit/infer_request.test.js | 44 +++++++++++++++++++ src/bindings/js/node/tests/unit/setup.js | 1 + src/bindings/js/node/tests/unit/utils.js | 18 +++++--- 5 files changed, 87 insertions(+), 9 deletions(-) diff --git a/src/bindings/js/node/src/infer_request.cpp b/src/bindings/js/node/src/infer_request.cpp index 0b5629adac6fe9..7c4bc387071387 100644 --- a/src/bindings/js/node/src/infer_request.cpp +++ b/src/bindings/js/node/src/infer_request.cpp @@ -146,7 +146,13 @@ Napi::Value InferRequestWrap::get_output_tensors(const Napi::CallbackInfo& info) auto tensor = _infer_request.get_tensor(node); auto new_tensor = ov::Tensor(tensor.get_element_type(), tensor.get_shape()); tensor.copy_to(new_tensor); - outputs_obj.Set(node.get_any_name(), TensorWrap::wrap(info.Env(), new_tensor)); + std::string name; + if (node.get_names().empty()) { + name = node.get_node()->get_name(); + } else { + name = node.get_any_name(); + } + outputs_obj.Set(name, TensorWrap::wrap(info.Env(), new_tensor)); } return outputs_obj; } @@ -215,7 +221,13 @@ void performInferenceThread(TsfnContext* context) { const auto& tensor = context->_ir->get_tensor(node); auto new_tensor = ov::Tensor(tensor.get_element_type(), tensor.get_shape()); tensor.copy_to(new_tensor); - outputs.insert({node.get_any_name(), new_tensor}); + std::string name; + if (node.get_names().empty()) { + name = node.get_node()->get_name(); + } else { + name = node.get_any_name(); + } + outputs.insert({name, new_tensor}); } context->result = outputs; diff --git a/src/bindings/js/node/src/node_output.cpp b/src/bindings/js/node/src/node_output.cpp index 3b3087f463a38c..face597a92b396 100644 --- a/src/bindings/js/node/src/node_output.cpp +++ b/src/bindings/js/node/src/node_output.cpp @@ -45,7 +45,14 @@ Napi::Value Output::get_partial_shape(const Napi::CallbackInfo& info) } Napi::Value Output::get_any_name(const Napi::CallbackInfo& info) { - return Napi::String::New(info.Env(), _output.get_any_name()); + std::string name; + if (_output.get_names().empty()) { + name = _output.get_node()->get_name(); + } else { + name = _output.get_any_name(); + } + + return Napi::String::New(info.Env(), name); } Output::Output(const Napi::CallbackInfo& info) @@ -88,5 +95,11 @@ Napi::Value Output::get_partial_shape(const Napi::CallbackInfo& } Napi::Value Output::get_any_name(const Napi::CallbackInfo& info) { - return Napi::String::New(info.Env(), _output.get_any_name()); + std::string name; + if (_output.get_names().empty()) { + name = _output.get_node()->get_name(); + } else { + name = _output.get_any_name(); + } + return Napi::String::New(info.Env(), name); } diff --git a/src/bindings/js/node/tests/unit/infer_request.test.js b/src/bindings/js/node/tests/unit/infer_request.test.js index 224781d4b80431..fca3e31404ce12 100644 --- a/src/bindings/js/node/tests/unit/infer_request.test.js +++ b/src/bindings/js/node/tests/unit/infer_request.test.js @@ -300,3 +300,47 @@ describe('ov.InferRequest tests', () => { }); }); }); + +describe('ov.InferRequest tests with missing outputs names', () => { + const modelV3Small = testModels.modelV3Small; + let compiledModel = null; + let tensorData = null; + let tensor = null; + let inferRequest = null; + + before(async () => { + await isModelAvailable(modelV3Small); + + const fs = require('fs'); + const core = new ov.Core(); + + let model_data = fs.readFileSync(getModelPath(modelV3Small).xml, 'utf8'); + const weights = fs.readFileSync(getModelPath(modelV3Small).bin); + model_data = model_data.replace("names=\"MobilenetV3/Predictions/Softmax:0\"", ""); + const model = core.readModelSync(Buffer.from(model_data, 'utf8'), weights); + + compiledModel = core.compileModelSync(model, 'CPU'); + inferRequest = compiledModel.createInferRequest(); + + tensorData = Float32Array.from( + { length: 150528 }, + () => Math.random() + epsilon, + ); + tensor = new ov.Tensor(ov.element.f32, modelV3Small.inputShape, tensorData); + }); + + it('Test infer(inputData: Tensor[])', () => { + const outputLayer = compiledModel.outputs[0]; + const result = inferRequest.infer([tensor]); + assert.deepStrictEqual(Object.keys(result), [outputLayer.toString()]); + assert.ok(result[outputLayer] instanceof ov.Tensor); + }); + + it('Test inferAsync(inputData: Tensor[])', () => { + inferRequest.inferAsync([tensor]).then((result) => { + const outputLayer = compiledModel.outputs[0]; + assert.deepStrictEqual(Object.keys(result), [outputLayer.toString()]); + assert.ok(result[outputLayer] instanceof ov.Tensor); + }); + }); +}); diff --git a/src/bindings/js/node/tests/unit/setup.js b/src/bindings/js/node/tests/unit/setup.js index b4885d24157abe..34c110c04e9833 100644 --- a/src/bindings/js/node/tests/unit/setup.js +++ b/src/bindings/js/node/tests/unit/setup.js @@ -6,4 +6,5 @@ if (require.main === module) { async function main() { await downloadTestModel(testModels.testModelFP32); + await downloadTestModel(testModels.modelV3Small); } diff --git a/src/bindings/js/node/tests/unit/utils.js b/src/bindings/js/node/tests/unit/utils.js index 232e4939f421c9..525abc6d272685 100644 --- a/src/bindings/js/node/tests/unit/utils.js +++ b/src/bindings/js/node/tests/unit/utils.js @@ -19,6 +19,16 @@ const testModels = { binURL: 'https://media.githubusercontent.com/media/openvinotoolkit/testdata/master/models/test_model/test_model_fp32.bin', }, + modelV3Small: { + xml: 'v3-small_224_1.0_float.xml', + bin: 'v3-small_224_1.0_float.bin', + inputShape: [1, 224, 224, 3], + outputShape: [1, 1001], + xmlURL: + 'https://storage.openvinotoolkit.org/repositories/openvino_notebooks/models/mobelinet-v3-tf/FP32/v3-small_224_1.0_float.xml', + binURL: + 'https://storage.openvinotoolkit.org/repositories/openvino_notebooks/models/mobelinet-v3-tf/FP32/v3-small_224_1.0_float.bin', + }, }; module.exports = { @@ -59,12 +69,10 @@ function sleep(ms) { return new Promise((resolve) => setTimeout(resolve, ms)); } -function getModelPath(isFP16 = false) { - const modelName = `test_model_fp${isFP16 ? 16 : 32}`; - +function getModelPath(model=testModels.testModelFP32) { return { - xml: path.join(modelDir, `${modelName}.xml`), - bin: path.join(modelDir, `${modelName}.bin`), + xml: path.join(modelDir, model.xml), + bin: path.join(modelDir, model.bin), }; }