diff --git a/rust/examples/audio.rs b/rust/examples/audio.rs index 55f114c4..1a3be6a9 100644 --- a/rust/examples/audio.rs +++ b/rust/examples/audio.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use embed_anything::{ - config::{self, TextEmbedConfig}, + config::TextEmbedConfig, emb_audio, embeddings::embed::Embedder, file_processor::audio::audio_processor::AudioDecoderModel, diff --git a/tests/model_tests/conftest.py b/tests/model_tests/conftest.py index 1d2bff8b..f3ab6a93 100644 --- a/tests/model_tests/conftest.py +++ b/tests/model_tests/conftest.py @@ -85,6 +85,10 @@ def openai_model() -> EmbeddingModel: ) return model +@pytest.fixture +def onnx_model() -> EmbeddingModel: + model = EmbeddingModel.from_pretrained_onnx(WhichModel.Bert, "BGESmallENV15Q") + return model class DummyAdapter(Adapter): diff --git a/tests/model_tests/test_bert.py b/tests/model_tests/test_bert.py index 98f3cadd..6e808d46 100644 --- a/tests/model_tests/test_bert.py +++ b/tests/model_tests/test_bert.py @@ -10,21 +10,31 @@ import os import pytest import tempfile - - -@pytest.mark.parametrize( - "config", [None, TextEmbedConfig(batch_size=32, chunk_size=256)] +import itertools + +# Global test parameters +MODEL_FIXTURES = ['bert_model', 'onnx_model'] +CONFIGS = [None, TextEmbedConfig(batch_size=32, chunk_size=256)] +ALL_COMBINATIONS = list(itertools.product(MODEL_FIXTURES, CONFIGS)) + +# Define common parametrize decorator +model_fixture_parametrize = pytest.mark.parametrize("model_fixture", MODEL_FIXTURES) +model_and_config_parametrize = pytest.mark.parametrize( + "model_fixture,config", + ALL_COMBINATIONS ) -def test_bert_model_file(bert_model, config, test_pdf_file): - data = embed_file(test_pdf_file, bert_model, config) - path = os.path.abspath(test_pdf_file) +@model_and_config_parametrize +def test_bert_model_file(model_fixture, config, test_pdf_file, request): + model = request.getfixturevalue(model_fixture) + data = embed_file(test_pdf_file, model, config) + path = os.path.abspath(test_pdf_file) + assert len(data) > 0 assert data[0].embedding is not None assert len(data[0].embedding) == 384 assert data[0].metadata["file_name"] == path - def test_bert_model_creation(): model = EmbeddingModel.from_pretrained_hf( @@ -34,40 +44,42 @@ def test_bert_model_creation(): ) assert model is not None +def test_onnx_model_creation(): + model = EmbeddingModel.from_pretrained_onnx(WhichModel.Bert, "BGESmallENV15Q") + assert model is not None -def test_bert_model_query(bert_model): - - data = embed_query(["Photo of a monkey?"], bert_model) +@model_fixture_parametrize +def test_bert_model_query(model_fixture, request): + model = request.getfixturevalue(model_fixture) + data = embed_query(["Photo of a monkey?"], model) assert len(data) == 1 assert data[0].embedding is not None assert len(data[0].embedding) == 384 - -@pytest.mark.parametrize( - "config", [None, TextEmbedConfig(batch_size=32, chunk_size=256)] -) -def test_bert_model_directory(bert_model, config, test_text_directory): - - data = embed_directory(test_text_directory, bert_model, config=config) +@model_and_config_parametrize +def test_bert_model_directory(model_fixture, config, test_text_directory, request): + model = request.getfixturevalue(model_fixture) + data = embed_directory(test_text_directory, model, config=config) assert data[0].embedding is not None assert len(data[0].embedding) == 384 - -def test_bert_model_empty_query(bert_model): - data = embed_query([""], bert_model) +@model_fixture_parametrize +def test_bert_model_empty_query(model_fixture, request): + model = request.getfixturevalue(model_fixture) + data = embed_query([""], model) assert len(data) == 1 assert data[0].embedding is not None assert len(data[0].embedding) == 384 - -def test_bert_model_long_query(bert_model): +@model_fixture_parametrize +def test_bert_model_long_query(model_fixture, request): + model = request.getfixturevalue(model_fixture) long_text = " ".join(["long"] * 1000) - data = embed_query([long_text], bert_model) + data = embed_query([long_text], model) assert len(data) == 1 assert data[0].embedding is not None assert len(data[0].embedding) == 384 - def test_bert_model_non_ascii_query(bert_model): non_ascii_text = "こんにちは世界" data = embed_query([non_ascii_text], bert_model) @@ -75,19 +87,16 @@ def test_bert_model_non_ascii_query(bert_model): assert data[0].embedding is not None assert len(data[0].embedding) == 384 - def test_bert_model_nonexistent_file(bert_model): with pytest.raises(FileNotFoundError): embed_file("nonexistent_file.txt", bert_model) - def test_bert_model_empty_directory(bert_model, tmp_path): empty_dir = tmp_path / "empty_dir" empty_dir.mkdir() data = embed_directory(str(empty_dir), bert_model) assert len(data) == 0 - def test_bert_model_unsupported_file_type(bert_model, tmp_path): # Create a file with an unsupported extension