Skip to content

Commit

Permalink
improve onnx tests
Browse files Browse the repository at this point in the history
Former-commit-id: a1c496b
Former-commit-id: e7afaf8c6d7068597fe0c48fa6a5ff03a4d9c86f
  • Loading branch information
akshayballal95 committed Nov 5, 2024
1 parent 98b4360 commit 158c0e7
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 29 deletions.
2 changes: 1 addition & 1 deletion rust/examples/audio.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::sync::Arc;

use embed_anything::{
config::{self, TextEmbedConfig},
config::TextEmbedConfig,
emb_audio,
embeddings::embed::Embedder,
file_processor::audio::audio_processor::AudioDecoderModel,
Expand Down
4 changes: 4 additions & 0 deletions tests/model_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def openai_model() -> EmbeddingModel:
)
return model

@pytest.fixture
def onnx_model() -> EmbeddingModel:
model = EmbeddingModel.from_pretrained_onnx(WhichModel.Bert, "BGESmallENV15Q")
return model

class DummyAdapter(Adapter):

Expand Down
65 changes: 37 additions & 28 deletions tests/model_tests/test_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,31 @@
import os
import pytest
import tempfile


@pytest.mark.parametrize(
"config", [None, TextEmbedConfig(batch_size=32, chunk_size=256)]
import itertools

# Global test parameters
MODEL_FIXTURES = ['bert_model', 'onnx_model']
CONFIGS = [None, TextEmbedConfig(batch_size=32, chunk_size=256)]
ALL_COMBINATIONS = list(itertools.product(MODEL_FIXTURES, CONFIGS))

# Define common parametrize decorator
model_fixture_parametrize = pytest.mark.parametrize("model_fixture", MODEL_FIXTURES)
model_and_config_parametrize = pytest.mark.parametrize(
"model_fixture,config",
ALL_COMBINATIONS
)
def test_bert_model_file(bert_model, config, test_pdf_file):
data = embed_file(test_pdf_file, bert_model, config)
path = os.path.abspath(test_pdf_file)

@model_and_config_parametrize
def test_bert_model_file(model_fixture, config, test_pdf_file, request):
model = request.getfixturevalue(model_fixture)
data = embed_file(test_pdf_file, model, config)
path = os.path.abspath(test_pdf_file)

assert len(data) > 0
assert data[0].embedding is not None
assert len(data[0].embedding) == 384
assert data[0].metadata["file_name"] == path


def test_bert_model_creation():

model = EmbeddingModel.from_pretrained_hf(
Expand All @@ -34,60 +44,59 @@ def test_bert_model_creation():
)
assert model is not None

def test_onnx_model_creation():
model = EmbeddingModel.from_pretrained_onnx(WhichModel.Bert, "BGESmallENV15Q")
assert model is not None

def test_bert_model_query(bert_model):

data = embed_query(["Photo of a monkey?"], bert_model)
@model_fixture_parametrize
def test_bert_model_query(model_fixture, request):
model = request.getfixturevalue(model_fixture)
data = embed_query(["Photo of a monkey?"], model)
assert len(data) == 1
assert data[0].embedding is not None
assert len(data[0].embedding) == 384


@pytest.mark.parametrize(
"config", [None, TextEmbedConfig(batch_size=32, chunk_size=256)]
)
def test_bert_model_directory(bert_model, config, test_text_directory):

data = embed_directory(test_text_directory, bert_model, config=config)
@model_and_config_parametrize
def test_bert_model_directory(model_fixture, config, test_text_directory, request):
model = request.getfixturevalue(model_fixture)
data = embed_directory(test_text_directory, model, config=config)
assert data[0].embedding is not None
assert len(data[0].embedding) == 384


def test_bert_model_empty_query(bert_model):
data = embed_query([""], bert_model)
@model_fixture_parametrize
def test_bert_model_empty_query(model_fixture, request):
model = request.getfixturevalue(model_fixture)
data = embed_query([""], model)
assert len(data) == 1
assert data[0].embedding is not None
assert len(data[0].embedding) == 384


def test_bert_model_long_query(bert_model):
@model_fixture_parametrize
def test_bert_model_long_query(model_fixture, request):
model = request.getfixturevalue(model_fixture)
long_text = " ".join(["long"] * 1000)
data = embed_query([long_text], bert_model)
data = embed_query([long_text], model)
assert len(data) == 1
assert data[0].embedding is not None
assert len(data[0].embedding) == 384


def test_bert_model_non_ascii_query(bert_model):
non_ascii_text = "こんにちは世界"
data = embed_query([non_ascii_text], bert_model)
assert len(data) == 1
assert data[0].embedding is not None
assert len(data[0].embedding) == 384


def test_bert_model_nonexistent_file(bert_model):
with pytest.raises(FileNotFoundError):
embed_file("nonexistent_file.txt", bert_model)


def test_bert_model_empty_directory(bert_model, tmp_path):
empty_dir = tmp_path / "empty_dir"
empty_dir.mkdir()
data = embed_directory(str(empty_dir), bert_model)
assert len(data) == 0


def test_bert_model_unsupported_file_type(bert_model, tmp_path):

# Create a file with an unsupported extension
Expand Down

0 comments on commit 158c0e7

Please sign in to comment.