Skip to content

Commit

Permalink
write tests
Browse files Browse the repository at this point in the history
Former-commit-id: a0ba8de
Former-commit-id: fd09617aae231bb65c51eff296590d398452e30f
  • Loading branch information
akshayballal95 committed Nov 5, 2024
1 parent 158c0e7 commit 195bae9
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 4 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion python/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "embed_anything_python"
version = "0.4.10"
version = "0.4.12"
edition = "2021"

[lib]
Expand Down
2 changes: 1 addition & 1 deletion rust/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "embed_anything"
version = "0.3.4"
version = "0.4.12"
edition.workspace = true
license.workspace = true
description.workspace = true
Expand Down
11 changes: 11 additions & 0 deletions tests/model_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
EmbedData,
EmbeddingModel,
WhichModel,
ColpaliModel
)


Expand Down Expand Up @@ -90,6 +91,16 @@ def onnx_model() -> EmbeddingModel:
model = EmbeddingModel.from_pretrained_onnx(WhichModel.Bert, "BGESmallENV15Q")
return model

@pytest.fixture
def colpali_onnx_model() -> ColpaliModel:
model = ColpaliModel.from_pretrained_onnx(model_id="akshayballal/colpali-v1.2-merged-onnx")
return model

@pytest.fixture
def colpali_model() -> ColpaliModel:
model = ColpaliModel.from_pretrained("vidore/colpali-v1.2-merged")
return model

class DummyAdapter(Adapter):

def create_index(self, dimension: int, metric: str, index_name: str, **kwargs):
Expand Down
8 changes: 8 additions & 0 deletions tests/model_tests/test_colpali.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
import pytest
from embed_anything import ColpaliModel

@pytest.mark.parametrize("model_fixture", ['colpali_model', 'colpali_onnx_model'])
def test_colpali_model_file(model_fixture, test_pdf_file, request):
model:ColpaliModel = request.getfixturevalue(model_fixture)
data = model.embed_file(test_pdf_file, batch_size=1)
assert len(data) == 1

0 comments on commit 195bae9

Please sign in to comment.