diff --git a/.vscode/launch.json b/.vscode/launch.json index 52ce7b9..bc4742c 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -1,5 +1,5 @@ { - "version": "0.2.0", + "version": "0.2.1", "configurations": [ { "name": "Matchbox: Debug", diff --git a/docs/contributing.md b/docs/contributing.md index 74ec2d0..2ac2f25 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -54,8 +54,8 @@ docker compose up -d --wait We have a VSCode default debugging profile called "API debug", which allows to e.g. set breakpoints on the API when running tests. After running this profile, change your `.env` file as follows: -- change the `MB__CLIENT__API_ROOT` variable to redirect tests to use the debug port (8080) -- disable time-outs by commenting out the `MB__CLIENT__TIMEOUT` variable +- Change the `MB__CLIENT__API_ROOT` variable to redirect tests to use the debug port (8080) +- Disable time-outs by commenting out the `MB__CLIENT__TIMEOUT` variable ## Standards diff --git a/environments/dev_docker.env b/environments/dev_docker.env index 393e54b..c25588e 100644 --- a/environments/dev_docker.env +++ b/environments/dev_docker.env @@ -17,4 +17,5 @@ MB__POSTGRES__DATABASE=matchbox MB__POSTGRES__DB_SCHEMA=mb MB__CLIENT__API_ROOT=http://localhost:8000 -MB__CLIENT__TIMEOUT=10 \ No newline at end of file +MB__CLIENT__TIMEOUT=10 +MB__CLIENT__RETRY_DELAY=2 \ No newline at end of file diff --git a/environments/dev_local.env b/environments/dev_local.env index dce50bf..be87d8a 100644 --- a/environments/dev_local.env +++ b/environments/dev_local.env @@ -17,4 +17,5 @@ MB__POSTGRES__DATABASE=matchbox MB__POSTGRES__DB_SCHEMA=mb MB__CLIENT__API_ROOT=http://localhost:8000 -MB__CLIENT__TIMEOUT=10 \ No newline at end of file +MB__CLIENT__TIMEOUT=10 +MB__CLIENT__RETRY_DELAY=2 diff --git a/environments/sample.env b/environments/sample.env index 7ee538b..3dc3a93 100644 --- a/environments/sample.env +++ b/environments/sample.env @@ -17,4 +17,5 @@ MB__POSTGRES__DATABASE= MB__POSTGRES__DB_SCHEMA= MB__CLIENT__API_ROOT= -MB__CLIENT__TIMEOUT= \ No newline at end of file +MB__CLIENT__TIMEOUT= +MB__CLIENT__RETRY_DELAY= diff --git a/pyproject.toml b/pyproject.toml index e0620bb..6875312 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,3 +107,9 @@ log_cli = false log_cli_level = "INFO" log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)" log_cli_date_format = "%Y-%m-%d %H:%M:%S" +asyncio_mode = "strict" +asyncio_default_fixture_loop_scope = "function" +filterwarnings = [ + 'ignore:.*__fields.*:pydantic.PydanticDeprecatedSince20:unittest.mock', + 'ignore:.*__fields_set__.*:pydantic.PydanticDeprecatedSince20:unittest.mock' +] diff --git a/src/matchbox/__init__.py b/src/matchbox/__init__.py index d9e0a31..8622921 100644 --- a/src/matchbox/__init__.py +++ b/src/matchbox/__init__.py @@ -1,11 +1,13 @@ from dotenv import find_dotenv, load_dotenv -from matchbox.client.helpers.cleaner import process -from matchbox.client.helpers.index import index -from matchbox.client.helpers.selector import match, query -from matchbox.client.models.models import make_model - -__all__ = ("make_model", "process", "query", "match", "index") - dotenv_path = find_dotenv(usecwd=True) load_dotenv(dotenv_path) + +# Environment variables must be loaded first for other imports to work + +from matchbox.client.helpers.cleaner import process # NoQA: E402 +from matchbox.client.helpers.index import index # NoQA: E402 +from matchbox.client.helpers.selector import match, query # NoQA: E402 +from matchbox.client.models.models import make_model # NoQA: E402 + +__all__ = ("make_model", "process", "query", "match", "index") diff --git a/src/matchbox/client/_handler.py b/src/matchbox/client/_handler.py index 34cee5c..e5e2ddf 100644 --- a/src/matchbox/client/_handler.py +++ b/src/matchbox/client/_handler.py @@ -8,9 +8,17 @@ from pyarrow.parquet import read_table from matchbox.common.arrow import SCHEMA_MB_IDS, table_to_buffer -from matchbox.common.dtos import BackendRetrievableType, NotFoundError, UploadStatus +from matchbox.common.dtos import ( + BackendRetrievableType, + ModelAncestor, + ModelMetadata, + ModelOperationStatus, + NotFoundError, + UploadStatus, +) from matchbox.common.exceptions import ( MatchboxClientFileError, + MatchboxDeletionNotConfirmed, MatchboxResolutionNotFoundError, MatchboxServerFileError, MatchboxSourceNotFoundError, @@ -21,23 +29,6 @@ from matchbox.common.hash import hash_to_base64 from matchbox.common.sources import Match, Source, SourceAddress -if timeout := getenv("MB__CLIENT__TIMEOUT"): - CLIENT = httpx.Client(timeout=float(timeout)) -else: - CLIENT = httpx.Client(timeout=None) - - -def url(path: str) -> str: - """Return path prefixed by API root, determined from environment.""" - api_root = getenv("MB__CLIENT__API_ROOT") - if api_root is None: - raise RuntimeError( - "MB__CLIENT__API_ROOT needs to be defined in the environment" - ) - - return api_root + path - - URLEncodeHandledType = str | int | float | bytes @@ -67,7 +58,9 @@ def url_params( def handle_http_code(res: httpx.Response) -> httpx.Response: """Handle HTTP status codes and raise appropriate exceptions.""" - if res.status_code == 200: + res.read() + + if 299 >= res.status_code >= 200: return res if res.status_code == 400: @@ -86,58 +79,36 @@ def handle_http_code(res: httpx.Response) -> httpx.Response: else: raise RuntimeError(f"Unexpected 404 error: {error.details}") + if res.status_code == 409: + error = ModelOperationStatus.model_validate(res.json()) + raise MatchboxDeletionNotConfirmed(message=error.details) + if res.status_code == 422: raise MatchboxUnparsedClientRequest(res.content) raise MatchboxUnhandledServerResponse(res.content) -def get_resolution_graph() -> ResolutionGraph: - """Get the resolution graph from Matchbox.""" - res = handle_http_code(CLIENT.get(url("/report/resolutions"))) - return ResolutionGraph.model_validate(res.json()) - - -def get_source(address: SourceAddress) -> Source: - warehouse_hash_b64 = hash_to_base64(address.warehouse_hash) - res = handle_http_code( - CLIENT.get(url(f"/sources/{warehouse_hash_b64}/{address.full_name}")) - ) - return Source.model_validate(res.json()) - - -def index(source: Source, data_hashes: Table) -> UploadStatus: - """Index a Source in Matchbox.""" - buffer = table_to_buffer(table=data_hashes) - - # Upload metadata - metadata_res = handle_http_code( - CLIENT.post(url("/sources"), json=source.model_dump()) - ) - upload = UploadStatus.model_validate(metadata_res.json()) - - # Upload data - upload_res = handle_http_code( - CLIENT.post( - url(f"/upload/{upload.id}"), - files={ - "file": (f"{upload.id}.parquet", buffer, "application/octet-stream") - }, +def create_client() -> httpx.Client: + """Create an HTTPX client with proper configuration.""" + api_root = getenv("MB__CLIENT__API_ROOT") + timeout = getenv("MB__CLIENT__TIMEOUT") + if api_root is None: + raise RuntimeError( + "MB__CLIENT__API_ROOT needs to be defined in the environment" ) - ) + if timeout is not None: + timeout = float(timeout) - # Poll until complete with retry/timeout configuration - status = UploadStatus.model_validate(upload_res.json()) - while status.status not in ["complete", "failed"]: - status_res = handle_http_code(CLIENT.get(url(f"/upload/{upload.id}/status"))) - status = UploadStatus.model_validate(status_res.json()) + return httpx.Client( + base_url=api_root, timeout=timeout, event_hooks={"response": [handle_http_code]} + ) - if status.status == "failed": - raise MatchboxServerFileError(status.details) - time.sleep(2) +CLIENT = create_client() +DELAY = int(getenv("MB__CLIENT__RETRY_DELAY", 2)) - return status +# Retrieval def query( @@ -146,20 +117,18 @@ def query( threshold: int | None = None, limit: int | None = None, ) -> BytesIO: - res = handle_http_code( - CLIENT.get( - url("/query"), - params=url_params( - { - "full_name": source_address.full_name, - # Converted to b64 by `url_params()` - "warehouse_hash_b64": source_address.warehouse_hash, - "resolution_name": resolution_name, - "threshold": threshold, - "limit": limit, - } - ), - ) + res = CLIENT.get( + "/query", + params=url_params( + { + "full_name": source_address.full_name, + # Converted to b64 by `url_params()` + "warehouse_hash_b64": source_address.warehouse_hash, + "resolution_name": resolution_name, + "threshold": threshold, + "limit": limit, + } + ), ) buffer = BytesIO(res.content) @@ -185,23 +154,156 @@ def match( target_full_names = [t.full_name for t in targets] target_warehouse_hashes = [t.warehouse_hash for t in targets] - res = handle_http_code( - CLIENT.get( - url("/match"), - params=url_params( - { - "target_full_names": target_full_names, - # Converted to b64 by `url_params()` - "target_warehouse_hashes_b64": target_warehouse_hashes, - "source_full_name": source.full_name, - # Converted to b64 by `url_params()` - "source_warehouse_hash_b64": source.warehouse_hash, - "source_pk": source_pk, - "resolution_name": resolution_name, - "threshold": threshold, - } - ), - ) + res = CLIENT.get( + "/match", + params=url_params( + { + "target_full_names": target_full_names, + # Converted to b64 by `url_params()` + "target_warehouse_hashes_b64": target_warehouse_hashes, + "source_full_name": source.full_name, + # Converted to b64 by `url_params()` + "source_warehouse_hash_b64": source.warehouse_hash, + "source_pk": source_pk, + "resolution_name": resolution_name, + "threshold": threshold, + } + ), ) return [Match.model_validate(m) for m in res.json()] + + +# Data management + + +def index(source: Source, data_hashes: Table) -> UploadStatus: + """Index a Source in Matchbox.""" + buffer = table_to_buffer(table=data_hashes) + + # Upload metadata + metadata_res = CLIENT.post("/sources", json=source.model_dump()) + + upload = UploadStatus.model_validate(metadata_res.json()) + + # Upload data + upload_res = CLIENT.post( + f"/upload/{upload.id}", + files={"file": (f"{upload.id}.parquet", buffer, "application/octet-stream")}, + ) + + # Poll until complete with retry/timeout configuration + status = UploadStatus.model_validate(upload_res.json()) + while status.status not in ["complete", "failed"]: + status_res = CLIENT.get(f"/upload/{upload.id}/status") + status = UploadStatus.model_validate(status_res.json()) + + if status.status == "failed": + raise MatchboxServerFileError(status.details) + + time.sleep(DELAY) + + return status + + +def get_source(address: SourceAddress) -> Source: + warehouse_hash_b64 = hash_to_base64(address.warehouse_hash) + res = CLIENT.get(f"/sources/{warehouse_hash_b64}/{address.full_name}") + + return Source.model_validate(res.json()) + + +def get_resolution_graph() -> ResolutionGraph: + """Get the resolution graph from Matchbox.""" + res = CLIENT.get("/report/resolutions") + return ResolutionGraph.model_validate(res.json()) + + +# Model management + + +def insert_model(model: ModelMetadata) -> ModelOperationStatus: + """Insert a model in Matchbox.""" + res = CLIENT.post("/models", json=model.model_dump()) + return ModelOperationStatus.model_validate(res.json()) + + +def get_model(name: str) -> ModelMetadata: + res = CLIENT.get(f"/models/{name}") + return ModelMetadata.model_validate(res.json()) + + +def add_model_results(name: str, results: Table) -> UploadStatus: + """Upload model results in Matchbox.""" + buffer = table_to_buffer(table=results) + + # Initialise upload + metadata_res = CLIENT.post(f"/models/{name}/results") + + upload = UploadStatus.model_validate(metadata_res.json()) + + # Upload data + upload_res = CLIENT.post( + f"/upload/{upload.id}", + files={"file": (f"{upload.id}.parquet", buffer, "application/octet-stream")}, + ) + + # Poll until complete with retry/timeout configuration + status = UploadStatus.model_validate(upload_res.json()) + while status.status not in ["complete", "failed"]: + status_res = CLIENT.get(f"/upload/{upload.id}/status") + status = UploadStatus.model_validate(status_res.json()) + + if status.status == "failed": + raise MatchboxServerFileError(status.details) + + time.sleep(DELAY) + + return status + + +def get_model_results(name: str) -> Table: + """Get model results from Matchbox.""" + res = CLIENT.get(f"/models/{name}/results") + buffer = BytesIO(res.content) + return read_table(buffer) + + +def set_model_truth(name: str, truth: float) -> ModelOperationStatus: + """Set the truth threshold for a model in Matchbox.""" + res = CLIENT.patch(f"/models/{name}/truth", json=truth) + return ModelOperationStatus.model_validate(res.json()) + + +def get_model_truth(name: str) -> float: + """Get the truth threshold for a model in Matchbox.""" + res = CLIENT.get(f"/models/{name}/truth") + return res.json() + + +def get_model_ancestors(name: str) -> list[ModelAncestor]: + """Get the ancestors of a model in Matchbox.""" + res = CLIENT.get(f"/models/{name}/ancestors") + return [ModelAncestor.model_validate(m) for m in res.json()] + + +def set_model_ancestors_cache( + name: str, ancestors: list[ModelAncestor] +) -> ModelOperationStatus: + """Set the ancestors cache for a model in Matchbox.""" + res = CLIENT.post( + f"/models/{name}/ancestors_cache", json=[a.model_dump() for a in ancestors] + ) + return ModelOperationStatus.model_validate(res.json()) + + +def get_model_ancestors_cache(name: str) -> list[ModelAncestor]: + """Get the ancestors cache for a model in Matchbox.""" + res = CLIENT.get(f"/models/{name}/ancestors_cache") + return [ModelAncestor.model_validate(m) for m in res.json()] + + +def delete_model(name: str, certain: bool = False) -> ModelOperationStatus: + """Delete a model in Matchbox.""" + res = CLIENT.delete(f"/models/{name}", params={"certain": certain}) + return ModelOperationStatus.model_validate(res.json()) diff --git a/src/matchbox/client/models/models.py b/src/matchbox/client/models/models.py index ddef734..48312f5 100644 --- a/src/matchbox/client/models/models.py +++ b/src/matchbox/client/models/models.py @@ -2,12 +2,12 @@ from pandas import DataFrame +from matchbox.client import _handler from matchbox.client.models.dedupers.base import Deduper from matchbox.client.models.linkers.base import Linker from matchbox.client.results import Results -from matchbox.common.dtos import ModelMetadata, ModelType +from matchbox.common.dtos import ModelAncestor, ModelMetadata, ModelType from matchbox.common.exceptions import MatchboxResolutionNotFoundError -from matchbox.server import MatchboxDBAdapter, inject_backend P = ParamSpec("P") R = TypeVar("R") @@ -28,60 +28,65 @@ def __init__( self.left_data = left_data self.right_data = right_data - @inject_backend - def insert_model(self, backend: MatchboxDBAdapter) -> None: + def insert_model(self) -> None: """Insert the model into the backend database.""" - backend.insert_model(model=self.metadata) + _handler.insert_model(model=self.metadata) @property - @inject_backend - def results(self, backend: MatchboxDBAdapter) -> Results: + def results(self) -> Results: """Retrieve results associated with the model from the database.""" - results = backend.get_model_results(model=self.metadata.name) + results = _handler.get_model_results(name=self.metadata.name) return Results(probabilities=results, metadata=self.metadata) @results.setter - @inject_backend - def results(self, backend: MatchboxDBAdapter, results: Results) -> None: + def results(self, results: Results) -> None: """Write results associated with the model to the database.""" - backend.set_model_results( - model=self.metadata.name, results=results.probabilities - ) + if results.probabilities.shape[0] > 0: + _handler.add_model_results( + name=self.metadata.name, results=results.probabilities + ) @property - @inject_backend - def truth(self, backend: MatchboxDBAdapter) -> float: + def truth(self) -> float: """Retrieve the truth threshold for the model.""" - return backend.get_model_truth(model=self.metadata.name) + return _handler.get_model_truth(name=self.metadata.name) @truth.setter - @inject_backend - def truth(self, backend: MatchboxDBAdapter, truth: float) -> None: + def truth(self, truth: float) -> None: """Set the truth threshold for the model.""" - backend.set_model_truth(model=self.metadata.name, truth=truth) + _handler.set_model_truth(name=self.metadata.name, truth=truth) @property - @inject_backend - def ancestors(self, backend: MatchboxDBAdapter) -> dict[str, float]: + def ancestors(self) -> dict[str, float]: """Retrieve the ancestors of the model.""" - return backend.get_model_ancestors(model=self.metadata.name) + return { + ancestor.name: ancestor.truth + for ancestor in _handler.get_model_ancestors(name=self.metadata.name) + } @property - @inject_backend - def ancestors_cache(self, backend: MatchboxDBAdapter) -> dict[str, float]: + def ancestors_cache(self) -> dict[str, float]: """Retrieve the ancestors cache of the model.""" - return backend.get_model_ancestors_cache(model=self.metadata.name) + return { + ancestor.name: ancestor.truth + for ancestor in _handler.get_model_ancestors_cache(name=self.metadata.name) + } @ancestors_cache.setter - @inject_backend - def ancestors_cache( - self, backend: MatchboxDBAdapter, ancestors_cache: dict[str, float] - ) -> None: + def ancestors_cache(self, ancestors_cache: dict[str, float]) -> None: """Set the ancestors cache of the model.""" - backend.set_model_ancestors_cache( - model=self.metadata.name, ancestors_cache=ancestors_cache + _handler.set_model_ancestors_cache( + name=self.metadata.name, + ancestors=[ + ModelAncestor(name=k, truth=v) for k, v in ancestors_cache.items() + ], ) + def delete(self, certain: bool = False) -> bool: + """Delete the model from the database.""" + result = _handler.delete_model(name=self.metadata.name, certain=certain) + return result.success + def run(self) -> Results: """Execute the model pipeline and return results.""" if self.metadata.type == ModelType.LINKER: diff --git a/src/matchbox/client/results.py b/src/matchbox/client/results.py index 776e64d..c98c484 100644 --- a/src/matchbox/client/results.py +++ b/src/matchbox/client/results.py @@ -11,7 +11,6 @@ from matchbox.common.dtos import ModelMetadata from matchbox.common.hash import IntMap from matchbox.common.transform import to_clusters -from matchbox.server.base import MatchboxDBAdapter, inject_backend if TYPE_CHECKING: from matchbox.client.models.models import Model @@ -210,8 +209,7 @@ def inspect_clusters( right_merge_col="child", ) - @inject_backend - def to_matchbox(self, backend: MatchboxDBAdapter) -> None: + def to_matchbox(self) -> None: """Writes the results to the Matchbox database.""" self.model.insert_model() self.model.results = self diff --git a/src/matchbox/common/dtos.py b/src/matchbox/common/dtos.py index 905e253..b4afb42 100644 --- a/src/matchbox/common/dtos.py +++ b/src/matchbox/common/dtos.py @@ -1,7 +1,7 @@ from enum import StrEnum from typing import Literal -from pydantic import BaseModel +from pydantic import BaseModel, Field from matchbox.common.arrow import SCHEMA_INDEX, SCHEMA_RESULTS @@ -45,6 +45,15 @@ class ModelType(StrEnum): DEDUPER = "deduper" +class ModelOperationType(StrEnum): + """Enumeration of supported model operations.""" + + INSERT = "insert" + UPDATE_TRUTH = "update_truth" + UPDATE_ANCESTOR_CACHE = "update_ancestor_cache" + DELETE = "delete" + + class ModelMetadata(BaseModel): """Metadata for a model.""" @@ -55,6 +64,75 @@ class ModelMetadata(BaseModel): right_resolution: str | None = None # Only used for linker models +class ModelAncestor(BaseModel): + """A model's ancestor and its truth value.""" + + name: str = Field(..., description="Name of the ancestor model") + truth: float | None = Field( + default=None, description="Truth threshold value", ge=0.0, le=1.0 + ) + + +class ModelOperationStatus(BaseModel): + """Status response for any model operation.""" + + success: bool + model_name: str + operation: ModelOperationType + details: str | None = None + + @classmethod + def status_409_examples(cls) -> dict: + return { + "content": { + "application/json": { + "examples": { + "confirm_delete": { + "summary": "Delete operation requires confirmation. ", + "value": cls( + success=False, + model_name="example_model", + operation=ModelOperationType.DELETE, + details=( + "This operation will delete the resolutions " + "deduper_1, deduper_2, linker_1, " + "as well as all probabilities they have created. " + "\n\n" + "It won't delete validation associated with these " + "probabilities. \n\n" + "If you're sure you want to continue, rerun with " + "certain=True" + ), + ).model_dump(), + }, + }, + } + } + } + + @classmethod + def status_500_examples(cls) -> dict: + return { + "content": { + "application/json": { + "examples": { + "unhandled": { + "summary": ( + "Unhandled exception encountered while updating the " + "model's truth value." + ), + "value": cls( + success=False, + model_name="example_model", + operation=ModelOperationType.UPDATE_TRUTH, + ).model_dump(), + }, + }, + } + } + } + + class HealthCheck(BaseModel): """Response model to validate and return when performing a health check.""" @@ -77,6 +155,20 @@ class UploadStatus(BaseModel): details: str | None = None entity: BackendUploadType | None = None + _status_code_mapping = { + "ready": 200, + "complete": 200, + "failed": 400, + "awaiting_upload": 202, + "queued": 200, + "processing": 200, + } + + def get_http_code(self, status: bool) -> int: + if self.status == "failed": + return 400 + return self._status_code_mapping[self.status] + @classmethod def status_400_examples(cls) -> dict: return { diff --git a/src/matchbox/common/exceptions.py b/src/matchbox/common/exceptions.py index 2f11bde..8b7a6d8 100644 --- a/src/matchbox/common/exceptions.py +++ b/src/matchbox/common/exceptions.py @@ -136,3 +136,23 @@ def __init__(self, message: str | None = None): class MatchboxConnectionError(Exception): """Connection to Matchbox's backend database failed.""" + + +class MatchboxDeletionNotConfirmed(Exception): + """Deletion must be confirmed: if certain, rerun with certain=True.""" + + def __init__(self, message: str | None = None, children: list[str] | None = None): + if message is None: + message = "Deletion must be confirmed: if certain, rerun with certain=True." + + if children is not None: + children_names = ", ".join(children) + message = ( + f"This operation will delete the resolutions {children_names}, " + "as well as all probabilities they have created. \n\n" + "It won't delete validation associated with these " + "probabilities. \n\n" + "If you're sure you want to continue, rerun with certain=True. " + ) + + super().__init__(message) diff --git a/src/matchbox/common/factories/models.py b/src/matchbox/common/factories/models.py index 3a51ec5..87a1349 100644 --- a/src/matchbox/common/factories/models.py +++ b/src/matchbox/common/factories/models.py @@ -1,13 +1,20 @@ from collections import Counter +from functools import cache from textwrap import dedent from typing import Any, Literal +from unittest.mock import Mock, PropertyMock, create_autospec import numpy as np import pyarrow as pa import rustworkx as rx from faker import Faker +from pandas import DataFrame from pydantic import BaseModel, ConfigDict +from matchbox.client.models.dedupers.base import Deduper +from matchbox.client.models.linkers.base import Linker +from matchbox.client.models.models import Model +from matchbox.client.results import Results from matchbox.common.dtos import ModelMetadata, ModelType from matchbox.common.transform import graph_results @@ -311,11 +318,40 @@ class ModelDummy(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - model: ModelMetadata + model: Model data: pa.Table metrics: ModelMetrics + def to_mock(self) -> Mock: + """Create a mock Model object that mimics this dummy model's behavior.""" + mock_model = create_autospec(Model) + # Set basic attributes + mock_model.metadata = self.model + mock_model.left_data = DataFrame() # Default empty DataFrame + mock_model.right_data = ( + DataFrame() if self.model.type == ModelType.LINKER else None + ) + + # Mock results property + mock_results = Results(probabilities=self.data, metadata=self.model) + type(mock_model).results = PropertyMock(return_value=mock_results) + + # Mock run method + mock_model.run.return_value = mock_results + + # Mock the model instance based on type + if self.model.type == ModelType.LINKER: + mock_model.model_instance = create_autospec(Linker) + mock_model.model_instance.link.return_value = self.data + else: + mock_model.model_instance = create_autospec(Deduper) + mock_model.model_instance.dedupe.return_value = self.data + + return mock_model + + +@cache def model_factory( name: str | None = None, description: str | None = None, @@ -342,7 +378,7 @@ def model_factory( model_type = ModelType(model_type.lower() if model_type else "deduper") - model = ModelMetadata( + metadata = ModelMetadata( name=name or generator.word(), description=description or generator.sentence(), type=model_type, @@ -350,14 +386,21 @@ def model_factory( right_resolution=generator.word() if model_type == ModelType.LINKER else None, ) - if model.type == ModelType.LINKER: + if metadata.type == ModelType.LINKER: left_values = list(range(n_true_entities)) right_values = list(range(n_true_entities, n_true_entities * 2)) - elif model.type == ModelType.DEDUPER: + elif metadata.type == ModelType.DEDUPER: values_count = n_true_entities * 2 # So there's something to dedupe left_values = list(range(values_count)) right_values = None + model = Model( + metadata=metadata, + model_instance=Mock(), + left_data=DataFrame({"id": left_values}), + right_data=DataFrame({"id": right_values}) if right_values else None, + ) + probabilities = generate_dummy_probabilities( left_values=left_values, right_values=right_values, diff --git a/src/matchbox/common/factories/sources.py b/src/matchbox/common/factories/sources.py index 4df0c4a..f064670 100644 --- a/src/matchbox/common/factories/sources.py +++ b/src/matchbox/common/factories/sources.py @@ -1,5 +1,7 @@ from abc import ABC, abstractmethod +from functools import cache, wraps from math import comb +from typing import Callable, ParamSpec, TypeVar from unittest.mock import Mock, create_autospec from uuid import uuid4 @@ -12,10 +14,39 @@ from matchbox.common.arrow import SCHEMA_INDEX from matchbox.common.sources import Source, SourceAddress, SourceColumn +P = ParamSpec("P") +R = TypeVar("R") + + +def make_features_hashable(func: Callable[P, R]) -> Callable[P, R]: + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + # Handle features in first positional arg + if args and args[0] is not None: + if isinstance(args[0][0], dict): + args = (tuple(FeatureConfig(**d) for d in args[0]),) + args[1:] + else: + args = (tuple(args[0]),) + args[1:] + + # Handle features in kwargs + if "features" in kwargs and kwargs["features"] is not None: + if isinstance(kwargs["features"][0], dict): + kwargs["features"] = tuple( + FeatureConfig(**d) for d in kwargs["features"] + ) + else: + kwargs["features"] = tuple(kwargs["features"]) + + return func(*args, **kwargs) + + return wrapper + class VariationRule(BaseModel, ABC): """Abstract base class for variation rules.""" + model_config = ConfigDict(frozen=True) + @abstractmethod def apply(self, value: str) -> str: """Apply the variation to a value.""" @@ -71,10 +102,12 @@ def type(self) -> str: class FeatureConfig(BaseModel): """Configuration for generating a feature with variations.""" + model_config = ConfigDict(frozen=True) + name: str base_generator: str - parameters: dict = Field(default_factory=dict) - variations: list[VariationRule] = Field(default_factory=list) + parameters: tuple = Field(default_factory=tuple) + variations: tuple[VariationRule, ...] = Field(default_factory=tuple) class SourceMetrics(BaseModel): @@ -138,7 +171,7 @@ class SourceDummy(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) source: Source - features: list[FeatureConfig] + features: tuple[FeatureConfig, ...] data: pa.Table data_hashes: pa.Table metrics: SourceMetrics @@ -166,7 +199,7 @@ def __init__(self, seed: int = 42): self.faker.seed_instance(seed) def generate_data( - self, n_true_entities: int, features: list[FeatureConfig], repetition: int + self, n_true_entities: int, features: tuple[FeatureConfig], repetition: int ) -> tuple[pa.Table, pa.Table, SourceMetrics]: """Generate raw data as PyArrow tables. @@ -186,7 +219,7 @@ def generate_data( for _ in range(n_true_entities): # Generate base values -- the raw row base_values = { - f.name: getattr(self.faker, f.base_generator)(**f.parameters) + f.name: getattr(self.faker, f.base_generator)(**dict(f.parameters)) for f in features } @@ -235,6 +268,8 @@ def generate_data( return pa.Table.from_pandas(df), data_hashes, metrics +@make_features_hashable +@cache def source_factory( features: list[FeatureConfig] | list[dict] | None = None, full_name: str | None = None, @@ -259,7 +294,7 @@ def source_factory( generator = SourceDataGenerator(seed) if features is None: - features = [ + features = ( FeatureConfig( name="company_name", base_generator="company", @@ -267,9 +302,9 @@ def source_factory( FeatureConfig( name="crn", base_generator="bothify", - parameters={"text": "???-###-???-###"}, + parameters=(("text", "???-###-???-###"),), ), - ] + ) if full_name is None: full_name = generator.faker.word() @@ -277,9 +312,6 @@ def source_factory( if engine is None: engine = create_engine("sqlite:///:memory:") - if features and isinstance(features[0], dict): - features = [FeatureConfig.model_validate(feature) for feature in features] - data, data_hashes, metrics = generator.generate_data( n_true_entities=n_true_entities, features=features, repetition=repetition ) diff --git a/src/matchbox/server/__init__.py b/src/matchbox/server/__init__.py index 88bf6d0..51b2461 100644 --- a/src/matchbox/server/__init__.py +++ b/src/matchbox/server/__init__.py @@ -3,9 +3,8 @@ MatchboxDBAdapter, MatchboxSettings, initialise_matchbox, - inject_backend, ) -__all__ = ["app", "MatchboxDBAdapter", "MatchboxSettings", "inject_backend"] +__all__ = ["app", "MatchboxDBAdapter", "MatchboxSettings"] initialise_matchbox() diff --git a/src/matchbox/server/api/cache.py b/src/matchbox/server/api/cache.py index 0e3734b..b8aad7f 100644 --- a/src/matchbox/server/api/cache.py +++ b/src/matchbox/server/api/cache.py @@ -6,10 +6,7 @@ import pyarrow as pa from pydantic import BaseModel, ConfigDict -from matchbox.common.dtos import ( - BackendUploadType, - UploadStatus, -) +from matchbox.common.dtos import BackendUploadType, ModelMetadata, UploadStatus from matchbox.common.sources import Source from matchbox.server.api.arrow import s3_to_recordbatch from matchbox.server.base import MatchboxDBAdapter @@ -18,7 +15,7 @@ class MetadataCacheEntry(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - metadata: Source + metadata: Source | ModelMetadata upload_type: BackendUploadType update_timestamp: datetime status: UploadStatus @@ -68,9 +65,20 @@ def cache_source(self, metadata: Source) -> str: ) return cache_id - def cache_model(self, metadata: object) -> str: + def cache_model(self, metadata: ModelMetadata) -> str: """Cache model results metadata and return ID.""" - raise NotImplementedError + self._cleanup_if_needed() + cache_id = str(uuid.uuid4()) + + self._store[cache_id] = MetadataCacheEntry( + metadata=metadata, + upload_type=BackendUploadType.RESULTS, + update_timestamp=datetime.now(), + status=UploadStatus( + id=cache_id, status="awaiting_upload", entity=BackendUploadType.RESULTS + ), + ) + return cache_id def get(self, cache_id: str) -> MetadataCacheEntry | None: """Retrieve metadata by ID if not expired. Updates timestamp on access.""" @@ -158,9 +166,10 @@ async def process_upload( """Background task to process uploaded file.""" metadata_store.update_status(upload_id, "processing") client = backend.settings.datastore.get_client() + upload = metadata_store.get(upload_id) try: - data_hashes = pa.Table.from_batches( + data = pa.Table.from_batches( [ batch async for batch in s3_to_recordbatch( @@ -168,10 +177,14 @@ async def process_upload( ) ] ) + async with heartbeat(metadata_store, upload_id): - backend.index( - source=metadata_store.get(upload_id).metadata, data_hashes=data_hashes - ) + if upload.upload_type == BackendUploadType.INDEX: + backend.index(source=upload.metadata, data_hashes=data) + elif upload.upload_type == BackendUploadType.RESULTS: + backend.set_model_results(model=upload.metadata.name, results=data) + else: + raise ValueError(f"Unknown upload type: {upload.upload_type}") metadata_store.update_status(upload_id, "complete") diff --git a/src/matchbox/server/api/routes.py b/src/matchbox/server/api/routes.py index 09e6792..3f7b412 100644 --- a/src/matchbox/server/api/routes.py +++ b/src/matchbox/server/api/routes.py @@ -2,7 +2,16 @@ from typing import TYPE_CHECKING, Annotated, Any, AsyncGenerator from dotenv import find_dotenv, load_dotenv -from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query, UploadFile +from fastapi import ( + BackgroundTasks, + Body, + Depends, + FastAPI, + HTTPException, + Query, + UploadFile, + status, +) from fastapi.responses import JSONResponse, Response from starlette.exceptions import HTTPException as StarletteHTTPException @@ -13,11 +22,15 @@ BackendUploadType, CountResult, HealthCheck, - ModelResultsType, + ModelAncestor, + ModelMetadata, + ModelOperationStatus, + ModelOperationType, NotFoundError, UploadStatus, ) from matchbox.common.exceptions import ( + MatchboxDeletionNotConfirmed, MatchboxResolutionNotFoundError, MatchboxServerFileError, MatchboxSourceNotFoundError, @@ -51,7 +64,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: app = FastAPI( title="matchbox API", - version="0.2.0", + version="0.2.1", lifespan=lifespan, ) @@ -62,6 +75,9 @@ async def http_exception_handler(request, exc): return JSONResponse(content=exc.detail, status_code=exc.status_code) +# General + + def get_backend() -> MatchboxDBAdapter: return BackendManager.get_backend() @@ -77,6 +93,8 @@ async def count_backend_items( backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], entity: BackendCountableType | None = None, ) -> CountResult: + """Count the number of various entities in the backend.""" + def get_count(e: BackendCountableType) -> int: return getattr(backend, str(e)).count() @@ -87,47 +105,12 @@ def get_count(e: BackendCountableType) -> int: return CountResult(entities=res) -@app.post("/testing/clear") -async def clear_backend(): - raise HTTPException(status_code=501, detail="Not implemented") - - -@app.get("/sources") -async def list_sources(): - raise HTTPException(status_code=501, detail="Not implemented") - - -@app.get( - "/sources/{warehouse_hash_b64}/{full_name}", - responses={404: {"model": NotFoundError}}, -) -async def get_source( - backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], - warehouse_hash_b64: str, - full_name: str, -) -> Source: - address = SourceAddress(full_name=full_name, warehouse_hash=warehouse_hash_b64) - try: - return backend.get_source(address) - except MatchboxSourceNotFoundError as e: - raise HTTPException( - status_code=404, - detail=NotFoundError( - details=str(e), entity=BackendRetrievableType.SOURCE - ).model_dump(), - ) from e - - -@app.post("/sources") -async def add_source(source: Source): - """Add a source to the backend.""" - upload_id = metadata_store.cache_source(metadata=source) - return metadata_store.get(cache_id=upload_id).status - - @app.post( "/upload/{upload_id}", - responses={400: {"model": UploadStatus, **UploadStatus.status_400_examples()}}, + responses={ + 400: {"model": UploadStatus, **UploadStatus.status_400_examples()}, + }, + status_code=status.HTTP_202_ACCEPTED, ) async def upload_file( backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], @@ -199,12 +182,24 @@ async def upload_file( metadata_store=metadata_store, ) - return metadata_store.get(upload_id).status + source_cache = metadata_store.get(upload_id) + + # Check for error in async task + if source_cache.status.status == "failed": + raise HTTPException( + status_code=400, + detail=source_cache.status.model_dump(), + ) + else: + return source_cache.status @app.get( "/upload/{upload_id}/status", - responses={400: {"model": UploadStatus, **UploadStatus.status_400_examples()}}, + responses={ + 400: {"model": UploadStatus, **UploadStatus.status_400_examples()}, + }, + status_code=status.HTTP_200_OK, ) async def get_upload_status( upload_id: str, @@ -234,64 +229,11 @@ async def get_upload_status( return source_cache.status -@app.get("/models") -async def list_models(): - raise HTTPException(status_code=501, detail="Not implemented") - - -@app.get("/resolution/{name}") -async def get_resolution(name: str): - raise HTTPException(status_code=501, detail="Not implemented") - - -@app.post("/models/{name}") -async def add_model(name: str): - raise HTTPException(status_code=501, detail="Not implemented") - - -@app.delete("/models/{name}") -async def delete_model(name: str): - raise HTTPException(status_code=501, detail="Not implemented") - - -@app.get("/models/{name}/results") -async def get_results(name: str, result_type: ModelResultsType | None): - raise HTTPException(status_code=501, detail="Not implemented") - - -@app.post("/models/{name}/results") -async def set_results(name: str): - raise HTTPException(status_code=501, detail="Not implemented") - - -@app.get("/models/{name}/truth") -async def get_truth(name: str): - raise HTTPException(status_code=501, detail="Not implemented") - - -@app.post("/models/{name}/truth") -async def set_truth(name: str): - raise HTTPException(status_code=501, detail="Not implemented") - - -@app.get("/models/{name}/ancestors") -async def get_ancestors(name: str): - raise HTTPException(status_code=501, detail="Not implemented") - - -@app.get("/models/{name}/ancestors_cache") -async def get_ancestors_cache(name: str): - raise HTTPException(status_code=501, detail="Not implemented") - - -@app.post("/models/{name}/ancestors_cache") -async def set_ancestors_cache(name: str): - raise HTTPException(status_code=501, detail="Not implemented") +# Retrieval @app.get( "/query", - response_class=ParquetResponse, responses={404: {"model": NotFoundError}}, ) async def query( @@ -301,7 +243,8 @@ async def query( resolution_name: str | None = None, threshold: int | None = None, limit: int | None = None, -): +) -> ParquetResponse: + """Query Matchbox for matches based on a source address.""" source_address = SourceAddress( full_name=full_name, warehouse_hash=warehouse_hash_b64 ) @@ -345,6 +288,7 @@ async def match( resolution_name: str, threshold: int | None = None, ) -> list[Match]: + """Match a source primary key against a list of target addresses.""" targets = [ SourceAddress(full_name=n, warehouse_hash=w) for n, w in zip(target_full_names, target_warehouse_hashes_b64, strict=True) @@ -378,9 +322,36 @@ async def match( return res -@app.get("/validate/hash") -async def validate_hashes(): - raise HTTPException(status_code=501, detail="Not implemented") +# Data management + + +@app.post("/sources", status_code=status.HTTP_202_ACCEPTED) +async def add_source(source: Source) -> UploadStatus: + """Create an upload and insert task for indexed source data.""" + upload_id = metadata_store.cache_source(metadata=source) + return metadata_store.get(cache_id=upload_id).status + + +@app.get( + "/sources/{warehouse_hash_b64}/{full_name}", + responses={404: {"model": NotFoundError}}, +) +async def get_source( + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], + warehouse_hash_b64: str, + full_name: str, +) -> Source: + """Get a source from the backend.""" + address = SourceAddress(full_name=full_name, warehouse_hash=warehouse_hash_b64) + try: + return backend.get_source(address) + except MatchboxSourceNotFoundError as e: + raise HTTPException( + status_code=404, + detail=NotFoundError( + details=str(e), entity=BackendRetrievableType.SOURCE + ).model_dump(), + ) from e @app.get("/report/resolutions") @@ -388,3 +359,285 @@ async def get_resolutions( backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], ) -> ResolutionGraph: return backend.get_resolution_graph() + + +# Model management + + +@app.post( + "/models", + responses={ + 500: { + "model": ModelOperationStatus, + **ModelOperationStatus.status_500_examples(), + }, + }, + status_code=status.HTTP_201_CREATED, +) +async def insert_model( + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], model: ModelMetadata +) -> ModelOperationStatus: + """Insert a model into the backend.""" + try: + backend.insert_model(model) + return ModelOperationStatus( + success=True, + model_name=model.name, + operation=ModelOperationType.INSERT, + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=ModelOperationStatus( + success=False, + model_name=model.name, + operation=ModelOperationType.INSERT, + details=str(e), + ).model_dump(), + ) from e + + +@app.get( + "/models/{name}", + responses={404: {"model": NotFoundError}}, +) +async def get_model( + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], name: str +) -> ModelMetadata: + """Get a model from the backend.""" + try: + return backend.get_model(model=name) + except MatchboxResolutionNotFoundError as e: + raise HTTPException( + status_code=404, + detail=NotFoundError( + details=str(e), entity=BackendRetrievableType.RESOLUTION + ).model_dump(), + ) from e + + +@app.post( + "/models/{name}/results", + responses={404: {"model": NotFoundError}}, + status_code=status.HTTP_202_ACCEPTED, +) +async def set_results( + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], name: str +) -> UploadStatus: + """Create an upload task for model results.""" + try: + metadata = backend.get_model(model=name) + except MatchboxResolutionNotFoundError as e: + raise HTTPException( + status_code=404, + detail=NotFoundError( + details=str(e), entity=BackendRetrievableType.RESOLUTION + ).model_dump(), + ) from e + + upload_id = metadata_store.cache_model(metadata=metadata) + return metadata_store.get(cache_id=upload_id).status + + +@app.get( + "/models/{name}/results", + responses={404: {"model": NotFoundError}}, +) +async def get_results( + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], name: str +) -> ParquetResponse: + """Download results for a model as a parquet file.""" + try: + res = backend.get_model_results(model=name) + except MatchboxResolutionNotFoundError as e: + raise HTTPException( + status_code=404, + detail=NotFoundError( + details=str(e), entity=BackendRetrievableType.RESOLUTION + ).model_dump(), + ) from e + + buffer = table_to_buffer(res) + return ParquetResponse(buffer.getvalue()) + + +@app.patch( + "/models/{name}/truth", + responses={ + 404: {"model": NotFoundError}, + 500: { + "model": ModelOperationStatus, + **ModelOperationStatus.status_500_examples(), + }, + }, +) +async def set_truth( + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], + name: str, + truth: Annotated[float, Body(ge=0.0, le=1.0)], +) -> ModelOperationStatus: + """Set truth data for a model.""" + try: + backend.set_model_truth(model=name, truth=truth) + return ModelOperationStatus( + success=True, + model_name=name, + operation=ModelOperationType.UPDATE_TRUTH, + ) + except MatchboxResolutionNotFoundError as e: + raise HTTPException( + status_code=404, + detail=NotFoundError( + details=str(e), entity=BackendRetrievableType.RESOLUTION + ).model_dump(), + ) from e + except Exception as e: + raise HTTPException( + status_code=500, + detail=ModelOperationStatus( + success=False, + model_name=name, + operation=ModelOperationType.UPDATE_TRUTH, + details=str(e), + ).model_dump(), + ) from e + + +@app.get( + "/models/{name}/truth", + responses={404: {"model": NotFoundError}}, +) +async def get_truth( + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], name: str +) -> float: + """Get truth data for a model.""" + try: + return backend.get_model_truth(model=name) + except MatchboxResolutionNotFoundError as e: + raise HTTPException( + status_code=404, + detail=NotFoundError( + details=str(e), entity=BackendRetrievableType.RESOLUTION + ).model_dump(), + ) from e + + +@app.get( + "/models/{name}/ancestors", + responses={404: {"model": NotFoundError}}, +) +async def get_ancestors( + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], name: str +) -> list[ModelAncestor]: + try: + return backend.get_model_ancestors(model=name) + except MatchboxResolutionNotFoundError as e: + raise HTTPException( + status_code=404, + detail=NotFoundError( + details=str(e), entity=BackendRetrievableType.RESOLUTION + ).model_dump(), + ) from e + + +@app.patch( + "/models/{name}/ancestors_cache", + responses={ + 404: {"model": NotFoundError}, + 500: { + "model": ModelOperationStatus, + **ModelOperationStatus.status_500_examples(), + }, + }, +) +async def set_ancestors_cache( + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], + name: str, + ancestors: list[ModelAncestor], +): + try: + backend.set_model_ancestors_cache(model=name, ancestors_cache=ancestors) + return ModelOperationStatus( + success=True, + model_name=name, + operation=ModelOperationType.UPDATE_ANCESTOR_CACHE, + ) + except MatchboxResolutionNotFoundError as e: + raise HTTPException( + status_code=404, + detail=NotFoundError( + details=str(e), entity=BackendRetrievableType.RESOLUTION + ).model_dump(), + ) from e + except Exception as e: + raise HTTPException( + status_code=500, + detail=ModelOperationStatus( + success=False, + model_name=name, + operation=ModelOperationType.UPDATE_ANCESTOR_CACHE, + details=str(e), + ).model_dump(), + ) from e + + +@app.get( + "/models/{name}/ancestors_cache", + responses={404: {"model": NotFoundError}}, +) +async def get_ancestors_cache( + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], name: str +) -> list[ModelAncestor]: + try: + return backend.get_model_ancestors_cache(model=name) + except MatchboxResolutionNotFoundError as e: + raise HTTPException( + status_code=404, + detail=NotFoundError( + details=str(e), entity=BackendRetrievableType.RESOLUTION + ).model_dump(), + ) from e + + +@app.delete( + "/models/{name}", + responses={ + 404: {"model": NotFoundError}, + 409: { + "model": ModelOperationStatus, + **ModelOperationStatus.status_409_examples(), + }, + }, +) +async def delete_model( + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], + name: str, + certain: Annotated[ + bool, Query(description="Confirm deletion of the model") + ] = False, +) -> ModelOperationStatus: + """Delete a model from the backend.""" + try: + backend.delete_model(model=name, certain=certain) + return ModelOperationStatus( + success=True, + model_name=name, + operation=ModelOperationType.DELETE, + ) + except MatchboxResolutionNotFoundError as e: + raise HTTPException( + status_code=404, + detail=NotFoundError( + details=str(e), entity=BackendRetrievableType.RESOLUTION + ).model_dump(), + ) from e + except MatchboxDeletionNotConfirmed as e: + raise HTTPException( + status_code=409, + detail=ModelOperationStatus( + success=False, + model_name=name, + operation=ModelOperationType.DELETE, + details=str(e), + ).model_dump(), + ) from e diff --git a/src/matchbox/server/base.py b/src/matchbox/server/base.py index abd741f..da2825f 100644 --- a/src/matchbox/server/base.py +++ b/src/matchbox/server/base.py @@ -1,16 +1,12 @@ -import inspect from abc import ABC, abstractmethod from enum import StrEnum -from functools import wraps from pathlib import Path from typing import ( TYPE_CHECKING, Any, - Callable, ParamSpec, Protocol, TypeVar, - cast, ) import boto3 @@ -20,7 +16,7 @@ from pydantic import Field, SecretStr from pydantic_settings import BaseSettings, SettingsConfigDict -from matchbox.common.dtos import ModelMetadata +from matchbox.common.dtos import ModelAncestor, ModelMetadata from matchbox.common.graph import ResolutionGraph from matchbox.common.sources import Match, Source, SourceAddress @@ -181,35 +177,6 @@ def initialise_matchbox() -> None: initialise_backend(settings) -def inject_backend(func: Callable[..., R]) -> Callable[..., R]: - """Decorator to inject the Matchbox backend into functions. - - Used to allow user-facing functions to access the backend without needing to - pass it in. The backend is defined by the MB__BACKEND_TYPE environment variable. - - Can be used for both functions and methods. - - If the user specifies a backend, it will be used instead of the injection. - """ - - @wraps(func) - def _inject_backend( - *args: P.args, backend: "MatchboxDBAdapter | None" = None, **kwargs: P.kwargs - ) -> R: - if backend is None: - backend = BackendManager.get_backend() - - sig = inspect.signature(func) - params = list(sig.parameters.values()) - - if params and params[0].name in ("self", "cls"): - return cast(R, func(args[0], backend, *args[1:], **kwargs)) - else: - return cast(R, func(backend, *args, **kwargs)) - - return cast(Callable[..., R], _inject_backend) - - class Countable(Protocol): """A protocol for objects that can be counted.""" @@ -241,6 +208,8 @@ class MatchboxDBAdapter(ABC): merges: Countable proposes: Countable + # Retrieval + @abstractmethod def query( self, @@ -260,6 +229,8 @@ def match( threshold: int | None = None, ) -> list[Match]: ... + # Data management + @abstractmethod def index(self, source: Source, data_hashes: Table) -> None: ... @@ -281,7 +252,8 @@ def get_resolution_graph(self) -> ResolutionGraph: ... @abstractmethod def clear(self, certain: bool) -> None: ... - # Model methods + # Model management + @abstractmethod def insert_model(self, model: ModelMetadata) -> None: ... @@ -301,15 +273,15 @@ def set_model_truth(self, model: str, truth: float) -> None: ... def get_model_truth(self, model: str) -> float: ... @abstractmethod - def get_model_ancestors(self, model: str) -> dict[str, float]: ... + def get_model_ancestors(self, model: str) -> list[ModelAncestor]: ... @abstractmethod def set_model_ancestors_cache( - self, model: str, ancestors_cache: dict[str, float] + self, model: str, ancestors_cache: list[ModelAncestor] ) -> None: ... @abstractmethod - def get_model_ancestors_cache(self, model: str) -> dict[str, float]: ... + def get_model_ancestors_cache(self, model: str) -> list[ModelAncestor]: ... @abstractmethod def delete_model(self, model: str, certain: bool) -> None: ... diff --git a/src/matchbox/server/postgresql/adapter.py b/src/matchbox/server/postgresql/adapter.py index 935ccec..afeb941 100644 --- a/src/matchbox/server/postgresql/adapter.py +++ b/src/matchbox/server/postgresql/adapter.py @@ -5,9 +5,10 @@ from sqlalchemy import and_, bindparam, delete, func, or_, select from sqlalchemy.orm import Session -from matchbox.common.dtos import ModelMetadata, ModelType +from matchbox.common.dtos import ModelAncestor, ModelMetadata, ModelType from matchbox.common.exceptions import ( MatchboxDataNotFound, + MatchboxDeletionNotConfirmed, MatchboxResolutionNotFoundError, MatchboxSourceNotFoundError, ) @@ -37,13 +38,10 @@ T = TypeVar("T") P = ParamSpec("P") + if TYPE_CHECKING: - from pandas import DataFrame as PandasDataFrame - from polars import DataFrame as PolarsDataFrame from pyarrow import Table as ArrowTable else: - PandasDataFrame = Any - PolarsDataFrame = Any ArrowTable = Any @@ -128,6 +126,8 @@ def __init__(self, settings: MatchboxPostgresSettings): self.creates = FilteredProbabilities(over_truth=True) self.proposes = FilteredProbabilities() + # Retrieval + def query( self, source_address: SourceAddress, @@ -188,6 +188,8 @@ def match( threshold=threshold, ) + # Data management + def index(self, source: Source, data_hashes: Table) -> None: """Indexes to Matchbox a source dataset in your warehouse. @@ -357,7 +359,7 @@ def clear(self, certain: bool = False) -> None: "If you're sure you want to continue, rerun with certain=True" ) - # Model methods + # Model management def insert_model(self, model: ModelMetadata) -> None: """Writes a model to Matchbox. @@ -434,57 +436,61 @@ def get_model_truth(self, model: str) -> float: resolution = resolve_model_name(model=model, engine=MBDB.get_engine()) return resolution.truth - def get_model_ancestors(self, model: str) -> dict[str, float]: + def get_model_ancestors(self, model: str) -> list[ModelAncestor]: """Gets the current truth values of all ancestors. - Returns a dict mapping model names to their current truth thresholds. + Returns a list of ModelAncestor objects mapping model names to their current + truth thresholds. Unlike ancestors_cache which returns cached values, this property returns the current truth values of all ancestor models. """ resolution = resolve_model_name(model=model, engine=MBDB.get_engine()) - return { - resolution.name: resolution.truth for resolution in resolution.ancestors - } + return [ + ModelAncestor(name=resolution.name, truth=resolution.truth) + for resolution in resolution.ancestors + ] def set_model_ancestors_cache( self, model: str, - ancestors_cache: dict[str, float], + ancestors_cache: list[ModelAncestor], ) -> None: """Updates the cached ancestor thresholds. Args: - ancestors_cache: Dictionary mapping model names to their truth thresholds + ancestors_cache: List of ModelAncestor objects mapping model names to + their truth thresholds """ resolution = resolve_model_name(model=model, engine=MBDB.get_engine()) with Session(MBDB.get_engine()) as session: session.add(resolution) - model_names = list(ancestors_cache.keys()) + ancestor_names = [ancestor.name for ancestor in ancestors_cache] name_to_id = dict( session.query(Resolutions.name, Resolutions.resolution_id) - .filter(Resolutions.name.in_(model_names)) + .filter(Resolutions.name.in_(ancestor_names)) .all() ) - for model_name, truth_value in ancestors_cache.items(): - parent_id = name_to_id.get(model_name) + for ancestor in ancestors_cache: + parent_id = name_to_id.get(ancestor.name) if parent_id is None: - raise ValueError(f"Model '{model_name}' not found in database") + raise ValueError(f"Model '{ancestor.name}' not found in database") session.execute( ResolutionFrom.__table__.update() .where(ResolutionFrom.parent == parent_id) .where(ResolutionFrom.child == resolution.resolution_id) - .values(truth_cache=truth_value) + .values(truth_cache=ancestor.truth) ) session.commit() - def get_model_ancestors_cache(self, model: str) -> dict[str, float]: + def get_model_ancestors_cache(self, model: str) -> list[ModelAncestor]: """Gets the cached ancestor thresholds, converting hashes to model names. - Returns a dictionary mapping model names to their truth thresholds. + Returns a list of ModelAncestor objects mapping model names to their cached + truth thresholds. This is required because each point of truth needs to be stable, so we choose when to update it, caching the ancestor's values in the model itself. @@ -499,9 +505,10 @@ def get_model_ancestors_cache(self, model: str) -> dict[str, float]: .where(ResolutionFrom.truth_cache.isnot(None)) ) - return { - name: truth_cache for name, truth_cache in session.execute(query).all() - } + return [ + ModelAncestor(name=name, truth=truth) + for name, truth in session.execute(query).all() + ] def delete_model(self, model: str, certain: bool = False) -> None: """Delete a model from the database. @@ -525,12 +532,5 @@ def delete_model(self, model: str, certain: bool = False) -> None: session.delete(resolution) session.commit() else: - childen = resolution.descendants - children_names = ", ".join([r.name for r in childen]) - raise ValueError( - f"This operation will delete the resolutions {children_names}, " - "as well as all probabilities they have created. \n\n" - "It won't delete validation associated with these " - "probabilities. \n\n" - "If you're sure you want to continue, rerun with certain=True" - ) + children = [r.name for r in resolution.descendants] + raise MatchboxDeletionNotConfirmed(childen=children) diff --git a/src/matchbox/server/postgresql/db.py b/src/matchbox/server/postgresql/db.py index 043403f..4e1ae7c 100644 --- a/src/matchbox/server/postgresql/db.py +++ b/src/matchbox/server/postgresql/db.py @@ -93,6 +93,8 @@ def clear_database(self): ) conn.commit() + self.engine.dispose() + self.create_database() diff --git a/test/client/test_dedupers.py b/test/client/test_dedupers.py index 2046e58..e283944 100644 --- a/test/client/test_dedupers.py +++ b/test/client/test_dedupers.py @@ -108,7 +108,7 @@ def test_dedupers( # 4. Probabilities and clusters are inserted correctly - results.to_matchbox(backend=matchbox_postgres) + results.to_matchbox() retrieved_results = matchbox_postgres.get_model_results(model=deduper_name) assert retrieved_results.shape[0] == fx_data.tgt_prob_n diff --git a/test/client/test_helpers.py b/test/client/test_helpers.py index 5a0b0ee..194e4b3 100644 --- a/test/client/test_helpers.py +++ b/test/client/test_helpers.py @@ -3,14 +3,13 @@ import pyarrow as pa import pytest -import respx from dotenv import find_dotenv, load_dotenv from httpx import Response from pandas import DataFrame +from respx import MockRouter from sqlalchemy import Engine, create_engine from matchbox import index, match, process, query -from matchbox.client._handler import url from matchbox.client.clean import company_name, company_number from matchbox.client.helpers import cleaner, cleaners, comparison, select from matchbox.client.helpers.selector import Match, Selector @@ -46,9 +45,7 @@ def test_cleaners(): assert cleaner_name_number is not None -def test_process( - warehouse_data: list[Source], -): +def test_process(warehouse_data: list[Source]): crn = warehouse_data[0].to_arrow() cleaner_name = cleaner( @@ -77,8 +74,7 @@ def test_comparisons(): assert comparison_name_id is not None -@respx.mock -def test_select_mixed_style(warehouse_engine: Engine): +def test_select_mixed_style(matchbox_api: MockRouter, warehouse_engine: Engine): """We can select select specific columns from some of the sources""" # Set up mocks and test data source1 = Source( @@ -91,11 +87,11 @@ def test_select_mixed_style(warehouse_engine: Engine): db_pk="pk", ) - respx.get( - url(f"/sources/{hash_to_base64(source1.address.warehouse_hash)}/test.foo") + matchbox_api.get( + f"/sources/{hash_to_base64(source1.address.warehouse_hash)}/test.foo" ).mock(return_value=Response(200, content=source1.model_dump_json())) - respx.get( - url(f"/sources/{hash_to_base64(source2.address.warehouse_hash)}/test.bar") + matchbox_api.get( + f"/sources/{hash_to_base64(source2.address.warehouse_hash)}/test.bar" ).mock(return_value=Response(200, content=source2.model_dump_json())) df = DataFrame([{"pk": 0, "a": 1, "b": "2"}, {"pk": 1, "a": 10, "b": "20"}]) @@ -129,15 +125,14 @@ def test_select_mixed_style(warehouse_engine: Engine): assert selection[1].source.engine == warehouse_engine -@respx.mock -def test_select_non_indexed_columns(warehouse_engine: Engine): +def test_select_non_indexed_columns(matchbox_api: MockRouter, warehouse_engine: Engine): """Selecting columns not declared to backend generates warning.""" source = Source( address=SourceAddress.compose(engine=warehouse_engine, full_name="test.foo"), db_pk="pk", ) - respx.get( - url(f"/sources/{hash_to_base64(source.address.warehouse_hash)}/test.foo") + matchbox_api.get( + f"/sources/{hash_to_base64(source.address.warehouse_hash)}/test.foo" ).mock(return_value=Response(200, content=source.model_dump_json())) df = DataFrame([{"pk": 0, "a": 1, "b": "2"}, {"pk": 1, "a": 10, "b": "20"}]) @@ -154,16 +149,15 @@ def test_select_non_indexed_columns(warehouse_engine: Engine): select({"test.foo": ["a", "b"]}, engine=warehouse_engine) -@respx.mock -def test_select_missing_columns(warehouse_engine: Engine): +def test_select_missing_columns(matchbox_api: MockRouter, warehouse_engine: Engine): """Selecting columns not in the warehouse errors.""" source = Source( address=SourceAddress.compose(engine=warehouse_engine, full_name="test.foo"), db_pk="pk", ) - respx.get( - url(f"/sources/{hash_to_base64(source.address.warehouse_hash)}/test.foo") + matchbox_api.get( + f"/sources/{hash_to_base64(source.address.warehouse_hash)}/test.foo" ).mock(return_value=Response(200, content=source.model_dump_json())) df = DataFrame([{"pk": 0, "a": 1, "b": "2"}, {"pk": 1, "a": 10, "b": "20"}]) @@ -205,12 +199,13 @@ def test_query_no_resolution_fail(): query(sels) -@respx.mock @patch.object(Source, "to_arrow") -def test_query_no_resolution_ok_various_params(to_arrow: Mock): +def test_query_no_resolution_ok_various_params( + to_arrow: Mock, matchbox_api: MockRouter +): """Tests that we can avoid passing resolution name, with a variety of parameters.""" # Mock API - query_route = respx.get(url("/query")).mock( + query_route = matchbox_api.get("/query").mock( return_value=Response( 200, content=table_to_buffer( @@ -273,12 +268,11 @@ def test_query_no_resolution_ok_various_params(to_arrow: Mock): } -@respx.mock @patch.object(Source, "to_arrow") -def test_query_multiple_sources_with_limits(to_arrow: Mock): +def test_query_multiple_sources_with_limits(to_arrow: Mock, matchbox_api: MockRouter): """Tests that we can query multiple sources and distribute the limit among them.""" # Mock API - query_route = respx.get(url("/query")).mock( + query_route = matchbox_api.get("/query").mock( side_effect=[ Response( 200, @@ -375,10 +369,9 @@ def test_query_multiple_sources_with_limits(to_arrow: Mock): query([sels[0]], [sels[1]], resolution_name="link", limit=7) -@respx.mock -def test_query_404_resolution(): +def test_query_404_resolution(matchbox_api: MockRouter): # Mock API - respx.get(url("/query")).mock( + matchbox_api.get("/query").mock( return_value=Response( 404, json=NotFoundError( @@ -407,10 +400,9 @@ def test_query_404_resolution(): query(sels) -@respx.mock -def test_query_404_source(): +def test_query_404_source(matchbox_api: MockRouter): # Mock API - respx.get(url("/query")).mock( + matchbox_api.get("/query").mock( return_value=Response( 404, json=NotFoundError( @@ -439,9 +431,8 @@ def test_query_404_source(): query(sels) -@respx.mock @patch("matchbox.client.helpers.index.Source") -def test_index_success(MockSource: Mock): +def test_index_success(MockSource: Mock, matchbox_api: MockRouter): """Test successful indexing flow through the API.""" engine = create_engine("sqlite:///:memory:") @@ -453,9 +444,9 @@ def test_index_success(MockSource: Mock): MockSource.return_value = mock_source_instance # Mock the initial source metadata upload - source_route = respx.post(url("/sources")).mock( + source_route = matchbox_api.post("/sources").mock( return_value=Response( - 200, + 202, json=UploadStatus( id="test-upload-id", status="awaiting_upload", @@ -465,9 +456,9 @@ def test_index_success(MockSource: Mock): ) # Mock the data upload - upload_route = respx.post(url("/upload/test-upload-id")).mock( + upload_route = matchbox_api.post("/upload/test-upload-id").mock( return_value=Response( - 200, + 202, json=UploadStatus( id="test-upload-id", status="complete", entity=BackendUploadType.INDEX ).model_dump(), @@ -491,7 +482,6 @@ def test_index_success(MockSource: Mock): assert b"PAR1" in upload_route.calls.last.request.content -@respx.mock @patch("matchbox.client.helpers.index.Source") @pytest.mark.parametrize( "columns", @@ -508,7 +498,9 @@ def test_index_success(MockSource: Mock): ], ) def test_index_with_columns( - MockSource: Mock, columns: list[str] | list[dict[str, str]] + MockSource: Mock, + matchbox_api: MockRouter, + columns: list[str] | list[dict[str, str]], ): """Test indexing with different column definition formats.""" engine = create_engine("sqlite:///:memory:") @@ -525,9 +517,9 @@ def test_index_with_columns( MockSource.return_value = mock_source_instance # Mock the API endpoints - source_route = respx.post(url("/sources")).mock( + source_route = matchbox_api.post("/sources").mock( return_value=Response( - 200, + 202, json=UploadStatus( id="test-upload-id", status="awaiting_upload", @@ -536,9 +528,9 @@ def test_index_with_columns( ) ) - upload_route = respx.post(url("/upload/test-upload-id")).mock( + upload_route = matchbox_api.post("/upload/test-upload-id").mock( return_value=Response( - 200, + 202, json=UploadStatus( id="test-upload-id", status="complete", entity=BackendUploadType.INDEX ).model_dump(), @@ -569,9 +561,8 @@ def test_index_with_columns( mock_source_instance.default_columns.assert_called_once() -@respx.mock @patch("matchbox.client.helpers.index.Source") -def test_index_upload_failure(MockSource: Mock): +def test_index_upload_failure(MockSource: Mock, matchbox_api: MockRouter): """Test handling of upload failures.""" engine = create_engine("sqlite:///:memory:") @@ -583,9 +574,9 @@ def test_index_upload_failure(MockSource: Mock): MockSource.return_value = mock_source_instance # Mock successful source creation - source_route = respx.post(url("/sources")).mock( + source_route = matchbox_api.post("/sources").mock( return_value=Response( - 200, + 202, json=UploadStatus( id="test-upload-id", status="awaiting_upload", @@ -595,7 +586,7 @@ def test_index_upload_failure(MockSource: Mock): ) # Mock failed upload - upload_route = respx.post(url("/upload/test-upload-id")).mock( + upload_route = matchbox_api.post("/upload/test-upload-id").mock( return_value=Response( 400, json=UploadStatus( @@ -625,8 +616,7 @@ def test_index_upload_failure(MockSource: Mock): assert b"PAR1" in upload_route.calls.last.request.content -@respx.mock -def test_match_ok(): +def test_match_ok(matchbox_api: MockRouter): """The client can perform the right call for matching.""" # Set up mocks mock_match1 = Match( @@ -648,7 +638,7 @@ def test_match_ok(): f"[{mock_match1.model_dump_json()}, {mock_match2.model_dump_json()}]" ) - match_route = respx.get(url("/match")).mock( + match_route = matchbox_api.get("/match").mock( return_value=Response(200, content=serialised_matches) ) @@ -717,11 +707,10 @@ def test_match_ok(): ) -@respx.mock -def test_match_404_resolution(): +def test_match_404_resolution(matchbox_api: MockRouter): """The client can handle a resolution not found error.""" # Set up mocks - respx.get(url("/match")).mock( + matchbox_api.get("/match").mock( return_value=Response( 404, json=NotFoundError( @@ -766,11 +755,10 @@ def test_match_404_resolution(): ) -@respx.mock -def test_match_404_source(): +def test_match_404_source(matchbox_api: MockRouter): """The client can handle a source not found error.""" # Set up mocks - respx.get(url("/match")).mock( + matchbox_api.get("/match").mock( return_value=Response( 404, json=NotFoundError( diff --git a/test/client/test_linkers.py b/test/client/test_linkers.py index 914b9a3..9204207 100644 --- a/test/client/test_linkers.py +++ b/test/client/test_linkers.py @@ -174,7 +174,7 @@ def unique_non_null(s): # 4. Probabilities and clusters are inserted correctly - results.to_matchbox(backend=matchbox_postgres) + results.to_matchbox() retrieved_results = matchbox_postgres.get_model_results(model=linker_name) assert retrieved_results.shape[0] == fx_data.tgt_prob_n diff --git a/test/client/test_model.py b/test/client/test_model.py new file mode 100644 index 0000000..d192d10 --- /dev/null +++ b/test/client/test_model.py @@ -0,0 +1,406 @@ +import json + +import pytest +from httpx import Response +from respx.router import MockRouter + +from matchbox.client.results import Results +from matchbox.common.arrow import SCHEMA_RESULTS, table_to_buffer +from matchbox.common.dtos import ( + BackendRetrievableType, + BackendUploadType, + ModelAncestor, + ModelOperationStatus, + ModelOperationType, + NotFoundError, + UploadStatus, +) +from matchbox.common.exceptions import ( + MatchboxDeletionNotConfirmed, + MatchboxResolutionNotFoundError, + MatchboxServerFileError, + MatchboxUnhandledServerResponse, + MatchboxUnparsedClientRequest, +) +from matchbox.common.factories.models import model_factory + + +def test_insert_model(matchbox_api: MockRouter): + """Test inserting a model via the API.""" + # Create test model using factory + dummy = model_factory(model_type="linker") + + # Mock the POST /models endpoint + route = matchbox_api.post("/models").mock( + return_value=Response( + 201, + json=ModelOperationStatus( + success=True, + model_name=dummy.model.metadata.name, + operation=ModelOperationType.INSERT, + ).model_dump(), + ) + ) + + # Call insert_model + dummy.model.insert_model() + + # Verify the API call + assert route.called + assert ( + route.calls.last.request.content.decode() + == dummy.model.metadata.model_dump_json() + ) + + +def test_insert_model_error(matchbox_api: MockRouter): + """Test handling of model insertion errors.""" + dummy = model_factory(model_type="linker") + + # Mock the POST /models endpoint with an error response + route = matchbox_api.post("/models").mock( + return_value=Response( + 500, + json=ModelOperationStatus( + success=False, + model_name=dummy.model.metadata.name, + operation=ModelOperationType.INSERT, + details="Internal server error", + ).model_dump(), + ) + ) + + # Call insert_model and verify it raises an exception + with pytest.raises(MatchboxUnhandledServerResponse, match="Internal server error"): + dummy.model.insert_model() + + assert route.called + + +def test_results_getter(matchbox_api: MockRouter): + """Test getting model results via the API.""" + dummy = model_factory(model_type="linker") + + # Mock the GET /models/{name}/results endpoint + route = matchbox_api.get(f"/models/{dummy.model.metadata.name}/results").mock( + return_value=Response(200, content=table_to_buffer(dummy.data).read()) + ) + + # Get results + results = dummy.model.results + + # Verify the API call + assert route.called + assert isinstance(results, Results) + assert results.probabilities.schema.equals(SCHEMA_RESULTS) + + +def test_results_getter_not_found(matchbox_api: MockRouter): + """Test getting model results when they don't exist.""" + dummy = model_factory(model_type="linker") + + # Mock the GET endpoint with a 404 response + route = matchbox_api.get(f"/models/{dummy.model.metadata.name}/results").mock( + return_value=Response( + 404, + json=NotFoundError( + details="Results not found", entity=BackendRetrievableType.RESOLUTION + ).model_dump(), + ) + ) + + # Verify that accessing results raises an exception + with pytest.raises(MatchboxResolutionNotFoundError, match="Results not found"): + _ = dummy.model.results + + assert route.called + + +def test_results_setter(matchbox_api: MockRouter): + """Test setting model results via the API.""" + dummy = model_factory(model_type="linker") + + # Mock the endpoints needed for results upload + init_route = matchbox_api.post(f"/models/{dummy.model.metadata.name}/results").mock( + return_value=Response( + 202, + json=UploadStatus( + id="test-upload-id", + status="awaiting_upload", + entity=BackendUploadType.RESULTS, + ).model_dump(), + ) + ) + + upload_route = matchbox_api.post("/upload/test-upload-id").mock( + return_value=Response( + 202, + json=UploadStatus( + id="test-upload-id", + status="processing", + entity=BackendUploadType.RESULTS, + ).model_dump(), + ) + ) + + status_route = matchbox_api.get("/upload/test-upload-id/status").mock( + return_value=Response( + 200, + json=UploadStatus( + id="test-upload-id", + status="complete", + entity=BackendUploadType.RESULTS, + ).model_dump(), + ) + ) + + # Set results + test_results = Results(probabilities=dummy.data, metadata=dummy.model.metadata) + dummy.model.results = test_results + + # Verify API calls + assert init_route.called + assert upload_route.called + assert status_route.called + assert ( + b"PAR1" in upload_route.calls.last.request.content + ) # Check for parquet file signature + + +def test_results_setter_upload_failure(matchbox_api: MockRouter): + """Test handling of upload failures when setting results.""" + dummy = model_factory(model_type="linker") + + # Mock the initial POST endpoint + init_route = matchbox_api.post(f"/models/{dummy.model.metadata.name}/results").mock( + return_value=Response( + 202, + json=UploadStatus( + id="test-upload-id", + status="awaiting_upload", + entity=BackendUploadType.RESULTS, + ).model_dump(), + ) + ) + + # Mock the upload endpoint with a failure + upload_route = matchbox_api.post("/upload/test-upload-id").mock( + return_value=Response( + 400, + json=UploadStatus( + id="test-upload-id", + status="failed", + entity=BackendUploadType.RESULTS, + details="Invalid data format", + ).model_dump(), + ) + ) + + # Attempt to set results and verify it raises an exception + test_results = Results(probabilities=dummy.data, metadata=dummy.model.metadata) + with pytest.raises(MatchboxServerFileError, match="Invalid data format"): + dummy.model.results = test_results + + assert init_route.called + assert upload_route.called + + +def test_truth_getter(matchbox_api: MockRouter): + """Test getting model truth threshold via the API.""" + dummy = model_factory(model_type="linker") + + # Mock the GET /models/{name}/truth endpoint + route = matchbox_api.get(f"/models/{dummy.model.metadata.name}/truth").mock( + return_value=Response(200, json=0.9) + ) + + # Get truth + truth = dummy.model.truth + + # Verify the API call + assert route.called + assert truth == 0.9 + + +def test_truth_setter(matchbox_api: MockRouter): + """Test setting model truth threshold via the API.""" + dummy = model_factory(model_type="linker") + + # Mock the PATCH /models/{name}/truth endpoint + route = matchbox_api.patch(f"/models/{dummy.model.metadata.name}/truth").mock( + return_value=Response( + 200, + json=ModelOperationStatus( + success=True, + model_name=dummy.model.metadata.name, + operation=ModelOperationType.UPDATE_TRUTH, + ).model_dump(), + ) + ) + + # Set truth + dummy.model.truth = 0.9 + + # Verify the API call + assert route.called + assert float(route.calls.last.request.read()) == 0.9 + + +def test_truth_setter_validation_error(matchbox_api: MockRouter): + """Test setting invalid truth values.""" + dummy = model_factory(model_type="linker") + + # Mock the PATCH endpoint with a validation error + route = matchbox_api.patch(f"/models/{dummy.model.metadata.name}/truth").mock( + return_value=Response(422) + ) + + # Attempt to set an invalid truth value + with pytest.raises(MatchboxUnparsedClientRequest): + dummy.model.truth = 1.5 + + assert route.called + + +def test_ancestors_getter(matchbox_api: MockRouter): + """Test getting model ancestors via the API.""" + dummy = model_factory(model_type="linker") + + ancestors_data = [ + ModelAncestor(name="model1", truth=0.9).model_dump(), + ModelAncestor(name="model2", truth=0.8).model_dump(), + ] + + # Mock the GET /models/{name}/ancestors endpoint + route = matchbox_api.get(f"/models/{dummy.model.metadata.name}/ancestors").mock( + return_value=Response(200, json=ancestors_data) + ) + + # Get ancestors + ancestors = dummy.model.ancestors + + # Verify the API call + assert route.called + assert ancestors == {"model1": 0.9, "model2": 0.8} + + +def test_ancestors_cache_operations(matchbox_api: MockRouter): + """Test getting and setting model ancestors cache via the API.""" + dummy = model_factory(model_type="linker") + + # Mock the GET endpoint + get_route = matchbox_api.get( + f"/models/{dummy.model.metadata.name}/ancestors_cache" + ).mock( + return_value=Response( + 200, json=[ModelAncestor(name="model1", truth=0.9).model_dump()] + ) + ) + + # Mock the POST endpoint + set_route = matchbox_api.post( + f"/models/{dummy.model.metadata.name}/ancestors_cache" + ).mock( + return_value=Response( + 200, + json=ModelOperationStatus( + success=True, + model_name=dummy.model.metadata.name, + operation=ModelOperationType.UPDATE_ANCESTOR_CACHE, + ).model_dump(), + ) + ) + + # Get ancestors cache + cache = dummy.model.ancestors_cache + assert get_route.called + assert cache == {"model1": 0.9} + + # Set ancestors cache + dummy.model.ancestors_cache = {"model2": 0.8} + assert set_route.called + assert json.loads(set_route.calls.last.request.content.decode()) == [ + ModelAncestor(name="model2", truth=0.8).model_dump() + ] + + +def test_ancestors_cache_set_error(matchbox_api: MockRouter): + """Test error handling when setting ancestors cache.""" + dummy = model_factory(model_type="linker") + + # Mock the POST endpoint with an error + route = matchbox_api.post( + f"/models/{dummy.model.metadata.name}/ancestors_cache" + ).mock( + return_value=Response( + 500, + json=ModelOperationStatus( + success=False, + model_name=dummy.model.metadata.name, + operation=ModelOperationType.UPDATE_ANCESTOR_CACHE, + details="Database error", + ).model_dump(), + ) + ) + + # Attempt to set ancestors cache + with pytest.raises(MatchboxUnhandledServerResponse, match="Database error"): + dummy.model.ancestors_cache = {"model1": 0.9} + + assert route.called + + +def test_delete_model(matchbox_api: MockRouter): + """Test successfully deleting a model.""" + # Create test model using factory + dummy = model_factory() + + # Mock the DELETE endpoint with success response + route = matchbox_api.delete( + f"/models/{dummy.model.metadata.name}", params={"certain": True} + ).mock( + return_value=Response( + 200, + json=ModelOperationStatus( + success=True, + model_name=dummy.model.metadata.name, + operation=ModelOperationType.DELETE, + ).model_dump(), + ) + ) + + # Delete the model + response = dummy.model.delete(certain=True) + + # Verify the response and API call + assert response + assert route.called + assert route.calls.last.request.url.params["certain"] == "true" + + +def test_delete_model_needs_confirmation(matchbox_api: MockRouter): + """Test attempting to delete a model without confirmation returns 409.""" + # Create test model using factory + dummy = model_factory() + + # Mock the DELETE endpoint with 409 confirmation required response + error_details = "Cannot delete model with dependent models: dedupe1, dedupe2" + route = matchbox_api.delete(f"/models/{dummy.model.metadata.name}").mock( + return_value=Response( + 409, + json=ModelOperationStatus( + success=False, + model_name=dummy.model.metadata.name, + operation=ModelOperationType.DELETE, + details=error_details, + ).model_dump(), + ) + ) + + # Attempt to delete without certain=True + with pytest.raises(MatchboxDeletionNotConfirmed): + dummy.model.delete() + + # Verify the response and API call + assert route.called + assert route.calls.last.request.url.params["certain"] == "false" diff --git a/test/client/test_visualisation.py b/test/client/test_visualisation.py index a9f2076..2b34119 100644 --- a/test/client/test_visualisation.py +++ b/test/client/test_visualisation.py @@ -1,14 +1,15 @@ -import respx from httpx import Response from matplotlib.figure import Figure +from respx import MockRouter -from matchbox.client._handler import url from matchbox.client.visualisation import draw_resolution_graph +from matchbox.common.graph import ResolutionGraph -@respx.mock -def test_draw_resolution_graph(resolution_graph): - respx.get(url("/report/resolutions")).mock( +def test_draw_resolution_graph( + matchbox_api: MockRouter, resolution_graph: ResolutionGraph +): + matchbox_api.get("/report/resolutions").mock( return_value=Response(200, content=resolution_graph.model_dump_json()), ) diff --git a/test/common/test_factories.py b/test/common/test_factories.py index 96b85c2..68c4d7f 100644 --- a/test/common/test_factories.py +++ b/test/common/test_factories.py @@ -285,15 +285,15 @@ def test_source_factory_metrics_with_multiple_features(): FeatureConfig( name="company_name", base_generator="company", - variations=[ + variations=( SuffixRule(suffix=" Inc"), SuffixRule(suffix=" Ltd"), - ], + ), ), FeatureConfig( name="email", base_generator="email", - variations=[ReplaceRule(old="@", new="+test@")], + variations=(ReplaceRule(old="@", new="+test@"),), ), ] @@ -418,12 +418,12 @@ def test_source_factory_mock_properties(): FeatureConfig( name="company_name", base_generator="company", - variations=[SuffixRule(suffix=" Ltd")], + variations=(SuffixRule(suffix=" Ltd"),), ), FeatureConfig( name="registration_id", base_generator="numerify", - parameters={"text": "######"}, + parameters=(("text", "######"),), ), ] @@ -462,15 +462,15 @@ def test_source_factory_mock_properties(): def test_model_factory_default(): """Test that model_factory generates a dummy model with default parameters.""" - model = model_factory() + dummy = model_factory() - assert model.metrics.n_true_entities == 10 - assert model.model.type == ModelType.DEDUPER - assert model.model.right_resolution is None + assert dummy.metrics.n_true_entities == 10 + assert dummy.model.metadata.type == ModelType.DEDUPER + assert dummy.model.metadata.right_resolution is None # Check that probabilities table was generated correctly - assert len(model.data) > 0 - assert model.data.schema.equals(SCHEMA_RESULTS) + assert len(dummy.data) > 0 + assert dummy.data.schema.equals(SCHEMA_RESULTS) def test_model_factory_with_custom_params(): @@ -480,19 +480,19 @@ def test_model_factory_with_custom_params(): n_true_entities = 5 prob_range = (0.9, 1.0) - model = model_factory( + dummy = model_factory( name=name, description=description, n_true_entities=n_true_entities, prob_range=prob_range, ) - assert model.model.name == name - assert model.model.description == description - assert model.metrics.n_true_entities == n_true_entities + assert dummy.model.metadata.name == name + assert dummy.model.metadata.description == description + assert dummy.metrics.n_true_entities == n_true_entities # Check probability range - probs = model.data.column("probability").to_pylist() + probs = dummy.data.column("probability").to_pylist() assert all(90 <= p <= 100 for p in probs) @@ -505,16 +505,16 @@ def test_model_factory_with_custom_params(): ) def test_model_factory_different_types(model_type: str): """Test model_factory handles different model types correctly.""" - model = model_factory(model_type=model_type) + dummy = model_factory(model_type=model_type) - assert model.model.type == model_type + assert dummy.model.metadata.type == model_type if model_type == ModelType.LINKER: - assert model.model.right_resolution is not None + assert dummy.model.metadata.right_resolution is not None # Check that left and right values are in different ranges - left_vals = model.data.column("left_id").to_pylist() - right_vals = model.data.column("right_id").to_pylist() + left_vals = dummy.data.column("left_id").to_pylist() + right_vals = dummy.data.column("right_id").to_pylist() left_min, left_max = min(left_vals), max(left_vals) right_min, right_max = min(right_vals), max(right_vals) assert (left_min < left_max < right_min < right_max) or ( @@ -531,14 +531,14 @@ def test_model_factory_different_types(model_type: str): ) def test_model_factory_seed_behavior(seed1: int, seed2: int, should_be_equal: bool): """Test that model_factory handles seeds correctly for reproducibility.""" - model1 = model_factory(seed=seed1) - model2 = model_factory(seed=seed2) + dummy1 = model_factory(seed=seed1) + dummy2 = model_factory(seed=seed2) if should_be_equal: - assert model1.model.name == model2.model.name - assert model1.model.description == model2.model.description - assert model1.data.equals(model2.data) + assert dummy1.model.metadata.name == dummy2.model.metadata.name + assert dummy1.model.metadata.description == dummy2.model.metadata.description + assert dummy1.data.equals(dummy2.data) else: - assert model1.model.name != model2.model.name - assert model1.model.description != model2.model.description - assert not model1.data.equals(model2.data) + assert dummy1.model.metadata.name != dummy2.model.metadata.name + assert dummy1.model.metadata.description != dummy2.model.metadata.description + assert not dummy1.data.equals(dummy2.data) diff --git a/test/fixtures/db.py b/test/fixtures/db.py index 65a951f..3436fce 100644 --- a/test/fixtures/db.py +++ b/test/fixtures/db.py @@ -1,12 +1,15 @@ import os +from os import getenv from typing import TYPE_CHECKING, Any, Callable, Generator, Literal import boto3 import pytest +import respx from _pytest.fixtures import FixtureRequest from dotenv import find_dotenv, load_dotenv from moto import mock_aws from pandas import DataFrame +from respx import MockRouter from sqlalchemy import Engine, create_engine from sqlalchemy import text as sqltext @@ -103,7 +106,7 @@ def _db_add_dedupe_models_and_data( ) results = model.run() - results.to_matchbox(backend=backend) + results.to_matchbox() model.truth = 0.0 return _db_add_dedupe_models_and_data @@ -173,7 +176,7 @@ def _db_add_link_models_and_data( ) results = model.run() - results.to_matchbox(backend=backend) + results.to_matchbox() model.truth = 0.0 return _db_add_link_models_and_data @@ -384,3 +387,14 @@ def s3(aws_credentials: None) -> Generator[S3Client, None, None]: """Return a mocked S3 client.""" with mock_aws(): yield boto3.client("s3", region_name="eu-west-2") + + +# Mock API + + +@pytest.fixture(scope="function") +def matchbox_api() -> Generator[MockRouter, None, None]: + with respx.mock( + base_url=getenv("MB__CLIENT__API_ROOT"), assert_all_called=True + ) as respx_mock: + yield respx_mock diff --git a/test/server/api/test_routes.py b/test/server/api/test_routes.py index 606c919..abf31bf 100644 --- a/test/server/api/test_routes.py +++ b/test/server/api/test_routes.py @@ -9,11 +9,19 @@ from fastapi.testclient import TestClient from matchbox.common.arrow import SCHEMA_MB_IDS, table_to_buffer -from matchbox.common.dtos import BackendRetrievableType, UploadStatus +from matchbox.common.dtos import ( + BackendRetrievableType, + ModelAncestor, + ModelOperationType, + NotFoundError, + UploadStatus, +) from matchbox.common.exceptions import ( + MatchboxDeletionNotConfirmed, MatchboxResolutionNotFoundError, MatchboxSourceNotFoundError, ) +from matchbox.common.factories.models import model_factory from matchbox.common.factories.sources import source_factory from matchbox.common.graph import ResolutionGraph from matchbox.common.hash import hash_to_base64 @@ -30,6 +38,9 @@ client = TestClient(app) +# General + + def test_healthcheck(): """Test the healthcheck endpoint.""" response = client.get("/health") @@ -73,65 +84,10 @@ def test_count_backend_item(get_backend: MatchboxDBAdapter): assert response.json() == {"entities": {"models": 20}} -# def test_clear_backend(): -# response = client.post("/testing/clear") -# assert response.status_code == 200 - -# def test_list_sources(): -# response = client.get("/sources") -# assert response.status_code == 200 - - -@patch("matchbox.server.base.BackendManager.get_backend") -def test_get_source(get_backend): - dummy_source = Source( - address=SourceAddress(full_name="foo", warehouse_hash=b"bar"), db_pk="pk" - ) - mock_backend = Mock() - mock_backend.get_source = Mock(return_value=dummy_source) - get_backend.return_value = mock_backend - - response = client.get(f"/sources/{hash_to_base64(b'bar')}/foo") - assert response.status_code == 200 - assert Source.model_validate(response.json()) - - -@patch("matchbox.server.base.BackendManager.get_backend") -def test_get_source_404(get_backend): - mock_backend = Mock() - mock_backend.get_source = Mock(side_effect=MatchboxSourceNotFoundError) - get_backend.return_value = mock_backend - - response = client.get(f"/sources/{hash_to_base64(b'bar')}/foo") - assert response.status_code == 404 - assert response.json()["entity"] == BackendRetrievableType.SOURCE - - -@patch("matchbox.server.base.BackendManager.get_backend") -def test_add_source(get_backend: Mock): - """Test the source addition endpoint.""" - # Setup - mock_backend = Mock() - mock_backend.index = Mock(return_value=None) - get_backend.return_value = mock_backend - - dummy_source = source_factory() - - # Make request - response = client.post("/sources", json=dummy_source.source.model_dump()) - - # Validate response - assert UploadStatus.model_validate(response.json()) - assert response.status_code == 200, response.json() - assert response.json()["status"] == "awaiting_upload" - assert response.json().get("id") is not None - mock_backend.index.assert_not_called() - - @patch("matchbox.server.base.BackendManager.get_backend") @patch("matchbox.server.api.routes.metadata_store") @patch("matchbox.server.api.routes.BackgroundTasks.add_task") -def test_source_upload( +def test_upload( mock_add_task: Mock, metadata_store: Mock, get_backend: Mock, s3: S3Client ): """Test uploading a file, happy path.""" @@ -168,7 +124,7 @@ def test_source_upload( # Validate response assert UploadStatus.model_validate(response.json()) - assert response.status_code == 200, response.json() + assert response.status_code == 202, response.json() assert response.json()["status"] == "queued" # Updated to check for queued status # Check both status updates were called in correct order assert metadata_store.update_status.call_args_list == [ @@ -180,7 +136,49 @@ def test_source_upload( @patch("matchbox.server.base.BackendManager.get_backend") @patch("matchbox.server.api.routes.metadata_store") -def test_upload_status_check(metadata_store: Mock, get_backend: Mock): +@patch("matchbox.server.api.routes.BackgroundTasks.add_task") +def test_upload_wrong_schema( + mock_add_task: Mock, metadata_store: Mock, get_backend: Mock, s3: S3Client +): + """Test uploading a file with wrong schema.""" + # Setup + mock_backend = Mock() + mock_backend.settings.datastore.get_client.return_value = s3 + mock_backend.settings.datastore.cache_bucket_name = "test-bucket" + get_backend.return_value = mock_backend + + # Create source with results schema instead of index + dummy_source = source_factory() + + # Setup store + store = MetadataStore() + update_id = store.cache_source(dummy_source.source) + metadata_store.get.side_effect = store.get + metadata_store.update_status.side_effect = store.update_status + + # Make request with actual data instead of the hashes -- wrong schema + response = client.post( + f"/upload/{update_id}", + files={ + "file": ( + "hashes.parquet", + table_to_buffer(dummy_source.data), + "application/octet-stream", + ), + }, + ) + + # Should fail before background task starts + assert response.status_code == 400 + assert response.json()["status"] == "failed" + assert "schema mismatch" in response.json()["details"].lower() + metadata_store.update_status.assert_called_with(update_id, "failed", details=ANY) + mock_add_task.assert_not_called() # Background task should not be queued + + +@patch("matchbox.server.base.BackendManager.get_backend") # Stops real backend call +@patch("matchbox.server.api.routes.metadata_store") +def test_upload_status_check(metadata_store: Mock, _: Mock): """Test checking status of an upload using the status endpoint.""" # Setup store with a processing entry store = MetadataStore() @@ -200,9 +198,9 @@ def test_upload_status_check(metadata_store: Mock, get_backend: Mock): metadata_store.update_status.assert_not_called() -@patch("matchbox.server.base.BackendManager.get_backend") +@patch("matchbox.server.base.BackendManager.get_backend") # Stops real backend call @patch("matchbox.server.api.routes.metadata_store") -def test_upload_already_processing(metadata_store: Mock, get_backend: Mock): +def test_upload_already_processing(metadata_store: Mock, _: Mock): """Test attempting to upload when status is already processing.""" # Setup store with a processing entry store = MetadataStore() @@ -223,8 +221,9 @@ def test_upload_already_processing(metadata_store: Mock, get_backend: Mock): assert response.json()["status"] == "processing" +@patch("matchbox.server.base.BackendManager.get_backend") # Stops real backend call @patch("matchbox.server.api.routes.metadata_store") -def test_upload_already_queued(metadata_store: Mock): +def test_upload_already_queued(metadata_store: Mock, _: Mock): """Test attempting to upload when status is already queued.""" # Setup store with a queued entry store = MetadataStore() @@ -245,48 +244,6 @@ def test_upload_already_queued(metadata_store: Mock): assert response.json()["status"] == "queued" -@patch("matchbox.server.base.BackendManager.get_backend") -@patch("matchbox.server.api.routes.metadata_store") -@patch("matchbox.server.api.routes.BackgroundTasks.add_task") -def test_source_upload_wrong_schema( - mock_add_task: Mock, metadata_store: Mock, get_backend: Mock, s3: S3Client -): - """Test uploading a file with wrong schema.""" - # Setup - mock_backend = Mock() - mock_backend.settings.datastore.get_client.return_value = s3 - mock_backend.settings.datastore.cache_bucket_name = "test-bucket" - get_backend.return_value = mock_backend - - # Create source with results schema instead of index - dummy_source = source_factory() - - # Setup store - store = MetadataStore() - update_id = store.cache_source(dummy_source.source) - metadata_store.get.side_effect = store.get - metadata_store.update_status.side_effect = store.update_status - - # Make request with actual data instead of the hashes -- wrong schema - response = client.post( - f"/upload/{update_id}", - files={ - "file": ( - "hashes.parquet", - table_to_buffer(dummy_source.data), - "application/octet-stream", - ), - }, - ) - - # Should fail before background task starts - assert response.status_code == 400 - assert response.json()["status"] == "failed" - assert "schema mismatch" in response.json()["details"].lower() - metadata_store.update_status.assert_called_with(update_id, "failed", details=ANY) - mock_add_task.assert_not_called() # Background task should not be queued - - @patch("matchbox.server.api.routes.metadata_store") def test_status_check_not_found(metadata_store: Mock): """Test checking status for non-existent upload ID.""" @@ -299,120 +256,7 @@ def test_status_check_not_found(metadata_store: Mock): assert "not found or expired" in response.json()["details"].lower() -@pytest.mark.asyncio -@patch("matchbox.server.base.BackendManager.get_backend") -async def test_complete_upload_process(get_backend: Mock, s3: S3Client): - """Test the complete upload process from source creation through processing.""" - # Setup the backend - mock_backend = Mock() - mock_backend.settings.datastore.get_client.return_value = s3 - mock_backend.settings.datastore.cache_bucket_name = "test-bucket" - mock_backend.index = Mock(return_value=None) - get_backend.return_value = mock_backend - - # Create test bucket - s3.create_bucket( - Bucket="test-bucket", - CreateBucketConfiguration={"LocationConstraint": "eu-west-2"}, - ) - - # Create test data - dummy_source = source_factory() - - # Step 1: Add source - response = client.post("/sources", json=dummy_source.source.model_dump()) - assert response.status_code == 200 - upload_id = response.json()["id"] - assert response.json()["status"] == "awaiting_upload" - - # Step 2: Upload file with real background tasks - response = client.post( - f"/upload/{upload_id}", - files={ - "file": ( - "hashes.parquet", - table_to_buffer(dummy_source.data_hashes), - "application/octet-stream", - ), - }, - ) - assert response.status_code == 200 - assert response.json()["status"] == "queued" - - # Step 3: Poll status until complete or timeout - max_attempts = 10 - current_attempt = 0 - while current_attempt < max_attempts: - response = client.get(f"/upload/{upload_id}/status") - assert response.status_code == 200 - - status = response.json()["status"] - if status == "complete": - break - elif status == "failed": - pytest.fail(f"Upload failed: {response.json().get('details')}") - elif status in ["processing", "queued"]: - await asyncio.sleep(0.1) # Small delay between polls - else: - pytest.fail(f"Unexpected status: {status}") - - current_attempt += 1 - - assert current_attempt < max_attempts, ( - "Timed out waiting for processing to complete" - ) - assert status == "complete" - - # Verify backend.index was called with correct arguments - mock_backend.index.assert_called_once() - call_args = mock_backend.index.call_args - assert call_args[1]["source"] == dummy_source.source # Check source matches - assert call_args[1]["data_hashes"].equals(dummy_source.data_hashes) # Check data - - -# def test_list_models(): -# response = client.get("/models") -# assert response.status_code == 200 - -# def test_get_resolution(): -# response = client.get("/models/test_resolution") -# assert response.status_code == 200 - -# def test_add_model(): -# response = client.post("/models") -# assert response.status_code == 200 - -# def test_delete_model(): -# response = client.delete("/models/test_model") -# assert response.status_code == 200 - -# def test_get_results(): -# response = client.get("/models/test_model/results") -# assert response.status_code == 200 - -# def test_set_results(): -# response = client.post("/models/test_model/results") -# assert response.status_code == 200 - -# def test_get_truth(): -# response = client.get("/models/test_model/truth") -# assert response.status_code == 200 - -# def test_set_truth(): -# response = client.post("/models/test_model/truth") -# assert response.status_code == 200 - -# def test_get_ancestors(): -# response = client.get("/models/test_model/ancestors") -# assert response.status_code == 200 - -# def test_get_ancestors_cache(): -# response = client.get("/models/test_model/ancestors_cache") -# assert response.status_code == 200 - -# def test_set_ancestors_cache(): -# response = client.post("/models/test_model/ancestors_cache") -# assert response.status_code == 200 +# Retrieval @patch("matchbox.server.base.BackendManager.get_backend") @@ -573,15 +417,626 @@ def test_match_404_source(get_backend: Mock): assert response.json()["entity"] == BackendRetrievableType.SOURCE +# Data management + + @patch("matchbox.server.base.BackendManager.get_backend") -def test_get_resolution_graph( - get_backend: MatchboxDBAdapter, resolution_graph: ResolutionGraph -): - """Test the resolution graph report endpoint.""" +def test_get_source(get_backend): + dummy_source = Source( + address=SourceAddress(full_name="foo", warehouse_hash=b"bar"), db_pk="pk" + ) mock_backend = Mock() - mock_backend.get_resolution_graph = Mock(return_value=resolution_graph) + mock_backend.get_source = Mock(return_value=dummy_source) get_backend.return_value = mock_backend - response = client.get("/report/resolutions") + response = client.get(f"/sources/{hash_to_base64(b'bar')}/foo") assert response.status_code == 200 - assert ResolutionGraph.model_validate(response.json()) + assert Source.model_validate(response.json()) + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_get_source_404(get_backend): + mock_backend = Mock() + mock_backend.get_source = Mock(side_effect=MatchboxSourceNotFoundError) + get_backend.return_value = mock_backend + + response = client.get(f"/sources/{hash_to_base64(b'bar')}/foo") + assert response.status_code == 404 + assert response.json()["entity"] == BackendRetrievableType.SOURCE + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_add_source(get_backend: Mock): + """Test the source addition endpoint.""" + # Setup + mock_backend = Mock() + mock_backend.index = Mock(return_value=None) + get_backend.return_value = mock_backend + + dummy_source = source_factory() + + # Make request + response = client.post("/sources", json=dummy_source.source.model_dump()) + + # Validate response + assert UploadStatus.model_validate(response.json()) + assert response.status_code == 202, response.json() + assert response.json()["status"] == "awaiting_upload" + assert response.json().get("id") is not None + mock_backend.index.assert_not_called() + + +@pytest.mark.asyncio +@patch("matchbox.server.base.BackendManager.get_backend") +async def test_complete_source_upload_process(get_backend: Mock, s3: S3Client): + """Test the complete upload process from source creation through processing.""" + # Setup the backend + mock_backend = Mock() + mock_backend.settings.datastore.get_client.return_value = s3 + mock_backend.settings.datastore.cache_bucket_name = "test-bucket" + mock_backend.index = Mock(return_value=None) + get_backend.return_value = mock_backend + + # Create test bucket + s3.create_bucket( + Bucket="test-bucket", + CreateBucketConfiguration={"LocationConstraint": "eu-west-2"}, + ) + + # Create test data + dummy_source = source_factory() + + # Step 1: Add source + response = client.post("/sources", json=dummy_source.source.model_dump()) + assert response.status_code == 202 + upload_id = response.json()["id"] + assert response.json()["status"] == "awaiting_upload" + + # Step 2: Upload file with real background tasks + response = client.post( + f"/upload/{upload_id}", + files={ + "file": ( + "hashes.parquet", + table_to_buffer(dummy_source.data_hashes), + "application/octet-stream", + ), + }, + ) + assert response.status_code == 202 + assert response.json()["status"] == "queued" + + # Step 3: Poll status until complete or timeout + max_attempts = 10 + current_attempt = 0 + while current_attempt < max_attempts: + response = client.get(f"/upload/{upload_id}/status") + assert response.status_code == 200 + + status = response.json()["status"] + if status == "complete": + break + elif status == "failed": + pytest.fail(f"Upload failed: {response.json().get('details')}") + elif status in ["processing", "queued"]: + await asyncio.sleep(0.1) # Small delay between polls + else: + pytest.fail(f"Unexpected status: {status}") + + current_attempt += 1 + + assert current_attempt < max_attempts, ( + "Timed out waiting for processing to complete" + ) + assert status == "complete" + assert response.status_code == 200 + + # Verify backend.index was called with correct arguments + mock_backend.index.assert_called_once() + call_args = mock_backend.index.call_args + assert call_args[1]["source"] == dummy_source.source # Check source matches + assert call_args[1]["data_hashes"].equals(dummy_source.data_hashes) # Check data + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_get_resolution_graph( + get_backend: MatchboxDBAdapter, resolution_graph: ResolutionGraph +): + """Test the resolution graph report endpoint.""" + mock_backend = Mock() + mock_backend.get_resolution_graph = Mock(return_value=resolution_graph) + get_backend.return_value = mock_backend + + response = client.get("/report/resolutions") + assert response.status_code == 200 + assert ResolutionGraph.model_validate(response.json()) + + +# Model management + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_insert_model(get_backend: Mock): + mock_backend = Mock() + get_backend.return_value = mock_backend + + dummy = model_factory(name="test_model") + response = client.post("/models", json=dummy.model.metadata.model_dump()) + + assert response.status_code == 201 + assert response.json() == { + "success": True, + "model_name": "test_model", + "operation": ModelOperationType.INSERT.value, + "details": None, + } + mock_backend.insert_model.assert_called_once_with(dummy.model.metadata) + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_insert_model_error(get_backend: Mock): + mock_backend = Mock() + mock_backend.insert_model = Mock(side_effect=Exception("Test error")) + get_backend.return_value = mock_backend + + dummy = model_factory() + response = client.post("/models", json=dummy.model.metadata.model_dump()) + + assert response.status_code == 500 + assert response.json()["success"] is False + assert response.json()["details"] == "Test error" + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_get_model(get_backend: Mock): + mock_backend = Mock() + dummy = model_factory(name="test_model", description="test description") + mock_backend.get_model = Mock(return_value=dummy.model.metadata) + get_backend.return_value = mock_backend + + response = client.get("/models/test_model") + + assert response.status_code == 200 + assert response.json()["name"] == dummy.model.metadata.name + assert response.json()["description"] == dummy.model.metadata.description + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_get_model_not_found(get_backend: Mock): + mock_backend = Mock() + mock_backend.get_model = Mock(side_effect=MatchboxResolutionNotFoundError()) + get_backend.return_value = mock_backend + + response = client.get("/models/nonexistent") + + assert response.status_code == 404 + assert response.json()["entity"] == BackendRetrievableType.RESOLUTION + + +@pytest.mark.parametrize("model_type", ["deduper", "linker"]) +@patch("matchbox.server.base.BackendManager.get_backend") +@patch("matchbox.server.api.routes.metadata_store") +@patch("matchbox.server.api.routes.BackgroundTasks.add_task") +def test_model_upload( + mock_add_task: Mock, + metadata_store: Mock, + get_backend: Mock, + s3: S3Client, + model_type: str, +): + """Test uploading different types of files.""" + # Setup + mock_backend = Mock() + mock_backend.settings.datastore.get_client.return_value = s3 + mock_backend.settings.datastore.cache_bucket_name = "test-bucket" + get_backend.return_value = mock_backend + s3.create_bucket( + Bucket="test-bucket", + CreateBucketConfiguration={"LocationConstraint": "eu-west-2"}, + ) + + # Create test data with specified model type + dummy = model_factory(model_type=model_type) + + # Setup metadata store + store = MetadataStore() + upload_id = store.cache_model(dummy.model.metadata) + + metadata_store.get.side_effect = store.get + metadata_store.update_status.side_effect = store.update_status + + # Make request + response = client.post( + f"/upload/{upload_id}", + files={ + "file": ( + "data.parquet", + table_to_buffer(dummy.data), + "application/octet-stream", + ), + }, + ) + + # Validate response + assert response.status_code == 202 + assert response.json()["status"] == "queued" + mock_add_task.assert_called_once() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_type", ["deduper", "linker"]) +@patch("matchbox.server.base.BackendManager.get_backend") +async def test_complete_model_upload_process( + get_backend: Mock, s3: S3Client, model_type: str +): + """Test the complete upload process for models from creation through processing.""" + # Setup the backend + mock_backend = Mock() + mock_backend.settings.datastore.get_client.return_value = s3 + mock_backend.settings.datastore.cache_bucket_name = "test-bucket" + mock_backend.set_model_results = Mock(return_value=None) + get_backend.return_value = mock_backend + + # Create test bucket + s3.create_bucket( + Bucket="test-bucket", + CreateBucketConfiguration={"LocationConstraint": "eu-west-2"}, + ) + + # Create test data with specified model type + dummy = model_factory(model_type=model_type) + + # Set up the mock to return the actual model metadata and data + mock_backend.get_model = Mock(return_value=dummy.model.metadata) + mock_backend.get_model_results = Mock(return_value=dummy.data) + + # Step 1: Create model + response = client.post("/models", json=dummy.model.metadata.model_dump()) + assert response.status_code == 201 + assert response.json()["success"] is True + assert response.json()["model_name"] == dummy.model.metadata.name + + # Step 2: Initialize results upload + response = client.post(f"/models/{dummy.model.metadata.name}/results") + assert response.status_code == 202 + upload_id = response.json()["id"] + assert response.json()["status"] == "awaiting_upload" + + # Step 3: Upload results file with real background tasks + response = client.post( + f"/upload/{upload_id}", + files={ + "file": ( + "results.parquet", + table_to_buffer(dummy.data), + "application/octet-stream", + ), + }, + ) + assert response.status_code == 202 + assert response.json()["status"] == "queued" + + # Step 4: Poll status until complete or timeout + max_attempts = 10 + current_attempt = 0 + status = None + + while current_attempt < max_attempts: + response = client.get(f"/upload/{upload_id}/status") + assert response.status_code == 200 + + status = response.json()["status"] + if status == "complete": + break + elif status == "failed": + pytest.fail(f"Upload failed: {response.json().get('details')}") + elif status in ["processing", "queued"]: + await asyncio.sleep(0.1) # Small delay between polls + else: + pytest.fail(f"Unexpected status: {status}") + + current_attempt += 1 + + assert current_attempt < max_attempts, ( + "Timed out waiting for processing to complete" + ) + assert status == "complete" + assert response.status_code == 200 + + # Step 5: Verify results were stored correctly + mock_backend.set_model_results.assert_called_once() + call_args = mock_backend.set_model_results.call_args + assert ( + call_args[1]["model"] == dummy.model.metadata.name + ) # Check model name matches + assert call_args[1]["results"].equals(dummy.data) # Check results data matches + + # Step 6: Verify we can retrieve the results + response = client.get(f"/models/{dummy.model.metadata.name}/results") + assert response.status_code == 200 + assert response.headers["content-type"] == "application/octet-stream" + + # Step 7: Additional model-specific verifications + if model_type == "linker": + # For linker models, verify left and right resolutions are set + assert dummy.model.metadata.left_resolution is not None + assert dummy.model.metadata.right_resolution is not None + else: + # For deduper models, verify only left resolution is set + assert dummy.model.metadata.left_resolution is not None + assert dummy.model.metadata.right_resolution is None + + # Verify the model truth can be set and retrieved + truth_value = 0.85 + mock_backend.get_model_truth = Mock(return_value=truth_value) + + response = client.patch( + f"/models/{dummy.model.metadata.name}/truth", json=truth_value + ) + assert response.status_code == 200 + + response = client.get(f"/models/{dummy.model.metadata.name}/truth") + assert response.status_code == 200 + assert response.json() == truth_value + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_set_results(get_backend: Mock): + mock_backend = Mock() + dummy = model_factory() + mock_backend.get_model = Mock(return_value=dummy.model.metadata) + get_backend.return_value = mock_backend + + response = client.post(f"/models/{dummy.model.metadata.name}/results") + + assert response.status_code == 202 + assert response.json()["status"] == "awaiting_upload" + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_set_results_model_not_found(get_backend: Mock): + """Test setting results for a non-existent model.""" + mock_backend = Mock() + mock_backend.get_model = Mock(side_effect=MatchboxResolutionNotFoundError()) + get_backend.return_value = mock_backend + + response = client.post("/models/nonexistent-model/results") + + assert response.status_code == 404 + assert response.json()["entity"] == BackendRetrievableType.RESOLUTION + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_get_results(get_backend: Mock): + mock_backend = Mock() + dummy = model_factory() + mock_backend.get_model_results = Mock(return_value=dummy.data) + get_backend.return_value = mock_backend + + response = client.get(f"/models/{dummy.model.metadata.name}/results") + + assert response.status_code == 200 + assert response.headers["content-type"] == "application/octet-stream" + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_set_truth(get_backend: Mock): + mock_backend = Mock() + dummy = model_factory() + get_backend.return_value = mock_backend + + response = client.patch(f"/models/{dummy.model.metadata.name}/truth", json=0.95) + + assert response.status_code == 200 + assert response.json()["success"] is True + mock_backend.set_model_truth.assert_called_once_with( + model=dummy.model.metadata.name, truth=0.95 + ) + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_set_truth_invalid_value(get_backend: Mock): + """Test setting an invalid truth value (outside 0-1 range).""" + mock_backend = Mock() + dummy = model_factory() + get_backend.return_value = mock_backend + + # Test value > 1 + response = client.patch(f"/models/{dummy.model.metadata.name}/truth", json=1.5) + assert response.status_code == 422 + + # Test value < 0 + response = client.patch(f"/models/{dummy.model.metadata.name}/truth", json=-0.5) + assert response.status_code == 422 + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_get_truth(get_backend: Mock): + mock_backend = Mock() + dummy = model_factory() + mock_backend.get_model_truth = Mock(return_value=0.95) + get_backend.return_value = mock_backend + + response = client.get(f"/models/{dummy.model.metadata.name}/truth") + + assert response.status_code == 200 + assert response.json() == 0.95 + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_get_ancestors(get_backend: Mock): + mock_backend = Mock() + dummy = model_factory() + mock_ancestors = [ + ModelAncestor(name="parent_model", truth=0.7), + ModelAncestor(name="grandparent_model", truth=0.97), + ] + mock_backend.get_model_ancestors = Mock(return_value=mock_ancestors) + get_backend.return_value = mock_backend + + response = client.get(f"/models/{dummy.model.metadata.name}/ancestors") + + assert response.status_code == 200 + assert len(response.json()) == 2 + assert [ModelAncestor.model_validate(a) for a in response.json()] == mock_ancestors + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_get_ancestors_cache(get_backend: Mock): + """Test retrieving the ancestors cache for a model.""" + mock_backend = Mock() + dummy = model_factory() + mock_ancestors = [ + ModelAncestor(name="parent_model", truth=0.7), + ModelAncestor(name="grandparent_model", truth=0.8), + ] + mock_backend.get_model_ancestors_cache = Mock(return_value=mock_ancestors) + get_backend.return_value = mock_backend + + response = client.get(f"/models/{dummy.model.metadata.name}/ancestors_cache") + + assert response.status_code == 200 + assert len(response.json()) == 2 + assert [ModelAncestor.model_validate(a) for a in response.json()] == mock_ancestors + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_set_ancestors_cache(get_backend: Mock): + """Test setting the ancestors cache for a model.""" + mock_backend = Mock() + dummy = model_factory() + get_backend.return_value = mock_backend + + ancestors_data = [ + ModelAncestor(name="parent_model", truth=0.7), + ModelAncestor(name="grandparent_model", truth=0.8), + ] + + response = client.patch( + f"/models/{dummy.model.metadata.name}/ancestors_cache", + json=[a.model_dump() for a in ancestors_data], + ) + + assert response.status_code == 200 + assert response.json()["success"] is True + assert response.json()["operation"] == ModelOperationType.UPDATE_ANCESTOR_CACHE + mock_backend.set_model_ancestors_cache.assert_called_once_with( + model=dummy.model.metadata.name, ancestors_cache=ancestors_data + ) + + +@pytest.mark.parametrize( + "endpoint", + ["results", "truth", "ancestors", "ancestors_cache"], +) +@patch("matchbox.server.base.BackendManager.get_backend") +def test_model_get_endpoints_404( + get_backend: Mock, + endpoint: str, +) -> None: + """Test 404 responses for model GET endpoints when model doesn't exist.""" + # Setup backend mock + mock_backend = Mock() + mock_method = getattr(mock_backend, f"get_model_{endpoint}") + mock_method.side_effect = MatchboxResolutionNotFoundError() + get_backend.return_value = mock_backend + + # Make request + response = client.get(f"/models/nonexistent-model/{endpoint}") + + # Verify response + assert response.status_code == 404 + error = NotFoundError.model_validate(response.json()) + assert error.entity == BackendRetrievableType.RESOLUTION + + +@pytest.mark.parametrize( + ("endpoint", "payload"), + [ + ("truth", 0.95), + ( + "ancestors_cache", + [ + ModelAncestor(name="parent_model", truth=0.7).model_dump(), + ModelAncestor(name="grandparent_model", truth=0.8).model_dump(), + ], + ), + ], +) +@patch("matchbox.server.base.BackendManager.get_backend") +def test_model_patch_endpoints_404( + get_backend: Mock, + endpoint: str, + payload: float | list[dict[str, Any]], +) -> None: + """Test 404 responses for model PATCH endpoints when model doesn't exist.""" + # Setup backend mock + mock_backend = Mock() + mock_method = getattr(mock_backend, f"set_model_{endpoint}") + mock_method.side_effect = MatchboxResolutionNotFoundError() + get_backend.return_value = mock_backend + + # Make request + response = client.patch(f"/models/nonexistent-model/{endpoint}", json=payload) + + # Verify response + assert response.status_code == 404 + error = NotFoundError.model_validate(response.json()) + assert error.entity == BackendRetrievableType.RESOLUTION + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_delete_model(get_backend: Mock): + mock_backend = Mock() + get_backend.return_value = mock_backend + + dummy = model_factory() + response = client.delete( + f"/models/{dummy.model.metadata.name}", params={"certain": True} + ) + + assert response.status_code == 200 + assert response.json() == { + "success": True, + "model_name": dummy.model.metadata.name, + "operation": ModelOperationType.DELETE, + "details": None, + } + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_delete_model_needs_confirmation(get_backend: Mock): + mock_backend = Mock() + mock_backend.delete_model = Mock( + side_effect=MatchboxDeletionNotConfirmed(children=["dedupe1", "dedupe2"]) + ) + get_backend.return_value = mock_backend + + dummy = model_factory() + response = client.delete(f"/models/{dummy.model.metadata.name}") + + assert response.status_code == 409 + assert response.json()["success"] is False + message = response.json()["details"] + assert "dedupe1" in message and "dedupe2" in message + + +@pytest.mark.parametrize( + "certain", + [True, False], +) +@patch("matchbox.server.base.BackendManager.get_backend") +def test_delete_model_404(get_backend: Mock, certain: bool) -> None: + """Test 404 response when trying to delete a non-existent model.""" + # Setup backend mock + mock_backend = Mock() + mock_backend.delete_model.side_effect = MatchboxResolutionNotFoundError() + get_backend.return_value = mock_backend + + # Make request + response = client.delete("/models/nonexistent-model", params={"certain": certain}) + + # Verify response + assert response.status_code == 404 + error = NotFoundError.model_validate(response.json()) + assert error.entity == BackendRetrievableType.RESOLUTION diff --git a/test/server/test_adapter.py b/test/server/test_adapter.py index 28b08ed..8033afc 100644 --- a/test/server/test_adapter.py +++ b/test/server/test_adapter.py @@ -8,7 +8,7 @@ from sqlalchemy import Engine from matchbox.client.helpers.selector import match, query, select -from matchbox.common.dtos import ModelMetadata, ModelType +from matchbox.common.dtos import ModelAncestor, ModelMetadata, ModelType from matchbox.common.exceptions import ( MatchboxDataNotFound, MatchboxResolutionNotFoundError, @@ -354,15 +354,13 @@ def test_model_ancestors(self): linker_name = "deterministic_naive_test.crn_naive_test.duns" linker_ancestors = self.backend.get_model_ancestors(model=linker_name) - assert isinstance(linker_ancestors, dict) + assert [isinstance(ancestor, ModelAncestor) for ancestor in linker_ancestors] + # Not all ancestors have truth values, but one must truth_found = False - for model, truth in linker_ancestors.items(): - if isinstance(truth, float): - # Not all ancestors have truth values, but one must + for ancestor in linker_ancestors: + if isinstance(ancestor.truth, float): truth_found = True - assert isinstance(model, str) - assert isinstance(truth, (float, type(None))) assert truth_found @@ -375,7 +373,10 @@ def test_model_ancestors_cache(self): pre_ancestors_cache = self.backend.get_model_ancestors_cache(model=linker_name) # Set - updated_ancestors_cache = {k: 0.5 for k in pre_ancestors_cache.keys()} + updated_ancestors_cache = [ + ModelAncestor(name=ancestor.name, truth=0.5) + for ancestor in pre_ancestors_cache + ] self.backend.set_model_ancestors_cache( model=linker_name, ancestors_cache=updated_ancestors_cache ) diff --git a/uv.lock b/uv.lock index 2355eab..058e32c 100644 --- a/uv.lock +++ b/uv.lock @@ -258,7 +258,7 @@ name = "click" version = "8.1.8" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } wheels = [ @@ -849,7 +849,7 @@ name = "ipykernel" version = "6.29.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "appnope", marker = "platform_system == 'Darwin'" }, + { name = "appnope", marker = "sys_platform == 'darwin'" }, { name = "comm" }, { name = "debugpy" }, { name = "ipython" }, @@ -1327,7 +1327,7 @@ version = "1.6.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, { name = "ghp-import" }, { name = "jinja2" }, { name = "markdown" },