From 746d95c520989e2aa87264006989ade4b324cbd6 Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Wed, 12 Feb 2025 12:29:38 +0000 Subject: [PATCH 01/19] Adds API endpoints for the model, including rationalising AncestorCache DTO --- .vscode/launch.json | 2 +- src/matchbox/client/models/models.py | 9 +- src/matchbox/common/dtos.py | 72 +++++- src/matchbox/server/api/cache.py | 35 ++- src/matchbox/server/api/routes.py | 291 +++++++++++++++++++--- src/matchbox/server/base.py | 8 +- src/matchbox/server/postgresql/adapter.py | 39 +-- test/server/test_adapter.py | 10 +- uv.lock | 6 +- 9 files changed, 390 insertions(+), 82 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index a355886b..0db66f65 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -1,5 +1,5 @@ { - "version": "0.2.0", + "version": "0.2.1", "configurations": [ { "name": "Matchbox: Debug", diff --git a/src/matchbox/client/models/models.py b/src/matchbox/client/models/models.py index ddef734d..f62884b2 100644 --- a/src/matchbox/client/models/models.py +++ b/src/matchbox/client/models/models.py @@ -5,7 +5,7 @@ from matchbox.client.models.dedupers.base import Deduper from matchbox.client.models.linkers.base import Linker from matchbox.client.results import Results -from matchbox.common.dtos import ModelMetadata, ModelType +from matchbox.common.dtos import ModelAncestors, ModelMetadata, ModelType from matchbox.common.exceptions import MatchboxResolutionNotFoundError from matchbox.server import MatchboxDBAdapter, inject_backend @@ -64,13 +64,13 @@ def truth(self, backend: MatchboxDBAdapter, truth: float) -> None: @inject_backend def ancestors(self, backend: MatchboxDBAdapter) -> dict[str, float]: """Retrieve the ancestors of the model.""" - return backend.get_model_ancestors(model=self.metadata.name) + return backend.get_model_ancestors(model=self.metadata.name).ancestors @property @inject_backend def ancestors_cache(self, backend: MatchboxDBAdapter) -> dict[str, float]: """Retrieve the ancestors cache of the model.""" - return backend.get_model_ancestors_cache(model=self.metadata.name) + return backend.get_model_ancestors_cache(model=self.metadata.name).ancestors @ancestors_cache.setter @inject_backend @@ -79,7 +79,8 @@ def ancestors_cache( ) -> None: """Set the ancestors cache of the model.""" backend.set_model_ancestors_cache( - model=self.metadata.name, ancestors_cache=ancestors_cache + model=self.metadata.name, + ancestors_cache=ModelAncestors(ancestors=ancestors_cache), ) def run(self) -> Results: diff --git a/src/matchbox/common/dtos.py b/src/matchbox/common/dtos.py index 905e2531..2283a22e 100644 --- a/src/matchbox/common/dtos.py +++ b/src/matchbox/common/dtos.py @@ -1,7 +1,8 @@ +from datetime import datetime from enum import StrEnum from typing import Literal -from pydantic import BaseModel +from pydantic import BaseModel, Field from matchbox.common.arrow import SCHEMA_INDEX, SCHEMA_RESULTS @@ -45,6 +46,15 @@ class ModelType(StrEnum): DEDUPER = "deduper" +class ModelOperationType(StrEnum): + """Enumeration of supported model operations.""" + + INSERT = "insert" + UPDATE_TRUTH = "update_truth" + UPDATE_ANCESTOR_CACHE = "update_ancestor_cache" + DELETE = "delete" + + class ModelMetadata(BaseModel): """Metadata for a model.""" @@ -55,6 +65,66 @@ class ModelMetadata(BaseModel): right_resolution: str | None = None # Only used for linker models +class ModelAncestors(BaseModel): + """A model's ancestors and their truth values.""" + + ancestors: dict[str, float] = Field( + ..., + description="Mapping of model names to their truth thresholds", + examples=[{"model1": 0.75, "model2": 0.85}], + ) + + +class ModelOperationStatus(BaseModel): + """Status response for any model operation.""" + + success: bool + model_name: str + operation: ModelOperationType + details: str | None = None + timestamp: datetime = Field(default_factory=datetime.now) + + @classmethod + def status_500_examples(cls) -> dict: + return { + "content": { + "application/json": { + "examples": { + "unhandled": { + "summary": ( + "Unhandled exception encountered while updating the " + "model's truth value." + ), + "value": cls( + success=False, + model_name="example_model", + operation=ModelOperationType.UPDATE_TRUTH, + ).model_dump(), + }, + "confirm_delete": { + "summary": "Delete operation requires confirmation. ", + "value": cls( + success=False, + model_name="example_model", + operation=ModelOperationType.DELETE, + details=( + "This operation will delete the resolutions " + "deduper_1, deduper_2, linker_1, " + "as well as all probabilities they have created. " + "\n\n" + "It won't delete validation associated with these " + "probabilities. \n\n" + "If you're sure you want to continue, rerun with " + "certain=True" + ), + ).model_dump(), + }, + }, + } + } + } + + class HealthCheck(BaseModel): """Response model to validate and return when performing a health check.""" diff --git a/src/matchbox/server/api/cache.py b/src/matchbox/server/api/cache.py index 0e3734b3..b8aad7f9 100644 --- a/src/matchbox/server/api/cache.py +++ b/src/matchbox/server/api/cache.py @@ -6,10 +6,7 @@ import pyarrow as pa from pydantic import BaseModel, ConfigDict -from matchbox.common.dtos import ( - BackendUploadType, - UploadStatus, -) +from matchbox.common.dtos import BackendUploadType, ModelMetadata, UploadStatus from matchbox.common.sources import Source from matchbox.server.api.arrow import s3_to_recordbatch from matchbox.server.base import MatchboxDBAdapter @@ -18,7 +15,7 @@ class MetadataCacheEntry(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - metadata: Source + metadata: Source | ModelMetadata upload_type: BackendUploadType update_timestamp: datetime status: UploadStatus @@ -68,9 +65,20 @@ def cache_source(self, metadata: Source) -> str: ) return cache_id - def cache_model(self, metadata: object) -> str: + def cache_model(self, metadata: ModelMetadata) -> str: """Cache model results metadata and return ID.""" - raise NotImplementedError + self._cleanup_if_needed() + cache_id = str(uuid.uuid4()) + + self._store[cache_id] = MetadataCacheEntry( + metadata=metadata, + upload_type=BackendUploadType.RESULTS, + update_timestamp=datetime.now(), + status=UploadStatus( + id=cache_id, status="awaiting_upload", entity=BackendUploadType.RESULTS + ), + ) + return cache_id def get(self, cache_id: str) -> MetadataCacheEntry | None: """Retrieve metadata by ID if not expired. Updates timestamp on access.""" @@ -158,9 +166,10 @@ async def process_upload( """Background task to process uploaded file.""" metadata_store.update_status(upload_id, "processing") client = backend.settings.datastore.get_client() + upload = metadata_store.get(upload_id) try: - data_hashes = pa.Table.from_batches( + data = pa.Table.from_batches( [ batch async for batch in s3_to_recordbatch( @@ -168,10 +177,14 @@ async def process_upload( ) ] ) + async with heartbeat(metadata_store, upload_id): - backend.index( - source=metadata_store.get(upload_id).metadata, data_hashes=data_hashes - ) + if upload.upload_type == BackendUploadType.INDEX: + backend.index(source=upload.metadata, data_hashes=data) + elif upload.upload_type == BackendUploadType.RESULTS: + backend.set_model_results(model=upload.metadata.name, results=data) + else: + raise ValueError(f"Unknown upload type: {upload.upload_type}") metadata_store.update_status(upload_id, "complete") diff --git a/src/matchbox/server/api/routes.py b/src/matchbox/server/api/routes.py index 09e67929..4f9f3f2a 100644 --- a/src/matchbox/server/api/routes.py +++ b/src/matchbox/server/api/routes.py @@ -13,7 +13,10 @@ BackendUploadType, CountResult, HealthCheck, - ModelResultsType, + ModelAncestors, + ModelMetadata, + ModelOperationStatus, + ModelOperationType, NotFoundError, UploadStatus, ) @@ -51,7 +54,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: app = FastAPI( title="matchbox API", - version="0.2.0", + version="0.2.1", lifespan=lifespan, ) @@ -77,6 +80,8 @@ async def count_backend_items( backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], entity: BackendCountableType | None = None, ) -> CountResult: + """Count the number of various entities in the backend.""" + def get_count(e: BackendCountableType) -> int: return getattr(backend, str(e)).count() @@ -106,6 +111,7 @@ async def get_source( warehouse_hash_b64: str, full_name: str, ) -> Source: + """Get a source from the backend.""" address = SourceAddress(full_name=full_name, warehouse_hash=warehouse_hash_b64) try: return backend.get_source(address) @@ -119,8 +125,8 @@ async def get_source( @app.post("/sources") -async def add_source(source: Source): - """Add a source to the backend.""" +async def add_source(source: Source) -> UploadStatus: + """Create an upload and insert task for indexed source data.""" upload_id = metadata_store.cache_source(metadata=source) return metadata_store.get(cache_id=upload_id).status @@ -239,59 +245,272 @@ async def list_models(): raise HTTPException(status_code=501, detail="Not implemented") -@app.get("/resolution/{name}") -async def get_resolution(name: str): - raise HTTPException(status_code=501, detail="Not implemented") +@app.post("/models") +async def insert_model( + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], model: ModelMetadata +) -> ModelOperationStatus: + """Insert a model into the backend.""" + try: + backend.insert_model(model) + return ModelOperationStatus( + success=True, + model_name=model.name, + operation=ModelOperationType.INSERT, + ) + except Exception as e: + return ModelOperationStatus( + success=False, + model_name=model.name, + operation=ModelOperationType.INSERT, + details=str(e), + ) -@app.post("/models/{name}") -async def add_model(name: str): - raise HTTPException(status_code=501, detail="Not implemented") +@app.get( + "/models/{name}", + responses={404: {"model": NotFoundError}}, +) +async def get_model( + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], name: str +) -> ModelMetadata: + """Get a model from the backend.""" + try: + return backend.get_model(model=name) + except MatchboxResolutionNotFoundError as e: + raise HTTPException( + status_code=404, + detail=NotFoundError( + details=str(e), entity=BackendRetrievableType.RESOLUTION + ).model_dump(), + ) from e -@app.delete("/models/{name}") -async def delete_model(name: str): - raise HTTPException(status_code=501, detail="Not implemented") +@app.delete( + "/models/{name}", + responses={ + 404: {"model": NotFoundError}, + 500: { + "model": ModelOperationStatus, + **ModelOperationStatus.status_500_examples(), + }, + }, +) +async def delete_model( + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], + name: str, + certain: Annotated[ + bool, Query(description="Confirm deletion of the model") + ] = False, +) -> ModelOperationStatus: + """Delete a model from the backend.""" + try: + backend.delete_model(model=name, certain=certain) + return ModelOperationStatus( + success=True, + model_name=name, + operation=ModelOperationType.DELETE, + ) + except MatchboxResolutionNotFoundError as e: + raise HTTPException( + status_code=404, + detail=NotFoundError( + details=str(e), entity=BackendRetrievableType.RESOLUTION + ).model_dump(), + ) from e + except ValueError as e: + raise HTTPException( + status_code=500, + detail=ModelOperationStatus( + success=False, + model_name=name, + operation=ModelOperationType.DELETE, + details=str(e), + ).model_dump(), + ) from e -@app.get("/models/{name}/results") -async def get_results(name: str, result_type: ModelResultsType | None): - raise HTTPException(status_code=501, detail="Not implemented") +@app.get( + "/models/{name}/results", + responses={404: {"model": NotFoundError}}, +) +async def get_results( + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], name: str +) -> ParquetResponse: + """Download results for a model as a parquet file.""" + try: + res = backend.get_model_results(model=name) + except MatchboxResolutionNotFoundError as e: + raise HTTPException( + status_code=404, + detail=NotFoundError( + details=str(e), entity=BackendRetrievableType.RESOLUTION + ).model_dump(), + ) from e + buffer = table_to_buffer(res) + return ParquetResponse(buffer.getvalue()) -@app.post("/models/{name}/results") -async def set_results(name: str): - raise HTTPException(status_code=501, detail="Not implemented") +@app.post( + "/models/{name}/results", + responses={404: {"model": NotFoundError}}, +) +async def set_results( + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], name: str +) -> UploadStatus: + """Create an upload task for model results.""" + try: + metadata = backend.get_model(model=name) + except MatchboxResolutionNotFoundError as e: + raise HTTPException( + status_code=404, + detail=NotFoundError( + details=str(e), entity=BackendRetrievableType.RESOLUTION + ).model_dump(), + ) from e -@app.get("/models/{name}/truth") -async def get_truth(name: str): - raise HTTPException(status_code=501, detail="Not implemented") + upload_id = metadata_store.cache_model(metadata=metadata) + return metadata_store.get(cache_id=upload_id).status -@app.post("/models/{name}/truth") -async def set_truth(name: str): - raise HTTPException(status_code=501, detail="Not implemented") +@app.get( + "/models/{name}/truth", + responses={404: {"model": NotFoundError}}, +) +async def get_truth( + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], name: str +) -> float: + """Get truth data for a model.""" + try: + return backend.get_model_truth(model=name) + except MatchboxResolutionNotFoundError as e: + raise HTTPException( + status_code=404, + detail=NotFoundError( + details=str(e), entity=BackendRetrievableType.RESOLUTION + ).model_dump(), + ) from e -@app.get("/models/{name}/ancestors") -async def get_ancestors(name: str): - raise HTTPException(status_code=501, detail="Not implemented") +@app.post( + "/models/{name}/truth", + responses={ + 404: {"model": NotFoundError}, + 500: { + "model": ModelOperationStatus, + **ModelOperationStatus.status_500_examples(), + }, + }, +) +async def set_truth( + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], name: str, truth: float +) -> ModelOperationStatus: + """Set truth data for a model.""" + try: + backend.set_model_truth(model=name, truth=truth) + return ModelOperationStatus( + success=True, + model_name=name, + operation=ModelOperationType.UPDATE_TRUTH, + ) + except MatchboxResolutionNotFoundError as e: + raise HTTPException( + status_code=404, + detail=NotFoundError( + details=str(e), entity=BackendRetrievableType.RESOLUTION + ).model_dump(), + ) from e + except Exception as e: + raise HTTPException( + status_code=500, + detail=ModelOperationStatus( + success=False, + model_name=name, + operation=ModelOperationType.UPDATE_TRUTH, + details=str(e), + ).model_dump(), + ) from e -@app.get("/models/{name}/ancestors_cache") -async def get_ancestors_cache(name: str): - raise HTTPException(status_code=501, detail="Not implemented") +@app.get( + "/models/{name}/ancestors", + responses={404: {"model": NotFoundError}}, +) +async def get_ancestors( + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], name: str +) -> ModelAncestors: + try: + return backend.get_model_ancestors(model=name) + except MatchboxResolutionNotFoundError as e: + raise HTTPException( + status_code=404, + detail=NotFoundError( + details=str(e), entity=BackendRetrievableType.RESOLUTION + ).model_dump(), + ) from e -@app.post("/models/{name}/ancestors_cache") -async def set_ancestors_cache(name: str): - raise HTTPException(status_code=501, detail="Not implemented") +@app.get( + "/models/{name}/ancestors_cache", + responses={404: {"model": NotFoundError}}, +) +async def get_ancestors_cache( + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], name: str +) -> ModelAncestors: + try: + return backend.get_model_ancestors_cache(model=name) + except MatchboxResolutionNotFoundError as e: + raise HTTPException( + status_code=404, + detail=NotFoundError( + details=str(e), entity=BackendRetrievableType.RESOLUTION + ).model_dump(), + ) from e + + +@app.post( + "/models/{name}/ancestors_cache", + responses={ + 404: {"model": NotFoundError}, + 500: { + "model": ModelOperationStatus, + **ModelOperationStatus.status_500_examples(), + }, + }, +) +async def set_ancestors_cache( + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], + name: str, + ancestors: ModelAncestors, +): + try: + backend.set_model_ancestors_cache(model=name, ancestors_cache=ancestors) + return ModelOperationStatus( + success=True, + model_name=name, + operation=ModelOperationType.UPDATE_ANCESTOR_CACHE, + ) + except MatchboxResolutionNotFoundError as e: + raise HTTPException( + status_code=404, + detail=NotFoundError( + details=str(e), entity=BackendRetrievableType.RESOLUTION + ).model_dump(), + ) from e + except Exception as e: + raise HTTPException( + status_code=500, + detail=ModelOperationStatus( + success=False, + model_name=name, + operation=ModelOperationType.UPDATE_ANCESTOR_CACHE, + details=str(e), + ).model_dump(), + ) from e @app.get( "/query", - response_class=ParquetResponse, responses={404: {"model": NotFoundError}}, ) async def query( @@ -301,7 +520,7 @@ async def query( resolution_name: str | None = None, threshold: int | None = None, limit: int | None = None, -): +) -> ParquetResponse: source_address = SourceAddress( full_name=full_name, warehouse_hash=warehouse_hash_b64 ) diff --git a/src/matchbox/server/base.py b/src/matchbox/server/base.py index 7fddd024..5ab65af5 100644 --- a/src/matchbox/server/base.py +++ b/src/matchbox/server/base.py @@ -20,7 +20,7 @@ from pydantic import Field, SecretStr from pydantic_settings import BaseSettings, SettingsConfigDict -from matchbox.common.dtos import ModelMetadata +from matchbox.common.dtos import ModelAncestors, ModelMetadata from matchbox.common.graph import ResolutionGraph from matchbox.common.sources import Match, Source, SourceAddress @@ -293,15 +293,15 @@ def set_model_truth(self, model: str, truth: float) -> None: ... def get_model_truth(self, model: str) -> float: ... @abstractmethod - def get_model_ancestors(self, model: str) -> dict[str, float]: ... + def get_model_ancestors(self, model: str) -> ModelAncestors: ... @abstractmethod def set_model_ancestors_cache( - self, model: str, ancestors_cache: dict[str, float] + self, model: str, ancestors_cache: ModelAncestors ) -> None: ... @abstractmethod - def get_model_ancestors_cache(self, model: str) -> dict[str, float]: ... + def get_model_ancestors_cache(self, model: str) -> ModelAncestors: ... @abstractmethod def delete_model(self, model: str, certain: bool) -> None: ... diff --git a/src/matchbox/server/postgresql/adapter.py b/src/matchbox/server/postgresql/adapter.py index 935ccec1..dfadd8ba 100644 --- a/src/matchbox/server/postgresql/adapter.py +++ b/src/matchbox/server/postgresql/adapter.py @@ -5,7 +5,7 @@ from sqlalchemy import and_, bindparam, delete, func, or_, select from sqlalchemy.orm import Session -from matchbox.common.dtos import ModelMetadata, ModelType +from matchbox.common.dtos import ModelAncestors, ModelMetadata, ModelType from matchbox.common.exceptions import ( MatchboxDataNotFound, MatchboxResolutionNotFoundError, @@ -37,13 +37,10 @@ T = TypeVar("T") P = ParamSpec("P") + if TYPE_CHECKING: - from pandas import DataFrame as PandasDataFrame - from polars import DataFrame as PolarsDataFrame from pyarrow import Table as ArrowTable else: - PandasDataFrame = Any - PolarsDataFrame = Any ArrowTable = Any @@ -434,40 +431,43 @@ def get_model_truth(self, model: str) -> float: resolution = resolve_model_name(model=model, engine=MBDB.get_engine()) return resolution.truth - def get_model_ancestors(self, model: str) -> dict[str, float]: + def get_model_ancestors(self, model: str) -> ModelAncestors: """Gets the current truth values of all ancestors. - Returns a dict mapping model names to their current truth thresholds. + Returns a ModelAncestors mapping model names to their current truth thresholds. Unlike ancestors_cache which returns cached values, this property returns the current truth values of all ancestor models. """ resolution = resolve_model_name(model=model, engine=MBDB.get_engine()) - return { - resolution.name: resolution.truth for resolution in resolution.ancestors - } + return ModelAncestors( + ancestors={ + resolution.name: resolution.truth for resolution in resolution.ancestors + } + ) def set_model_ancestors_cache( self, model: str, - ancestors_cache: dict[str, float], + ancestors_cache: ModelAncestors, ) -> None: """Updates the cached ancestor thresholds. Args: - ancestors_cache: Dictionary mapping model names to their truth thresholds + ancestors_cache: ModelAncestors mapping model names to their truth + thresholds """ resolution = resolve_model_name(model=model, engine=MBDB.get_engine()) with Session(MBDB.get_engine()) as session: session.add(resolution) - model_names = list(ancestors_cache.keys()) + model_names = list(ancestors_cache.ancestors.keys()) name_to_id = dict( session.query(Resolutions.name, Resolutions.resolution_id) .filter(Resolutions.name.in_(model_names)) .all() ) - for model_name, truth_value in ancestors_cache.items(): + for model_name, truth_value in ancestors_cache.ancestors.items(): parent_id = name_to_id.get(model_name) if parent_id is None: raise ValueError(f"Model '{model_name}' not found in database") @@ -481,7 +481,7 @@ def set_model_ancestors_cache( session.commit() - def get_model_ancestors_cache(self, model: str) -> dict[str, float]: + def get_model_ancestors_cache(self, model: str) -> ModelAncestors: """Gets the cached ancestor thresholds, converting hashes to model names. Returns a dictionary mapping model names to their truth thresholds. @@ -499,9 +499,12 @@ def get_model_ancestors_cache(self, model: str) -> dict[str, float]: .where(ResolutionFrom.truth_cache.isnot(None)) ) - return { - name: truth_cache for name, truth_cache in session.execute(query).all() - } + return ModelAncestors( + ancestors={ + name: truth_cache + for name, truth_cache in session.execute(query).all() + } + ) def delete_model(self, model: str, certain: bool = False) -> None: """Delete a model from the database. diff --git a/test/server/test_adapter.py b/test/server/test_adapter.py index 28b08ed4..32e5a59e 100644 --- a/test/server/test_adapter.py +++ b/test/server/test_adapter.py @@ -8,7 +8,7 @@ from sqlalchemy import Engine from matchbox.client.helpers.selector import match, query, select -from matchbox.common.dtos import ModelMetadata, ModelType +from matchbox.common.dtos import ModelAncestors, ModelMetadata, ModelType from matchbox.common.exceptions import ( MatchboxDataNotFound, MatchboxResolutionNotFoundError, @@ -354,10 +354,10 @@ def test_model_ancestors(self): linker_name = "deterministic_naive_test.crn_naive_test.duns" linker_ancestors = self.backend.get_model_ancestors(model=linker_name) - assert isinstance(linker_ancestors, dict) + assert isinstance(linker_ancestors, ModelAncestors) truth_found = False - for model, truth in linker_ancestors.items(): + for model, truth in linker_ancestors.ancestors.items(): if isinstance(truth, float): # Not all ancestors have truth values, but one must truth_found = True @@ -375,7 +375,9 @@ def test_model_ancestors_cache(self): pre_ancestors_cache = self.backend.get_model_ancestors_cache(model=linker_name) # Set - updated_ancestors_cache = {k: 0.5 for k in pre_ancestors_cache.keys()} + updated_ancestors_cache = ModelAncestors( + ancestors={k: 0.5 for k in pre_ancestors_cache.ancestors.keys()} + ) self.backend.set_model_ancestors_cache( model=linker_name, ancestors_cache=updated_ancestors_cache ) diff --git a/uv.lock b/uv.lock index 2355eabb..058e32c1 100644 --- a/uv.lock +++ b/uv.lock @@ -258,7 +258,7 @@ name = "click" version = "8.1.8" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/b9/2e/0090cbf739cee7d23781ad4b89a9894a41538e4fcf4c31dcdd705b78eb8b/click-8.1.8.tar.gz", hash = "sha256:ed53c9d8990d83c2a27deae68e4ee337473f6330c040a31d4225c9574d16096a", size = 226593 } wheels = [ @@ -849,7 +849,7 @@ name = "ipykernel" version = "6.29.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "appnope", marker = "platform_system == 'Darwin'" }, + { name = "appnope", marker = "sys_platform == 'darwin'" }, { name = "comm" }, { name = "debugpy" }, { name = "ipython" }, @@ -1327,7 +1327,7 @@ version = "1.6.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, { name = "ghp-import" }, { name = "jinja2" }, { name = "markdown" }, From 875189cae91fcff1f12119ba6c0018af6efac2db Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Wed, 12 Feb 2025 14:18:27 +0000 Subject: [PATCH 02/19] Changed ModelAncestors to ModelAncestor --- src/matchbox/client/models/models.py | 16 ++++++-- src/matchbox/common/dtos.py | 11 +++-- src/matchbox/server/api/routes.py | 8 ++-- src/matchbox/server/base.py | 8 ++-- src/matchbox/server/postgresql/adapter.py | 49 +++++++++++------------ test/server/test_adapter.py | 19 +++++---- 6 files changed, 58 insertions(+), 53 deletions(-) diff --git a/src/matchbox/client/models/models.py b/src/matchbox/client/models/models.py index f62884b2..f192b5b6 100644 --- a/src/matchbox/client/models/models.py +++ b/src/matchbox/client/models/models.py @@ -5,7 +5,7 @@ from matchbox.client.models.dedupers.base import Deduper from matchbox.client.models.linkers.base import Linker from matchbox.client.results import Results -from matchbox.common.dtos import ModelAncestors, ModelMetadata, ModelType +from matchbox.common.dtos import ModelAncestor, ModelMetadata, ModelType from matchbox.common.exceptions import MatchboxResolutionNotFoundError from matchbox.server import MatchboxDBAdapter, inject_backend @@ -64,13 +64,19 @@ def truth(self, backend: MatchboxDBAdapter, truth: float) -> None: @inject_backend def ancestors(self, backend: MatchboxDBAdapter) -> dict[str, float]: """Retrieve the ancestors of the model.""" - return backend.get_model_ancestors(model=self.metadata.name).ancestors + return { + ancestor.name: ancestor.truth + for ancestor in backend.get_model_ancestors(model=self.metadata.name) + } @property @inject_backend def ancestors_cache(self, backend: MatchboxDBAdapter) -> dict[str, float]: """Retrieve the ancestors cache of the model.""" - return backend.get_model_ancestors_cache(model=self.metadata.name).ancestors + return { + ancestor.name: ancestor.truth + for ancestor in backend.get_model_ancestors_cache(model=self.metadata.name) + } @ancestors_cache.setter @inject_backend @@ -80,7 +86,9 @@ def ancestors_cache( """Set the ancestors cache of the model.""" backend.set_model_ancestors_cache( model=self.metadata.name, - ancestors_cache=ModelAncestors(ancestors=ancestors_cache), + ancestors_cache=[ + ModelAncestor(name=k, truth=v) for k, v in ancestors_cache + ], ) def run(self) -> Results: diff --git a/src/matchbox/common/dtos.py b/src/matchbox/common/dtos.py index 2283a22e..0eab9427 100644 --- a/src/matchbox/common/dtos.py +++ b/src/matchbox/common/dtos.py @@ -65,13 +65,12 @@ class ModelMetadata(BaseModel): right_resolution: str | None = None # Only used for linker models -class ModelAncestors(BaseModel): - """A model's ancestors and their truth values.""" +class ModelAncestor(BaseModel): + """A model's ancestor and its truth value.""" - ancestors: dict[str, float] = Field( - ..., - description="Mapping of model names to their truth thresholds", - examples=[{"model1": 0.75, "model2": 0.85}], + name: str = Field(..., description="Name of the ancestor model") + truth: float | None = Field( + default=None, description="Truth threshold value", ge=0.0, le=1.0 ) diff --git a/src/matchbox/server/api/routes.py b/src/matchbox/server/api/routes.py index 4f9f3f2a..cde0dd54 100644 --- a/src/matchbox/server/api/routes.py +++ b/src/matchbox/server/api/routes.py @@ -13,7 +13,7 @@ BackendUploadType, CountResult, HealthCheck, - ModelAncestors, + ModelAncestor, ModelMetadata, ModelOperationStatus, ModelOperationType, @@ -438,7 +438,7 @@ async def set_truth( ) async def get_ancestors( backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], name: str -) -> ModelAncestors: +) -> list[ModelAncestor]: try: return backend.get_model_ancestors(model=name) except MatchboxResolutionNotFoundError as e: @@ -456,7 +456,7 @@ async def get_ancestors( ) async def get_ancestors_cache( backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], name: str -) -> ModelAncestors: +) -> list[ModelAncestor]: try: return backend.get_model_ancestors_cache(model=name) except MatchboxResolutionNotFoundError as e: @@ -481,7 +481,7 @@ async def get_ancestors_cache( async def set_ancestors_cache( backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], name: str, - ancestors: ModelAncestors, + ancestors: list[ModelAncestor], ): try: backend.set_model_ancestors_cache(model=name, ancestors_cache=ancestors) diff --git a/src/matchbox/server/base.py b/src/matchbox/server/base.py index 5ab65af5..b1064bd5 100644 --- a/src/matchbox/server/base.py +++ b/src/matchbox/server/base.py @@ -20,7 +20,7 @@ from pydantic import Field, SecretStr from pydantic_settings import BaseSettings, SettingsConfigDict -from matchbox.common.dtos import ModelAncestors, ModelMetadata +from matchbox.common.dtos import ModelAncestor, ModelMetadata from matchbox.common.graph import ResolutionGraph from matchbox.common.sources import Match, Source, SourceAddress @@ -293,15 +293,15 @@ def set_model_truth(self, model: str, truth: float) -> None: ... def get_model_truth(self, model: str) -> float: ... @abstractmethod - def get_model_ancestors(self, model: str) -> ModelAncestors: ... + def get_model_ancestors(self, model: str) -> list[ModelAncestor]: ... @abstractmethod def set_model_ancestors_cache( - self, model: str, ancestors_cache: ModelAncestors + self, model: str, ancestors_cache: list[ModelAncestor] ) -> None: ... @abstractmethod - def get_model_ancestors_cache(self, model: str) -> ModelAncestors: ... + def get_model_ancestors_cache(self, model: str) -> list[ModelAncestor]: ... @abstractmethod def delete_model(self, model: str, certain: bool) -> None: ... diff --git a/src/matchbox/server/postgresql/adapter.py b/src/matchbox/server/postgresql/adapter.py index dfadd8ba..edd3ff3b 100644 --- a/src/matchbox/server/postgresql/adapter.py +++ b/src/matchbox/server/postgresql/adapter.py @@ -5,7 +5,7 @@ from sqlalchemy import and_, bindparam, delete, func, or_, select from sqlalchemy.orm import Session -from matchbox.common.dtos import ModelAncestors, ModelMetadata, ModelType +from matchbox.common.dtos import ModelAncestor, ModelMetadata, ModelType from matchbox.common.exceptions import ( MatchboxDataNotFound, MatchboxResolutionNotFoundError, @@ -431,60 +431,61 @@ def get_model_truth(self, model: str) -> float: resolution = resolve_model_name(model=model, engine=MBDB.get_engine()) return resolution.truth - def get_model_ancestors(self, model: str) -> ModelAncestors: + def get_model_ancestors(self, model: str) -> list[ModelAncestor]: """Gets the current truth values of all ancestors. - Returns a ModelAncestors mapping model names to their current truth thresholds. + Returns a list of ModelAncestor objects mapping model names to their current + truth thresholds. Unlike ancestors_cache which returns cached values, this property returns the current truth values of all ancestor models. """ resolution = resolve_model_name(model=model, engine=MBDB.get_engine()) - return ModelAncestors( - ancestors={ - resolution.name: resolution.truth for resolution in resolution.ancestors - } - ) + return [ + ModelAncestor(name=resolution.name, truth=resolution.truth) + for resolution in resolution.ancestors + ] def set_model_ancestors_cache( self, model: str, - ancestors_cache: ModelAncestors, + ancestors_cache: list[ModelAncestor], ) -> None: """Updates the cached ancestor thresholds. Args: - ancestors_cache: ModelAncestors mapping model names to their truth - thresholds + ancestors_cache: List of ModelAncestor objects mapping model names to + their truth thresholds """ resolution = resolve_model_name(model=model, engine=MBDB.get_engine()) with Session(MBDB.get_engine()) as session: session.add(resolution) - model_names = list(ancestors_cache.ancestors.keys()) + ancestor_names = [ancestor.name for ancestor in ancestors_cache] name_to_id = dict( session.query(Resolutions.name, Resolutions.resolution_id) - .filter(Resolutions.name.in_(model_names)) + .filter(Resolutions.name.in_(ancestor_names)) .all() ) - for model_name, truth_value in ancestors_cache.ancestors.items(): - parent_id = name_to_id.get(model_name) + for ancestor in ancestors_cache: + parent_id = name_to_id.get(ancestor.name) if parent_id is None: - raise ValueError(f"Model '{model_name}' not found in database") + raise ValueError(f"Model '{ancestor.name}' not found in database") session.execute( ResolutionFrom.__table__.update() .where(ResolutionFrom.parent == parent_id) .where(ResolutionFrom.child == resolution.resolution_id) - .values(truth_cache=truth_value) + .values(truth_cache=ancestor.truth) ) session.commit() - def get_model_ancestors_cache(self, model: str) -> ModelAncestors: + def get_model_ancestors_cache(self, model: str) -> list[ModelAncestor]: """Gets the cached ancestor thresholds, converting hashes to model names. - Returns a dictionary mapping model names to their truth thresholds. + Returns a list of ModelAncestor objects mapping model names to their cached + truth thresholds. This is required because each point of truth needs to be stable, so we choose when to update it, caching the ancestor's values in the model itself. @@ -499,12 +500,10 @@ def get_model_ancestors_cache(self, model: str) -> ModelAncestors: .where(ResolutionFrom.truth_cache.isnot(None)) ) - return ModelAncestors( - ancestors={ - name: truth_cache - for name, truth_cache in session.execute(query).all() - } - ) + return [ + ModelAncestor(name=name, truth=truth) + for name, truth in session.execute(query).all() + ] def delete_model(self, model: str, certain: bool = False) -> None: """Delete a model from the database. diff --git a/test/server/test_adapter.py b/test/server/test_adapter.py index 32e5a59e..8033afc4 100644 --- a/test/server/test_adapter.py +++ b/test/server/test_adapter.py @@ -8,7 +8,7 @@ from sqlalchemy import Engine from matchbox.client.helpers.selector import match, query, select -from matchbox.common.dtos import ModelAncestors, ModelMetadata, ModelType +from matchbox.common.dtos import ModelAncestor, ModelMetadata, ModelType from matchbox.common.exceptions import ( MatchboxDataNotFound, MatchboxResolutionNotFoundError, @@ -354,15 +354,13 @@ def test_model_ancestors(self): linker_name = "deterministic_naive_test.crn_naive_test.duns" linker_ancestors = self.backend.get_model_ancestors(model=linker_name) - assert isinstance(linker_ancestors, ModelAncestors) + assert [isinstance(ancestor, ModelAncestor) for ancestor in linker_ancestors] + # Not all ancestors have truth values, but one must truth_found = False - for model, truth in linker_ancestors.ancestors.items(): - if isinstance(truth, float): - # Not all ancestors have truth values, but one must + for ancestor in linker_ancestors: + if isinstance(ancestor.truth, float): truth_found = True - assert isinstance(model, str) - assert isinstance(truth, (float, type(None))) assert truth_found @@ -375,9 +373,10 @@ def test_model_ancestors_cache(self): pre_ancestors_cache = self.backend.get_model_ancestors_cache(model=linker_name) # Set - updated_ancestors_cache = ModelAncestors( - ancestors={k: 0.5 for k in pre_ancestors_cache.ancestors.keys()} - ) + updated_ancestors_cache = [ + ModelAncestor(name=ancestor.name, truth=0.5) + for ancestor in pre_ancestors_cache + ] self.backend.set_model_ancestors_cache( model=linker_name, ancestors_cache=updated_ancestors_cache ) From 33a7c9bc8ee857e1cd4d7b5ba28845af0ed403e5 Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Wed, 12 Feb 2025 14:34:48 +0000 Subject: [PATCH 03/19] Updated model deletion confirmation to use a custom error and raise 409 HTTP code --- src/matchbox/common/dtos.py | 35 +++++++++++++++-------- src/matchbox/common/exceptions.py | 16 +++++++++++ src/matchbox/server/api/routes.py | 34 +++++++++++++++------- src/matchbox/server/postgresql/adapter.py | 12 ++------ 4 files changed, 65 insertions(+), 32 deletions(-) diff --git a/src/matchbox/common/dtos.py b/src/matchbox/common/dtos.py index 0eab9427..a2623fad 100644 --- a/src/matchbox/common/dtos.py +++ b/src/matchbox/common/dtos.py @@ -84,22 +84,11 @@ class ModelOperationStatus(BaseModel): timestamp: datetime = Field(default_factory=datetime.now) @classmethod - def status_500_examples(cls) -> dict: + def status_409_examples(cls) -> dict: return { "content": { "application/json": { "examples": { - "unhandled": { - "summary": ( - "Unhandled exception encountered while updating the " - "model's truth value." - ), - "value": cls( - success=False, - model_name="example_model", - operation=ModelOperationType.UPDATE_TRUTH, - ).model_dump(), - }, "confirm_delete": { "summary": "Delete operation requires confirmation. ", "value": cls( @@ -123,6 +112,28 @@ def status_500_examples(cls) -> dict: } } + @classmethod + def status_500_examples(cls) -> dict: + return { + "content": { + "application/json": { + "examples": { + "unhandled": { + "summary": ( + "Unhandled exception encountered while updating the " + "model's truth value." + ), + "value": cls( + success=False, + model_name="example_model", + operation=ModelOperationType.UPDATE_TRUTH, + ).model_dump(), + }, + }, + } + } + } + class HealthCheck(BaseModel): """Response model to validate and return when performing a health check.""" diff --git a/src/matchbox/common/exceptions.py b/src/matchbox/common/exceptions.py index 2f11bde5..0e7c7f5d 100644 --- a/src/matchbox/common/exceptions.py +++ b/src/matchbox/common/exceptions.py @@ -136,3 +136,19 @@ def __init__(self, message: str | None = None): class MatchboxConnectionError(Exception): """Connection to Matchbox's backend database failed.""" + + +class MatchboxConfirmDelete(Exception): + """Deletion must be confirmed: if certain, rerun with certain=True.""" + + def __init__(self, children: list[str]): + children_names = ", ".join(children) + message = ( + f"This operation will delete the resolutions {children_names}, " + "as well as all probabilities they have created. \n\n" + "It won't delete validation associated with these " + "probabilities. \n\n" + "If you're sure you want to continue, rerun with certain=True. " + ) + + super().__init__(message) diff --git a/src/matchbox/server/api/routes.py b/src/matchbox/server/api/routes.py index cde0dd54..44af1ad8 100644 --- a/src/matchbox/server/api/routes.py +++ b/src/matchbox/server/api/routes.py @@ -21,6 +21,7 @@ UploadStatus, ) from matchbox.common.exceptions import ( + MatchboxConfirmDelete, MatchboxResolutionNotFoundError, MatchboxServerFileError, MatchboxSourceNotFoundError, @@ -245,7 +246,15 @@ async def list_models(): raise HTTPException(status_code=501, detail="Not implemented") -@app.post("/models") +@app.post( + "/models", + responses={ + 500: { + "model": ModelOperationStatus, + **ModelOperationStatus.status_500_examples(), + }, + }, +) async def insert_model( backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], model: ModelMetadata ) -> ModelOperationStatus: @@ -258,12 +267,15 @@ async def insert_model( operation=ModelOperationType.INSERT, ) except Exception as e: - return ModelOperationStatus( - success=False, - model_name=model.name, - operation=ModelOperationType.INSERT, - details=str(e), - ) + raise HTTPException( + status_code=500, + detail=ModelOperationStatus( + success=False, + model_name=model.name, + operation=ModelOperationType.INSERT, + details=str(e), + ).model_dump(), + ) from e @app.get( @@ -289,9 +301,9 @@ async def get_model( "/models/{name}", responses={ 404: {"model": NotFoundError}, - 500: { + 409: { "model": ModelOperationStatus, - **ModelOperationStatus.status_500_examples(), + **ModelOperationStatus.status_409_examples(), }, }, ) @@ -317,9 +329,9 @@ async def delete_model( details=str(e), entity=BackendRetrievableType.RESOLUTION ).model_dump(), ) from e - except ValueError as e: + except MatchboxConfirmDelete as e: raise HTTPException( - status_code=500, + status_code=409, detail=ModelOperationStatus( success=False, model_name=name, diff --git a/src/matchbox/server/postgresql/adapter.py b/src/matchbox/server/postgresql/adapter.py index edd3ff3b..25c24fbd 100644 --- a/src/matchbox/server/postgresql/adapter.py +++ b/src/matchbox/server/postgresql/adapter.py @@ -7,6 +7,7 @@ from matchbox.common.dtos import ModelAncestor, ModelMetadata, ModelType from matchbox.common.exceptions import ( + MatchboxConfirmDelete, MatchboxDataNotFound, MatchboxResolutionNotFoundError, MatchboxSourceNotFoundError, @@ -527,12 +528,5 @@ def delete_model(self, model: str, certain: bool = False) -> None: session.delete(resolution) session.commit() else: - childen = resolution.descendants - children_names = ", ".join([r.name for r in childen]) - raise ValueError( - f"This operation will delete the resolutions {children_names}, " - "as well as all probabilities they have created. \n\n" - "It won't delete validation associated with these " - "probabilities. \n\n" - "If you're sure you want to continue, rerun with certain=True" - ) + childen = [r.name for r in resolution.descendants] + raise MatchboxConfirmDelete(childen) From 65ecc111e3f788cdeafc978c3bf5fc7316f0c1ef Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Wed, 12 Feb 2025 16:56:28 +0000 Subject: [PATCH 04/19] Working API-level unit tests for model and results endpoints --- src/matchbox/common/dtos.py | 2 - src/matchbox/server/api/routes.py | 18 +- test/server/api/test_routes.py | 380 +++++++++++++++++++++++++++--- 3 files changed, 361 insertions(+), 39 deletions(-) diff --git a/src/matchbox/common/dtos.py b/src/matchbox/common/dtos.py index a2623fad..5652e778 100644 --- a/src/matchbox/common/dtos.py +++ b/src/matchbox/common/dtos.py @@ -1,4 +1,3 @@ -from datetime import datetime from enum import StrEnum from typing import Literal @@ -81,7 +80,6 @@ class ModelOperationStatus(BaseModel): model_name: str operation: ModelOperationType details: str | None = None - timestamp: datetime = Field(default_factory=datetime.now) @classmethod def status_409_examples(cls) -> dict: diff --git a/src/matchbox/server/api/routes.py b/src/matchbox/server/api/routes.py index 44af1ad8..95c9ebe4 100644 --- a/src/matchbox/server/api/routes.py +++ b/src/matchbox/server/api/routes.py @@ -2,7 +2,15 @@ from typing import TYPE_CHECKING, Annotated, Any, AsyncGenerator from dotenv import find_dotenv, load_dotenv -from fastapi import BackgroundTasks, Depends, FastAPI, HTTPException, Query, UploadFile +from fastapi import ( + BackgroundTasks, + Body, + Depends, + FastAPI, + HTTPException, + Query, + UploadFile, +) from fastapi.responses import JSONResponse, Response from starlette.exceptions import HTTPException as StarletteHTTPException @@ -404,7 +412,7 @@ async def get_truth( ) from e -@app.post( +@app.patch( "/models/{name}/truth", responses={ 404: {"model": NotFoundError}, @@ -415,7 +423,9 @@ async def get_truth( }, ) async def set_truth( - backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], name: str, truth: float + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], + name: str, + truth: Annotated[float, Body()], ) -> ModelOperationStatus: """Set truth data for a model.""" try: @@ -480,7 +490,7 @@ async def get_ancestors_cache( ) from e -@app.post( +@app.patch( "/models/{name}/ancestors_cache", responses={ 404: {"model": NotFoundError}, diff --git a/test/server/api/test_routes.py b/test/server/api/test_routes.py index 606c919c..fb8bc539 100644 --- a/test/server/api/test_routes.py +++ b/test/server/api/test_routes.py @@ -9,11 +9,18 @@ from fastapi.testclient import TestClient from matchbox.common.arrow import SCHEMA_MB_IDS, table_to_buffer -from matchbox.common.dtos import BackendRetrievableType, UploadStatus +from matchbox.common.dtos import ( + BackendRetrievableType, + ModelAncestor, + ModelOperationType, + UploadStatus, +) from matchbox.common.exceptions import ( + MatchboxConfirmDelete, MatchboxResolutionNotFoundError, MatchboxSourceNotFoundError, ) +from matchbox.common.factories.models import model_factory from matchbox.common.factories.sources import source_factory from matchbox.common.graph import ResolutionGraph from matchbox.common.hash import hash_to_base64 @@ -82,6 +89,9 @@ def test_count_backend_item(get_backend: MatchboxDBAdapter): # assert response.status_code == 200 +# Source endpoints + + @patch("matchbox.server.base.BackendManager.get_backend") def test_get_source(get_backend): dummy_source = Source( @@ -301,7 +311,7 @@ def test_status_check_not_found(metadata_store: Mock): @pytest.mark.asyncio @patch("matchbox.server.base.BackendManager.get_backend") -async def test_complete_upload_process(get_backend: Mock, s3: S3Client): +async def test_complete_source_upload_process(get_backend: Mock, s3: S3Client): """Test the complete upload process from source creation through processing.""" # Setup the backend mock_backend = Mock() @@ -370,49 +380,353 @@ async def test_complete_upload_process(get_backend: Mock, s3: S3Client): assert call_args[1]["data_hashes"].equals(dummy_source.data_hashes) # Check data -# def test_list_models(): +# Model endpoints + +# @patch("matchbox.server.base.BackendManager.get_backend") +# def test_list_models(get_backend: Mock): +# mock_backend = Mock() +# dummy_models = [ +# model_factory(name="model1", description="test model 1").model, +# model_factory(name="model2", description="test model 2").model +# ] +# mock_backend.list_models = Mock(return_value=dummy_models) +# get_backend.return_value = mock_backend + # response = client.get("/models") # assert response.status_code == 200 -# def test_get_resolution(): -# response = client.get("/models/test_resolution") -# assert response.status_code == 200 -# def test_add_model(): -# response = client.post("/models") -# assert response.status_code == 200 +@patch("matchbox.server.base.BackendManager.get_backend") +def test_insert_model(get_backend: Mock): + mock_backend = Mock() + get_backend.return_value = mock_backend -# def test_delete_model(): -# response = client.delete("/models/test_model") -# assert response.status_code == 200 + dummy_model = model_factory(name="test_model") + response = client.post("/models", json=dummy_model.model.model_dump()) -# def test_get_results(): -# response = client.get("/models/test_model/results") -# assert response.status_code == 200 + assert response.status_code == 200 + assert response.json() == { + "success": True, + "model_name": "test_model", + "operation": ModelOperationType.INSERT.value, + "details": None, + } + mock_backend.insert_model.assert_called_once_with(dummy_model.model) -# def test_set_results(): -# response = client.post("/models/test_model/results") -# assert response.status_code == 200 -# def test_get_truth(): -# response = client.get("/models/test_model/truth") -# assert response.status_code == 200 +@patch("matchbox.server.base.BackendManager.get_backend") +def test_insert_model_error(get_backend: Mock): + mock_backend = Mock() + mock_backend.insert_model = Mock(side_effect=Exception("Test error")) + get_backend.return_value = mock_backend -# def test_set_truth(): -# response = client.post("/models/test_model/truth") -# assert response.status_code == 200 + dummy_model = model_factory() + response = client.post("/models", json=dummy_model.model.model_dump()) -# def test_get_ancestors(): -# response = client.get("/models/test_model/ancestors") -# assert response.status_code == 200 + assert response.status_code == 500 + assert response.json()["success"] is False + assert response.json()["details"] == "Test error" -# def test_get_ancestors_cache(): -# response = client.get("/models/test_model/ancestors_cache") -# assert response.status_code == 200 -# def test_set_ancestors_cache(): -# response = client.post("/models/test_model/ancestors_cache") -# assert response.status_code == 200 +@patch("matchbox.server.base.BackendManager.get_backend") +def test_get_model(get_backend: Mock): + mock_backend = Mock() + dummy_model = model_factory(name="test_model", description="test description") + mock_backend.get_model = Mock(return_value=dummy_model.model) + get_backend.return_value = mock_backend + + response = client.get("/models/test_model") + + assert response.status_code == 200 + assert response.json()["name"] == dummy_model.model.name + assert response.json()["description"] == dummy_model.model.description + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_get_model_not_found(get_backend: Mock): + mock_backend = Mock() + mock_backend.get_model = Mock(side_effect=MatchboxResolutionNotFoundError()) + get_backend.return_value = mock_backend + + response = client.get("/models/nonexistent") + + assert response.status_code == 404 + assert response.json()["entity"] == BackendRetrievableType.RESOLUTION + + +@pytest.mark.parametrize("model_type", ["deduper", "linker"]) +@patch("matchbox.server.base.BackendManager.get_backend") +@patch("matchbox.server.api.routes.metadata_store") +@patch("matchbox.server.api.routes.BackgroundTasks.add_task") +def test_model_upload( + mock_add_task: Mock, + metadata_store: Mock, + get_backend: Mock, + s3: S3Client, + model_type: str, +): + """Test uploading different types of files.""" + # Setup + mock_backend = Mock() + mock_backend.settings.datastore.get_client.return_value = s3 + mock_backend.settings.datastore.cache_bucket_name = "test-bucket" + get_backend.return_value = mock_backend + s3.create_bucket( + Bucket="test-bucket", + CreateBucketConfiguration={"LocationConstraint": "eu-west-2"}, + ) + + # Create test data with specified model type + dummy_model = model_factory(model_type=model_type) + + # Setup metadata store + store = MetadataStore() + upload_id = store.cache_model(dummy_model.model) + + metadata_store.get.side_effect = store.get + metadata_store.update_status.side_effect = store.update_status + + # Make request + response = client.post( + f"/upload/{upload_id}", + files={ + "file": ( + "data.parquet", + table_to_buffer(dummy_model.data), + "application/octet-stream", + ), + }, + ) + + # Validate response + assert response.status_code == 200 + assert response.json()["status"] == "queued" + mock_add_task.assert_called_once() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_type", ["deduper", "linker"]) +@patch("matchbox.server.base.BackendManager.get_backend") +async def test_complete_model_upload_process( + get_backend: Mock, s3: S3Client, model_type: str +): + """Test the complete upload process for models from creation through processing.""" + # Setup the backend + mock_backend = Mock() + mock_backend.settings.datastore.get_client.return_value = s3 + mock_backend.settings.datastore.cache_bucket_name = "test-bucket" + mock_backend.set_model_results = Mock(return_value=None) + get_backend.return_value = mock_backend + + # Create test bucket + s3.create_bucket( + Bucket="test-bucket", + CreateBucketConfiguration={"LocationConstraint": "eu-west-2"}, + ) + + # Create test data with specified model type + dummy_model = model_factory(model_type=model_type) + + # Set up the mock to return the actual model metadata and data + mock_backend.get_model = Mock(return_value=dummy_model.model) + mock_backend.get_model_results = Mock(return_value=dummy_model.data) + + # Step 1: Create model + response = client.post("/models", json=dummy_model.model.model_dump()) + assert response.status_code == 200 + assert response.json()["success"] is True + assert response.json()["model_name"] == dummy_model.model.name + + # Step 2: Initialize results upload + response = client.post(f"/models/{dummy_model.model.name}/results") + assert response.status_code == 200 + upload_id = response.json()["id"] + assert response.json()["status"] == "awaiting_upload" + + # Step 3: Upload results file with real background tasks + response = client.post( + f"/upload/{upload_id}", + files={ + "file": ( + "results.parquet", + table_to_buffer(dummy_model.data), + "application/octet-stream", + ), + }, + ) + assert response.status_code == 200 + assert response.json()["status"] == "queued" + + # Step 4: Poll status until complete or timeout + max_attempts = 10 + current_attempt = 0 + status = None + + while current_attempt < max_attempts: + response = client.get(f"/upload/{upload_id}/status") + assert response.status_code == 200 + + status = response.json()["status"] + if status == "complete": + break + elif status == "failed": + pytest.fail(f"Upload failed: {response.json().get('details')}") + elif status in ["processing", "queued"]: + await asyncio.sleep(0.1) # Small delay between polls + else: + pytest.fail(f"Unexpected status: {status}") + + current_attempt += 1 + + assert current_attempt < max_attempts, ( + "Timed out waiting for processing to complete" + ) + assert status == "complete" + + # Step 5: Verify results were stored correctly + mock_backend.set_model_results.assert_called_once() + call_args = mock_backend.set_model_results.call_args + assert call_args[1]["model"] == dummy_model.model.name # Check model name matches + assert call_args[1]["results"].equals( + dummy_model.data + ) # Check results data matches + + # Step 6: Verify we can retrieve the results + response = client.get(f"/models/{dummy_model.model.name}/results") + assert response.status_code == 200 + assert response.headers["content-type"] == "application/octet-stream" + + # Step 7: Additional model-specific verifications + if model_type == "linker": + # For linker models, verify left and right resolutions are set + assert dummy_model.model.left_resolution is not None + assert dummy_model.model.right_resolution is not None + else: + # For deduper models, verify only left resolution is set + assert dummy_model.model.left_resolution is not None + assert dummy_model.model.right_resolution is None + + # Verify the model truth can be set and retrieved + truth_value = 0.85 + mock_backend.get_model_truth = Mock(return_value=truth_value) + + response = client.patch(f"/models/{dummy_model.model.name}/truth", json=truth_value) + assert response.status_code == 200 + + response = client.get(f"/models/{dummy_model.model.name}/truth") + assert response.status_code == 200 + assert response.json() == truth_value + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_delete_model(get_backend: Mock): + mock_backend = Mock() + get_backend.return_value = mock_backend + + dummy_model = model_factory() + response = client.delete( + f"/models/{dummy_model.model.name}", params={"certain": True} + ) + + assert response.status_code == 200 + assert response.json() == { + "success": True, + "model_name": dummy_model.model.name, + "operation": ModelOperationType.DELETE, + "details": None, + } + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_delete_model_needs_confirmation(get_backend: Mock): + mock_backend = Mock() + mock_backend.delete_model = Mock( + side_effect=MatchboxConfirmDelete(["dedupe1", "dedupe2"]) + ) + get_backend.return_value = mock_backend + + dummy_model = model_factory() + response = client.delete(f"/models/{dummy_model.model.name}") + + assert response.status_code == 409 + assert response.json()["success"] is False + message = response.json()["details"] + assert "dedupe1" in message and "dedupe2" in message + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_get_results(get_backend: Mock): + mock_backend = Mock() + dummy_model = model_factory() + mock_backend.get_model_results = Mock(return_value=dummy_model.data) + get_backend.return_value = mock_backend + + response = client.get(f"/models/{dummy_model.model.name}/results") + + assert response.status_code == 200 + assert response.headers["content-type"] == "application/octet-stream" + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_set_results(get_backend: Mock): + mock_backend = Mock() + dummy_model = model_factory() + mock_backend.get_model = Mock(return_value=dummy_model.model) + get_backend.return_value = mock_backend + + response = client.post(f"/models/{dummy_model.model.name}/results") + + assert response.status_code == 200 + assert response.json()["status"] == "awaiting_upload" + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_get_truth(get_backend: Mock): + mock_backend = Mock() + dummy_model = model_factory() + mock_backend.get_model_truth = Mock(return_value=0.95) + get_backend.return_value = mock_backend + + response = client.get(f"/models/{dummy_model.model.name}/truth") + + assert response.status_code == 200 + assert response.json() == 0.95 + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_set_truth(get_backend: Mock): + mock_backend = Mock() + dummy_model = model_factory() + get_backend.return_value = mock_backend + + response = client.patch(f"/models/{dummy_model.model.name}/truth", json=0.95) + + assert response.status_code == 200 + assert response.json()["success"] is True + mock_backend.set_model_truth.assert_called_once_with( + model=dummy_model.model.name, truth=0.95 + ) + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_get_ancestors(get_backend: Mock): + mock_backend = Mock() + dummy_model = model_factory() + mock_ancestors = [ + ModelAncestor(name="parent_model", truth=0.7), + ModelAncestor(name="grandparent_model", truth=0.97), + ] + mock_backend.get_model_ancestors = Mock(return_value=mock_ancestors) + get_backend.return_value = mock_backend + + response = client.get(f"/models/{dummy_model.model.name}/ancestors") + + assert response.status_code == 200 + assert len(response.json()) == 2 + assert [ModelAncestor.model_validate(a) for a in response.json()] == mock_ancestors + + +# Query and match endpoints @patch("matchbox.server.base.BackendManager.get_backend") From a4eec8cd5cdf923577bfd2a4af91b423e0ea9197 Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Thu, 13 Feb 2025 07:42:36 +0000 Subject: [PATCH 05/19] Reordered adapter, API and handler to use identical grouping and ordering of their functions --- src/matchbox/client/_handler.py | 101 +++--- src/matchbox/server/api/routes.py | 383 +++++++++++----------- src/matchbox/server/base.py | 7 +- src/matchbox/server/postgresql/adapter.py | 6 +- 4 files changed, 261 insertions(+), 236 deletions(-) diff --git a/src/matchbox/client/_handler.py b/src/matchbox/client/_handler.py index 739df5b1..dc4f76a4 100644 --- a/src/matchbox/client/_handler.py +++ b/src/matchbox/client/_handler.py @@ -85,52 +85,7 @@ def handle_http_code(res: httpx.Response) -> httpx.Response: raise MatchboxUnhandledServerResponse(res.content) -def get_resolution_graph() -> ResolutionGraph: - """Get the resolution graph from Matchbox.""" - res = handle_http_code(httpx.get(url("/report/resolutions"))) - return ResolutionGraph.model_validate(res.json()) - - -def get_source(address: SourceAddress) -> Source: - warehouse_hash_b64 = hash_to_base64(address.warehouse_hash) - res = handle_http_code( - httpx.get(url(f"/sources/{warehouse_hash_b64}/{address.full_name}")) - ) - return Source.model_validate(res.json()) - - -def index(source: Source, data_hashes: Table) -> UploadStatus: - """Index a Source in Matchbox.""" - buffer = table_to_buffer(table=data_hashes) - - # Upload metadata - metadata_res = handle_http_code( - httpx.post(url("/sources"), json=source.model_dump()) - ) - upload = UploadStatus.model_validate(metadata_res.json()) - - # Upload data - upload_res = handle_http_code( - httpx.post( - url(f"/upload/{upload.id}"), - files={ - "file": (f"{upload.id}.parquet", buffer, "application/octet-stream") - }, - ) - ) - - # Poll until complete with retry/timeout configuration - status = UploadStatus.model_validate(upload_res.json()) - while status.status not in ["complete", "failed"]: - status_res = handle_http_code(httpx.get(url(f"/upload/{upload.id}/status"))) - status = UploadStatus.model_validate(status_res.json()) - - if status.status == "failed": - raise MatchboxServerFileError(status.details) - - time.sleep(2) - - return status +# Retrieval def query( @@ -198,3 +153,57 @@ def match( ) return [Match.model_validate(m) for m in res.json()] + + +# Data management + + +def index(source: Source, data_hashes: Table) -> UploadStatus: + """Index a Source in Matchbox.""" + buffer = table_to_buffer(table=data_hashes) + + # Upload metadata + metadata_res = handle_http_code( + httpx.post(url("/sources"), json=source.model_dump()) + ) + upload = UploadStatus.model_validate(metadata_res.json()) + + # Upload data + upload_res = handle_http_code( + httpx.post( + url(f"/upload/{upload.id}"), + files={ + "file": (f"{upload.id}.parquet", buffer, "application/octet-stream") + }, + ) + ) + + # Poll until complete with retry/timeout configuration + status = UploadStatus.model_validate(upload_res.json()) + while status.status not in ["complete", "failed"]: + status_res = handle_http_code(httpx.get(url(f"/upload/{upload.id}/status"))) + status = UploadStatus.model_validate(status_res.json()) + + if status.status == "failed": + raise MatchboxServerFileError(status.details) + + time.sleep(2) + + return status + + +def get_source(address: SourceAddress) -> Source: + warehouse_hash_b64 = hash_to_base64(address.warehouse_hash) + res = handle_http_code( + httpx.get(url(f"/sources/{warehouse_hash_b64}/{address.full_name}")) + ) + return Source.model_validate(res.json()) + + +def get_resolution_graph() -> ResolutionGraph: + """Get the resolution graph from Matchbox.""" + res = handle_http_code(httpx.get(url("/report/resolutions"))) + return ResolutionGraph.model_validate(res.json()) + + +# Model management diff --git a/src/matchbox/server/api/routes.py b/src/matchbox/server/api/routes.py index 95c9ebe4..a1b0339a 100644 --- a/src/matchbox/server/api/routes.py +++ b/src/matchbox/server/api/routes.py @@ -74,6 +74,9 @@ async def http_exception_handler(request, exc): return JSONResponse(content=exc.detail, status_code=exc.status_code) +# General + + def get_backend() -> MatchboxDBAdapter: return BackendManager.get_backend() @@ -101,45 +104,6 @@ def get_count(e: BackendCountableType) -> int: return CountResult(entities=res) -@app.post("/testing/clear") -async def clear_backend(): - raise HTTPException(status_code=501, detail="Not implemented") - - -@app.get("/sources") -async def list_sources(): - raise HTTPException(status_code=501, detail="Not implemented") - - -@app.get( - "/sources/{warehouse_hash_b64}/{full_name}", - responses={404: {"model": NotFoundError}}, -) -async def get_source( - backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], - warehouse_hash_b64: str, - full_name: str, -) -> Source: - """Get a source from the backend.""" - address = SourceAddress(full_name=full_name, warehouse_hash=warehouse_hash_b64) - try: - return backend.get_source(address) - except MatchboxSourceNotFoundError as e: - raise HTTPException( - status_code=404, - detail=NotFoundError( - details=str(e), entity=BackendRetrievableType.SOURCE - ).model_dump(), - ) from e - - -@app.post("/sources") -async def add_source(source: Source) -> UploadStatus: - """Create an upload and insert task for indexed source data.""" - upload_id = metadata_store.cache_source(metadata=source) - return metadata_store.get(cache_id=upload_id).status - - @app.post( "/upload/{upload_id}", responses={400: {"model": UploadStatus, **UploadStatus.status_400_examples()}}, @@ -249,6 +213,149 @@ async def get_upload_status( return source_cache.status +# Retrieval + + +@app.get( + "/query", + responses={404: {"model": NotFoundError}}, +) +async def query( + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], + full_name: str, + warehouse_hash_b64: str, + resolution_name: str | None = None, + threshold: int | None = None, + limit: int | None = None, +) -> ParquetResponse: + source_address = SourceAddress( + full_name=full_name, warehouse_hash=warehouse_hash_b64 + ) + try: + res = backend.query( + source_address=source_address, + resolution_name=resolution_name, + threshold=threshold, + limit=limit, + ) + except MatchboxResolutionNotFoundError as e: + raise HTTPException( + status_code=404, + detail=NotFoundError( + details=str(e), entity=BackendRetrievableType.RESOLUTION + ).model_dump(), + ) from e + except MatchboxSourceNotFoundError as e: + raise HTTPException( + status_code=404, + detail=NotFoundError( + details=str(e), entity=BackendRetrievableType.SOURCE + ).model_dump(), + ) from e + + buffer = table_to_buffer(res) + return ParquetResponse(buffer.getvalue()) + + +@app.get( + "/match", + responses={404: {"model": NotFoundError}}, +) +async def match( + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], + target_full_names: Annotated[list[str], Query()], + target_warehouse_hashes_b64: Annotated[list[str], Query()], + source_full_name: str, + source_warehouse_hash_b64: str, + source_pk: str, + resolution_name: str, + threshold: int | None = None, +) -> list[Match]: + targets = [ + SourceAddress(full_name=n, warehouse_hash=w) + for n, w in zip(target_full_names, target_warehouse_hashes_b64, strict=True) + ] + source = SourceAddress( + full_name=source_full_name, warehouse_hash=source_warehouse_hash_b64 + ) + try: + res = backend.match( + source_pk=source_pk, + source=source, + targets=targets, + resolution_name=resolution_name, + threshold=threshold, + ) + except MatchboxResolutionNotFoundError as e: + raise HTTPException( + status_code=404, + detail=NotFoundError( + details=str(e), entity=BackendRetrievableType.RESOLUTION + ).model_dump(), + ) from e + except MatchboxSourceNotFoundError as e: + raise HTTPException( + status_code=404, + detail=NotFoundError( + details=str(e), entity=BackendRetrievableType.SOURCE + ).model_dump(), + ) from e + + return res + + +# Data management + + +@app.get("/sources") +async def list_sources(): + raise HTTPException(status_code=501, detail="Not implemented") + + +@app.post("/sources") +async def add_source(source: Source) -> UploadStatus: + """Create an upload and insert task for indexed source data.""" + upload_id = metadata_store.cache_source(metadata=source) + return metadata_store.get(cache_id=upload_id).status + + +@app.get( + "/sources/{warehouse_hash_b64}/{full_name}", + responses={404: {"model": NotFoundError}}, +) +async def get_source( + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], + warehouse_hash_b64: str, + full_name: str, +) -> Source: + """Get a source from the backend.""" + address = SourceAddress(full_name=full_name, warehouse_hash=warehouse_hash_b64) + try: + return backend.get_source(address) + except MatchboxSourceNotFoundError as e: + raise HTTPException( + status_code=404, + detail=NotFoundError( + details=str(e), entity=BackendRetrievableType.SOURCE + ).model_dump(), + ) from e + + +@app.get("/report/resolutions") +async def get_resolutions( + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], +) -> ResolutionGraph: + return backend.get_resolution_graph() + + +@app.post("/testing/clear") +async def clear_backend(): + raise HTTPException(status_code=501, detail="Not implemented") + + +# Model management + + @app.get("/models") async def list_models(): raise HTTPException(status_code=501, detail="Not implemented") @@ -305,72 +412,6 @@ async def get_model( ) from e -@app.delete( - "/models/{name}", - responses={ - 404: {"model": NotFoundError}, - 409: { - "model": ModelOperationStatus, - **ModelOperationStatus.status_409_examples(), - }, - }, -) -async def delete_model( - backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], - name: str, - certain: Annotated[ - bool, Query(description="Confirm deletion of the model") - ] = False, -) -> ModelOperationStatus: - """Delete a model from the backend.""" - try: - backend.delete_model(model=name, certain=certain) - return ModelOperationStatus( - success=True, - model_name=name, - operation=ModelOperationType.DELETE, - ) - except MatchboxResolutionNotFoundError as e: - raise HTTPException( - status_code=404, - detail=NotFoundError( - details=str(e), entity=BackendRetrievableType.RESOLUTION - ).model_dump(), - ) from e - except MatchboxConfirmDelete as e: - raise HTTPException( - status_code=409, - detail=ModelOperationStatus( - success=False, - model_name=name, - operation=ModelOperationType.DELETE, - details=str(e), - ).model_dump(), - ) from e - - -@app.get( - "/models/{name}/results", - responses={404: {"model": NotFoundError}}, -) -async def get_results( - backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], name: str -) -> ParquetResponse: - """Download results for a model as a parquet file.""" - try: - res = backend.get_model_results(model=name) - except MatchboxResolutionNotFoundError as e: - raise HTTPException( - status_code=404, - detail=NotFoundError( - details=str(e), entity=BackendRetrievableType.RESOLUTION - ).model_dump(), - ) from e - - buffer = table_to_buffer(res) - return ParquetResponse(buffer.getvalue()) - - @app.post( "/models/{name}/results", responses={404: {"model": NotFoundError}}, @@ -394,15 +435,15 @@ async def set_results( @app.get( - "/models/{name}/truth", + "/models/{name}/results", responses={404: {"model": NotFoundError}}, ) -async def get_truth( +async def get_results( backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], name: str -) -> float: - """Get truth data for a model.""" +) -> ParquetResponse: + """Download results for a model as a parquet file.""" try: - return backend.get_model_truth(model=name) + res = backend.get_model_results(model=name) except MatchboxResolutionNotFoundError as e: raise HTTPException( status_code=404, @@ -411,6 +452,9 @@ async def get_truth( ).model_dump(), ) from e + buffer = table_to_buffer(res) + return ParquetResponse(buffer.getvalue()) + @app.patch( "/models/{name}/truth", @@ -455,14 +499,15 @@ async def set_truth( @app.get( - "/models/{name}/ancestors", + "/models/{name}/truth", responses={404: {"model": NotFoundError}}, ) -async def get_ancestors( +async def get_truth( backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], name: str -) -> list[ModelAncestor]: +) -> float: + """Get truth data for a model.""" try: - return backend.get_model_ancestors(model=name) + return backend.get_model_truth(model=name) except MatchboxResolutionNotFoundError as e: raise HTTPException( status_code=404, @@ -473,14 +518,14 @@ async def get_ancestors( @app.get( - "/models/{name}/ancestors_cache", + "/models/{name}/ancestors", responses={404: {"model": NotFoundError}}, ) -async def get_ancestors_cache( +async def get_ancestors( backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], name: str ) -> list[ModelAncestor]: try: - return backend.get_model_ancestors_cache(model=name) + return backend.get_model_ancestors(model=name) except MatchboxResolutionNotFoundError as e: raise HTTPException( status_code=404, @@ -532,27 +577,14 @@ async def set_ancestors_cache( @app.get( - "/query", + "/models/{name}/ancestors_cache", responses={404: {"model": NotFoundError}}, ) -async def query( - backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], - full_name: str, - warehouse_hash_b64: str, - resolution_name: str | None = None, - threshold: int | None = None, - limit: int | None = None, -) -> ParquetResponse: - source_address = SourceAddress( - full_name=full_name, warehouse_hash=warehouse_hash_b64 - ) +async def get_ancestors_cache( + backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], name: str +) -> list[ModelAncestor]: try: - res = backend.query( - source_address=source_address, - resolution_name=resolution_name, - threshold=threshold, - limit=limit, - ) + return backend.get_model_ancestors_cache(model=name) except MatchboxResolutionNotFoundError as e: raise HTTPException( status_code=404, @@ -560,46 +592,32 @@ async def query( details=str(e), entity=BackendRetrievableType.RESOLUTION ).model_dump(), ) from e - except MatchboxSourceNotFoundError as e: - raise HTTPException( - status_code=404, - detail=NotFoundError( - details=str(e), entity=BackendRetrievableType.SOURCE - ).model_dump(), - ) from e - buffer = table_to_buffer(res) - return ParquetResponse(buffer.getvalue()) - -@app.get( - "/match", - responses={404: {"model": NotFoundError}}, +@app.delete( + "/models/{name}", + responses={ + 404: {"model": NotFoundError}, + 409: { + "model": ModelOperationStatus, + **ModelOperationStatus.status_409_examples(), + }, + }, ) -async def match( +async def delete_model( backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], - target_full_names: Annotated[list[str], Query()], - target_warehouse_hashes_b64: Annotated[list[str], Query()], - source_full_name: str, - source_warehouse_hash_b64: str, - source_pk: str, - resolution_name: str, - threshold: int | None = None, -) -> list[Match]: - targets = [ - SourceAddress(full_name=n, warehouse_hash=w) - for n, w in zip(target_full_names, target_warehouse_hashes_b64, strict=True) - ] - source = SourceAddress( - full_name=source_full_name, warehouse_hash=source_warehouse_hash_b64 - ) + name: str, + certain: Annotated[ + bool, Query(description="Confirm deletion of the model") + ] = False, +) -> ModelOperationStatus: + """Delete a model from the backend.""" try: - res = backend.match( - source_pk=source_pk, - source=source, - targets=targets, - resolution_name=resolution_name, - threshold=threshold, + backend.delete_model(model=name, certain=certain) + return ModelOperationStatus( + success=True, + model_name=name, + operation=ModelOperationType.DELETE, ) except MatchboxResolutionNotFoundError as e: raise HTTPException( @@ -608,24 +626,13 @@ async def match( details=str(e), entity=BackendRetrievableType.RESOLUTION ).model_dump(), ) from e - except MatchboxSourceNotFoundError as e: + except MatchboxConfirmDelete as e: raise HTTPException( - status_code=404, - detail=NotFoundError( - details=str(e), entity=BackendRetrievableType.SOURCE + status_code=409, + detail=ModelOperationStatus( + success=False, + model_name=name, + operation=ModelOperationType.DELETE, + details=str(e), ).model_dump(), ) from e - - return res - - -@app.get("/validate/hash") -async def validate_hashes(): - raise HTTPException(status_code=501, detail="Not implemented") - - -@app.get("/report/resolutions") -async def get_resolutions( - backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], -) -> ResolutionGraph: - return backend.get_resolution_graph() diff --git a/src/matchbox/server/base.py b/src/matchbox/server/base.py index b1064bd5..d7e3b553 100644 --- a/src/matchbox/server/base.py +++ b/src/matchbox/server/base.py @@ -233,6 +233,8 @@ class MatchboxDBAdapter(ABC): merges: Countable proposes: Countable + # Retrieval + @abstractmethod def query( self, @@ -252,6 +254,8 @@ def match( threshold: int | None = None, ) -> list[Match]: ... + # Data management + @abstractmethod def index(self, source: Source, data_hashes: Table) -> None: ... @@ -273,7 +277,8 @@ def get_resolution_graph(self) -> ResolutionGraph: ... @abstractmethod def clear(self, certain: bool) -> None: ... - # Model methods + # Model management + @abstractmethod def insert_model(self, model: ModelMetadata) -> None: ... diff --git a/src/matchbox/server/postgresql/adapter.py b/src/matchbox/server/postgresql/adapter.py index 25c24fbd..2d8a7f27 100644 --- a/src/matchbox/server/postgresql/adapter.py +++ b/src/matchbox/server/postgresql/adapter.py @@ -126,6 +126,8 @@ def __init__(self, settings: MatchboxPostgresSettings): self.creates = FilteredProbabilities(over_truth=True) self.proposes = FilteredProbabilities() + # Retrieval + def query( self, source_address: SourceAddress, @@ -186,6 +188,8 @@ def match( threshold=threshold, ) + # Data management + def index(self, source: Source, data_hashes: Table) -> None: """Indexes to Matchbox a source dataset in your warehouse. @@ -355,7 +359,7 @@ def clear(self, certain: bool = False) -> None: "If you're sure you want to continue, rerun with certain=True" ) - # Model methods + # Model management def insert_model(self, model: ModelMetadata) -> None: """Writes a model to Matchbox. From 9fb82d6c5cc657dc3bc7308bca34aca60e791a3b Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Thu, 13 Feb 2025 08:52:19 +0000 Subject: [PATCH 06/19] Cleaned up HTTPX client to automatically handle the URL and error processing --- src/matchbox/__init__.py | 16 ++-- src/matchbox/client/_handler.py | 137 ++++++++++++++++-------------- test/client/test_helpers.py | 101 +++++++++++----------- test/client/test_visualisation.py | 13 +-- 4 files changed, 139 insertions(+), 128 deletions(-) diff --git a/src/matchbox/__init__.py b/src/matchbox/__init__.py index d9e0a31a..86229219 100644 --- a/src/matchbox/__init__.py +++ b/src/matchbox/__init__.py @@ -1,11 +1,13 @@ from dotenv import find_dotenv, load_dotenv -from matchbox.client.helpers.cleaner import process -from matchbox.client.helpers.index import index -from matchbox.client.helpers.selector import match, query -from matchbox.client.models.models import make_model - -__all__ = ("make_model", "process", "query", "match", "index") - dotenv_path = find_dotenv(usecwd=True) load_dotenv(dotenv_path) + +# Environment variables must be loaded first for other imports to work + +from matchbox.client.helpers.cleaner import process # NoQA: E402 +from matchbox.client.helpers.index import index # NoQA: E402 +from matchbox.client.helpers.selector import match, query # NoQA: E402 +from matchbox.client.models.models import make_model # NoQA: E402 + +__all__ = ("make_model", "process", "query", "match", "index") diff --git a/src/matchbox/client/_handler.py b/src/matchbox/client/_handler.py index ad42d1d9..d4ca8d91 100644 --- a/src/matchbox/client/_handler.py +++ b/src/matchbox/client/_handler.py @@ -8,7 +8,13 @@ from pyarrow.parquet import read_table from matchbox.common.arrow import SCHEMA_MB_IDS, table_to_buffer -from matchbox.common.dtos import BackendRetrievableType, NotFoundError, UploadStatus +from matchbox.common.dtos import ( + BackendRetrievableType, + ModelMetadata, + ModelOperationStatus, + NotFoundError, + UploadStatus, +) from matchbox.common.exceptions import ( MatchboxClientFileError, MatchboxResolutionNotFoundError, @@ -21,23 +27,6 @@ from matchbox.common.hash import hash_to_base64 from matchbox.common.sources import Match, Source, SourceAddress -if timeout := getenv("MB__CLIENT__TIMEOUT"): - CLIENT = httpx.Client(timeout=float(timeout)) -else: - CLIENT = httpx.Client(timeout=None) - - -def url(path: str) -> str: - """Return path prefixed by API root, determined from environment.""" - api_root = getenv("MB__CLIENT__API_ROOT") - if api_root is None: - raise RuntimeError( - "MB__CLIENT__API_ROOT needs to be defined in the environment" - ) - - return api_root + path - - URLEncodeHandledType = str | int | float | bytes @@ -67,6 +56,8 @@ def url_params( def handle_http_code(res: httpx.Response) -> httpx.Response: """Handle HTTP status codes and raise appropriate exceptions.""" + res.read() + if res.status_code == 200: return res @@ -92,6 +83,25 @@ def handle_http_code(res: httpx.Response) -> httpx.Response: raise MatchboxUnhandledServerResponse(res.content) +def create_client() -> httpx.Client: + """Create an HTTPX client with proper configuration.""" + api_root = getenv("MB__CLIENT__API_ROOT") + timeout = getenv("MB__CLIENT__TIMEOUT") + if api_root is None: + raise RuntimeError( + "MB__CLIENT__API_ROOT needs to be defined in the environment" + ) + if timeout is not None: + timeout = float(timeout) + + return httpx.Client( + base_url=api_root, timeout=timeout, event_hooks={"response": [handle_http_code]} + ) + + +CLIENT = create_client() + + # Retrieval @@ -101,20 +111,18 @@ def query( threshold: int | None = None, limit: int | None = None, ) -> BytesIO: - res = handle_http_code( - CLIENT.get( - url("/query"), - params=url_params( - { - "full_name": source_address.full_name, - # Converted to b64 by `url_params()` - "warehouse_hash_b64": source_address.warehouse_hash, - "resolution_name": resolution_name, - "threshold": threshold, - "limit": limit, - } - ), - ) + res = CLIENT.get( + "/query", + params=url_params( + { + "full_name": source_address.full_name, + # Converted to b64 by `url_params()` + "warehouse_hash_b64": source_address.warehouse_hash, + "resolution_name": resolution_name, + "threshold": threshold, + "limit": limit, + } + ), ) buffer = BytesIO(res.content) @@ -140,23 +148,21 @@ def match( target_full_names = [t.full_name for t in targets] target_warehouse_hashes = [t.warehouse_hash for t in targets] - res = handle_http_code( - CLIENT.get( - url("/match"), - params=url_params( - { - "target_full_names": target_full_names, - # Converted to b64 by `url_params()` - "target_warehouse_hashes_b64": target_warehouse_hashes, - "source_full_name": source.full_name, - # Converted to b64 by `url_params()` - "source_warehouse_hash_b64": source.warehouse_hash, - "source_pk": source_pk, - "resolution_name": resolution_name, - "threshold": threshold, - } - ), - ) + res = CLIENT.get( + "/match", + params=url_params( + { + "target_full_names": target_full_names, + # Converted to b64 by `url_params()` + "target_warehouse_hashes_b64": target_warehouse_hashes, + "source_full_name": source.full_name, + # Converted to b64 by `url_params()` + "source_warehouse_hash_b64": source.warehouse_hash, + "source_pk": source_pk, + "resolution_name": resolution_name, + "threshold": threshold, + } + ), ) return [Match.model_validate(m) for m in res.json()] @@ -170,25 +176,20 @@ def index(source: Source, data_hashes: Table) -> UploadStatus: buffer = table_to_buffer(table=data_hashes) # Upload metadata - metadata_res = handle_http_code( - CLIENT.post(url("/sources"), json=source.model_dump()) - ) + metadata_res = CLIENT.post("/sources", json=source.model_dump()) + upload = UploadStatus.model_validate(metadata_res.json()) # Upload data - upload_res = handle_http_code( - CLIENT.post( - url(f"/upload/{upload.id}"), - files={ - "file": (f"{upload.id}.parquet", buffer, "application/octet-stream") - }, - ) + upload_res = CLIENT.post( + f"/upload/{upload.id}", + files={"file": (f"{upload.id}.parquet", buffer, "application/octet-stream")}, ) # Poll until complete with retry/timeout configuration status = UploadStatus.model_validate(upload_res.json()) while status.status not in ["complete", "failed"]: - status_res = handle_http_code(CLIENT.get(url(f"/upload/{upload.id}/status"))) + status_res = CLIENT.get(f"/upload/{upload.id}/status") status = UploadStatus.model_validate(status_res.json()) if status.status == "failed": @@ -201,13 +202,21 @@ def index(source: Source, data_hashes: Table) -> UploadStatus: def get_source(address: SourceAddress) -> Source: warehouse_hash_b64 = hash_to_base64(address.warehouse_hash) - res = handle_http_code( - CLIENT.get(url(f"/sources/{warehouse_hash_b64}/{address.full_name}")) - ) + res = CLIENT.get(f"/sources/{warehouse_hash_b64}/{address.full_name}") + return Source.model_validate(res.json()) def get_resolution_graph() -> ResolutionGraph: """Get the resolution graph from Matchbox.""" - res = handle_http_code(CLIENT.get(url("/report/resolutions"))) + res = CLIENT.get("/report/resolutions") return ResolutionGraph.model_validate(res.json()) + + +# Model management + + +def insert_model(model: ModelMetadata) -> ModelOperationStatus: + """Insert a model in Matchbox.""" + res = CLIENT.post("/models", json=model.model_dump()) + return ModelOperationStatus.model_validate(res.json()) diff --git a/test/client/test_helpers.py b/test/client/test_helpers.py index 5a0b0ee5..d2c7bb73 100644 --- a/test/client/test_helpers.py +++ b/test/client/test_helpers.py @@ -3,14 +3,13 @@ import pyarrow as pa import pytest -import respx from dotenv import find_dotenv, load_dotenv from httpx import Response from pandas import DataFrame +from respx import MockRouter from sqlalchemy import Engine, create_engine from matchbox import index, match, process, query -from matchbox.client._handler import url from matchbox.client.clean import company_name, company_number from matchbox.client.helpers import cleaner, cleaners, comparison, select from matchbox.client.helpers.selector import Match, Selector @@ -46,9 +45,7 @@ def test_cleaners(): assert cleaner_name_number is not None -def test_process( - warehouse_data: list[Source], -): +def test_process(warehouse_data: list[Source]): crn = warehouse_data[0].to_arrow() cleaner_name = cleaner( @@ -77,8 +74,8 @@ def test_comparisons(): assert comparison_name_id is not None -@respx.mock -def test_select_mixed_style(warehouse_engine: Engine): +@pytest.mark.respx(base_url="http://localhost:8000") +def test_select_mixed_style(respx_mock: MockRouter, warehouse_engine: Engine): """We can select select specific columns from some of the sources""" # Set up mocks and test data source1 = Source( @@ -91,11 +88,11 @@ def test_select_mixed_style(warehouse_engine: Engine): db_pk="pk", ) - respx.get( - url(f"/sources/{hash_to_base64(source1.address.warehouse_hash)}/test.foo") + respx_mock.get( + f"/sources/{hash_to_base64(source1.address.warehouse_hash)}/test.foo" ).mock(return_value=Response(200, content=source1.model_dump_json())) - respx.get( - url(f"/sources/{hash_to_base64(source2.address.warehouse_hash)}/test.bar") + respx_mock.get( + f"/sources/{hash_to_base64(source2.address.warehouse_hash)}/test.bar" ).mock(return_value=Response(200, content=source2.model_dump_json())) df = DataFrame([{"pk": 0, "a": 1, "b": "2"}, {"pk": 1, "a": 10, "b": "20"}]) @@ -129,15 +126,15 @@ def test_select_mixed_style(warehouse_engine: Engine): assert selection[1].source.engine == warehouse_engine -@respx.mock -def test_select_non_indexed_columns(warehouse_engine: Engine): +@pytest.mark.respx(base_url="http://localhost:8000") +def test_select_non_indexed_columns(respx_mock: MockRouter, warehouse_engine: Engine): """Selecting columns not declared to backend generates warning.""" source = Source( address=SourceAddress.compose(engine=warehouse_engine, full_name="test.foo"), db_pk="pk", ) - respx.get( - url(f"/sources/{hash_to_base64(source.address.warehouse_hash)}/test.foo") + respx_mock.get( + f"/sources/{hash_to_base64(source.address.warehouse_hash)}/test.foo" ).mock(return_value=Response(200, content=source.model_dump_json())) df = DataFrame([{"pk": 0, "a": 1, "b": "2"}, {"pk": 1, "a": 10, "b": "20"}]) @@ -154,16 +151,16 @@ def test_select_non_indexed_columns(warehouse_engine: Engine): select({"test.foo": ["a", "b"]}, engine=warehouse_engine) -@respx.mock -def test_select_missing_columns(warehouse_engine: Engine): +@pytest.mark.respx(base_url="http://localhost:8000") +def test_select_missing_columns(respx_mock: MockRouter, warehouse_engine: Engine): """Selecting columns not in the warehouse errors.""" source = Source( address=SourceAddress.compose(engine=warehouse_engine, full_name="test.foo"), db_pk="pk", ) - respx.get( - url(f"/sources/{hash_to_base64(source.address.warehouse_hash)}/test.foo") + respx_mock.get( + f"/sources/{hash_to_base64(source.address.warehouse_hash)}/test.foo" ).mock(return_value=Response(200, content=source.model_dump_json())) df = DataFrame([{"pk": 0, "a": 1, "b": "2"}, {"pk": 1, "a": 10, "b": "20"}]) @@ -205,12 +202,12 @@ def test_query_no_resolution_fail(): query(sels) -@respx.mock +@pytest.mark.respx(base_url="http://localhost:8000") @patch.object(Source, "to_arrow") -def test_query_no_resolution_ok_various_params(to_arrow: Mock): +def test_query_no_resolution_ok_various_params(to_arrow: Mock, respx_mock: MockRouter): """Tests that we can avoid passing resolution name, with a variety of parameters.""" # Mock API - query_route = respx.get(url("/query")).mock( + query_route = respx_mock.get("/query").mock( return_value=Response( 200, content=table_to_buffer( @@ -273,12 +270,12 @@ def test_query_no_resolution_ok_various_params(to_arrow: Mock): } -@respx.mock +@pytest.mark.respx(base_url="http://localhost:8000") @patch.object(Source, "to_arrow") -def test_query_multiple_sources_with_limits(to_arrow: Mock): +def test_query_multiple_sources_with_limits(to_arrow: Mock, respx_mock: MockRouter): """Tests that we can query multiple sources and distribute the limit among them.""" # Mock API - query_route = respx.get(url("/query")).mock( + query_route = respx_mock.get("/query").mock( side_effect=[ Response( 200, @@ -375,10 +372,10 @@ def test_query_multiple_sources_with_limits(to_arrow: Mock): query([sels[0]], [sels[1]], resolution_name="link", limit=7) -@respx.mock -def test_query_404_resolution(): +@pytest.mark.respx(base_url="http://localhost:8000") +def test_query_404_resolution(respx_mock: MockRouter): # Mock API - respx.get(url("/query")).mock( + respx_mock.get("/query").mock( return_value=Response( 404, json=NotFoundError( @@ -407,10 +404,10 @@ def test_query_404_resolution(): query(sels) -@respx.mock -def test_query_404_source(): +@pytest.mark.respx(base_url="http://localhost:8000") +def test_query_404_source(respx_mock: MockRouter): # Mock API - respx.get(url("/query")).mock( + respx_mock.get("/query").mock( return_value=Response( 404, json=NotFoundError( @@ -439,9 +436,9 @@ def test_query_404_source(): query(sels) -@respx.mock +@pytest.mark.respx(base_url="http://localhost:8000") @patch("matchbox.client.helpers.index.Source") -def test_index_success(MockSource: Mock): +def test_index_success(MockSource: Mock, respx_mock: MockRouter): """Test successful indexing flow through the API.""" engine = create_engine("sqlite:///:memory:") @@ -453,7 +450,7 @@ def test_index_success(MockSource: Mock): MockSource.return_value = mock_source_instance # Mock the initial source metadata upload - source_route = respx.post(url("/sources")).mock( + source_route = respx_mock.post("/sources").mock( return_value=Response( 200, json=UploadStatus( @@ -465,7 +462,7 @@ def test_index_success(MockSource: Mock): ) # Mock the data upload - upload_route = respx.post(url("/upload/test-upload-id")).mock( + upload_route = respx_mock.post("/upload/test-upload-id").mock( return_value=Response( 200, json=UploadStatus( @@ -491,7 +488,7 @@ def test_index_success(MockSource: Mock): assert b"PAR1" in upload_route.calls.last.request.content -@respx.mock +@pytest.mark.respx(base_url="http://localhost:8000") @patch("matchbox.client.helpers.index.Source") @pytest.mark.parametrize( "columns", @@ -508,7 +505,7 @@ def test_index_success(MockSource: Mock): ], ) def test_index_with_columns( - MockSource: Mock, columns: list[str] | list[dict[str, str]] + MockSource: Mock, respx_mock: MockRouter, columns: list[str] | list[dict[str, str]] ): """Test indexing with different column definition formats.""" engine = create_engine("sqlite:///:memory:") @@ -525,7 +522,7 @@ def test_index_with_columns( MockSource.return_value = mock_source_instance # Mock the API endpoints - source_route = respx.post(url("/sources")).mock( + source_route = respx_mock.post("/sources").mock( return_value=Response( 200, json=UploadStatus( @@ -536,7 +533,7 @@ def test_index_with_columns( ) ) - upload_route = respx.post(url("/upload/test-upload-id")).mock( + upload_route = respx_mock.post("/upload/test-upload-id").mock( return_value=Response( 200, json=UploadStatus( @@ -569,9 +566,9 @@ def test_index_with_columns( mock_source_instance.default_columns.assert_called_once() -@respx.mock +@pytest.mark.respx(base_url="http://localhost:8000") @patch("matchbox.client.helpers.index.Source") -def test_index_upload_failure(MockSource: Mock): +def test_index_upload_failure(MockSource: Mock, respx_mock: MockRouter): """Test handling of upload failures.""" engine = create_engine("sqlite:///:memory:") @@ -583,7 +580,7 @@ def test_index_upload_failure(MockSource: Mock): MockSource.return_value = mock_source_instance # Mock successful source creation - source_route = respx.post(url("/sources")).mock( + source_route = respx_mock.post("/sources").mock( return_value=Response( 200, json=UploadStatus( @@ -595,7 +592,7 @@ def test_index_upload_failure(MockSource: Mock): ) # Mock failed upload - upload_route = respx.post(url("/upload/test-upload-id")).mock( + upload_route = respx_mock.post("/upload/test-upload-id").mock( return_value=Response( 400, json=UploadStatus( @@ -625,8 +622,8 @@ def test_index_upload_failure(MockSource: Mock): assert b"PAR1" in upload_route.calls.last.request.content -@respx.mock -def test_match_ok(): +@pytest.mark.respx(base_url="http://localhost:8000") +def test_match_ok(respx_mock: MockRouter): """The client can perform the right call for matching.""" # Set up mocks mock_match1 = Match( @@ -648,7 +645,7 @@ def test_match_ok(): f"[{mock_match1.model_dump_json()}, {mock_match2.model_dump_json()}]" ) - match_route = respx.get(url("/match")).mock( + match_route = respx_mock.get("/match").mock( return_value=Response(200, content=serialised_matches) ) @@ -717,11 +714,11 @@ def test_match_ok(): ) -@respx.mock -def test_match_404_resolution(): +@pytest.mark.respx(base_url="http://localhost:8000") +def test_match_404_resolution(respx_mock: MockRouter): """The client can handle a resolution not found error.""" # Set up mocks - respx.get(url("/match")).mock( + respx_mock.get("/match").mock( return_value=Response( 404, json=NotFoundError( @@ -766,11 +763,11 @@ def test_match_404_resolution(): ) -@respx.mock -def test_match_404_source(): +@pytest.mark.respx(base_url="http://localhost:8000") +def test_match_404_source(respx_mock: MockRouter): """The client can handle a source not found error.""" # Set up mocks - respx.get(url("/match")).mock( + respx_mock.get("/match").mock( return_value=Response( 404, json=NotFoundError( diff --git a/test/client/test_visualisation.py b/test/client/test_visualisation.py index a9f2076e..1891f8e1 100644 --- a/test/client/test_visualisation.py +++ b/test/client/test_visualisation.py @@ -1,14 +1,17 @@ -import respx +import pytest from httpx import Response from matplotlib.figure import Figure +from respx import MockRouter -from matchbox.client._handler import url from matchbox.client.visualisation import draw_resolution_graph +from matchbox.common.graph import ResolutionGraph -@respx.mock -def test_draw_resolution_graph(resolution_graph): - respx.get(url("/report/resolutions")).mock( +@pytest.mark.respx(base_url="http://localhost:8000") +def test_draw_resolution_graph( + respx_mock: MockRouter, resolution_graph: ResolutionGraph +): + respx_mock.get("/report/resolutions").mock( return_value=Response(200, content=resolution_graph.model_dump_json()), ) From cf2a36fc6937b79678a0554a5c377cbf53218b15 Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Thu, 13 Feb 2025 12:08:25 +0000 Subject: [PATCH 07/19] Added unit tests for all handler model functions and added some extra converage for routes --- src/matchbox/client/_handler.py | 87 ++++ src/matchbox/client/models/models.py | 57 ++- src/matchbox/common/exceptions.py | 22 +- src/matchbox/common/factories/models.py | 34 ++ src/matchbox/server/api/routes.py | 2 +- src/matchbox/server/postgresql/adapter.py | 4 +- test/client/test_model.py | 499 ++++++++++++++++++++++ test/client/test_visualisation.py | 4 +- test/server/api/test_routes.py | 75 +++- 9 files changed, 739 insertions(+), 45 deletions(-) create mode 100644 test/client/test_model.py diff --git a/src/matchbox/client/_handler.py b/src/matchbox/client/_handler.py index d4ca8d91..1801734c 100644 --- a/src/matchbox/client/_handler.py +++ b/src/matchbox/client/_handler.py @@ -10,6 +10,7 @@ from matchbox.common.arrow import SCHEMA_MB_IDS, table_to_buffer from matchbox.common.dtos import ( BackendRetrievableType, + ModelAncestor, ModelMetadata, ModelOperationStatus, NotFoundError, @@ -17,6 +18,7 @@ ) from matchbox.common.exceptions import ( MatchboxClientFileError, + MatchboxConfirmDelete, MatchboxResolutionNotFoundError, MatchboxServerFileError, MatchboxSourceNotFoundError, @@ -77,6 +79,10 @@ def handle_http_code(res: httpx.Response) -> httpx.Response: else: raise RuntimeError(f"Unexpected 404 error: {error.details}") + if res.status_code == 409: + error = ModelOperationStatus.model_validate(res.json()) + raise MatchboxConfirmDelete(message=error.details) + if res.status_code == 422: raise MatchboxUnparsedClientRequest(res.content) @@ -220,3 +226,84 @@ def insert_model(model: ModelMetadata) -> ModelOperationStatus: """Insert a model in Matchbox.""" res = CLIENT.post("/models", json=model.model_dump()) return ModelOperationStatus.model_validate(res.json()) + + +def get_model(name: str) -> ModelMetadata: + res = CLIENT.get(f"/models/{name}") + return ModelMetadata.model_validate(res.json()) + + +def add_model_results(name: str, results: Table) -> UploadStatus: + """Upload model results in Matchbox.""" + buffer = table_to_buffer(table=results) + + # Initialise upload + metadata_res = CLIENT.post(f"/models/{name}/results") + + upload = UploadStatus.model_validate(metadata_res.json()) + + # Upload data + upload_res = CLIENT.post( + f"/upload/{upload.id}", + files={"file": (f"{upload.id}.parquet", buffer, "application/octet-stream")}, + ) + + # Poll until complete with retry/timeout configuration + status = UploadStatus.model_validate(upload_res.json()) + while status.status not in ["complete", "failed"]: + status_res = CLIENT.get(f"/upload/{upload.id}/status") + status = UploadStatus.model_validate(status_res.json()) + + if status.status == "failed": + raise MatchboxServerFileError(status.details) + + time.sleep(2) + + return status + + +def get_model_results(name: str) -> Table: + """Get model results from Matchbox.""" + res = CLIENT.get(f"/models/{name}/results") + buffer = BytesIO(res.content) + return read_table(buffer) + + +def set_model_truth(name: str, truth: float) -> ModelOperationStatus: + """Set the truth threshold for a model in Matchbox.""" + res = CLIENT.post(f"/models/{name}/truth", json=truth) + return ModelOperationStatus.model_validate(res.json()) + + +def get_model_truth(name: str) -> float: + """Get the truth threshold for a model in Matchbox.""" + res = CLIENT.get(f"/models/{name}/truth") + return res.json() + + +def get_model_ancestors(name: str) -> list[ModelAncestor]: + """Get the ancestors of a model in Matchbox.""" + res = CLIENT.get(f"/models/{name}/ancestors") + return [ModelAncestor.model_validate(m) for m in res.json()] + + +def set_model_ancestors_cache( + name: str, ancestors: list[ModelAncestor] +) -> ModelOperationStatus: + """Set the ancestors cache for a model in Matchbox.""" + res = CLIENT.post( + f"/models/{name}/ancestors_cache", json=[a.model_dump() for a in ancestors] + ) + return ModelOperationStatus.model_validate(res.json()) + + +def get_model_ancestors_cache(name: str) -> list[ModelAncestor]: + """Get the ancestors cache for a model in Matchbox.""" + res = CLIENT.get(f"/models/{name}/ancestors_cache") + return [ModelAncestor.model_validate(m) for m in res.json()] + + +def delete_model(name: str, certain: bool | None = None) -> ModelOperationStatus: + """Delete a model in Matchbox.""" + res = CLIENT.delete(f"/models/{name}", params={"certain": certain}) + return ModelOperationStatus.model_validate(res.json()) diff --git a/src/matchbox/client/models/models.py b/src/matchbox/client/models/models.py index f192b5b6..3dd3a782 100644 --- a/src/matchbox/client/models/models.py +++ b/src/matchbox/client/models/models.py @@ -2,12 +2,12 @@ from pandas import DataFrame +from matchbox.client import _handler from matchbox.client.models.dedupers.base import Deduper from matchbox.client.models.linkers.base import Linker from matchbox.client.results import Results from matchbox.common.dtos import ModelAncestor, ModelMetadata, ModelType from matchbox.common.exceptions import MatchboxResolutionNotFoundError -from matchbox.server import MatchboxDBAdapter, inject_backend P = ParamSpec("P") R = TypeVar("R") @@ -28,69 +28,64 @@ def __init__( self.left_data = left_data self.right_data = right_data - @inject_backend - def insert_model(self, backend: MatchboxDBAdapter) -> None: + def insert_model(self) -> None: """Insert the model into the backend database.""" - backend.insert_model(model=self.metadata) + _handler.insert_model(model=self.metadata) @property - @inject_backend - def results(self, backend: MatchboxDBAdapter) -> Results: + def results(self) -> Results: """Retrieve results associated with the model from the database.""" - results = backend.get_model_results(model=self.metadata.name) + results = _handler.get_model_results(name=self.metadata.name) return Results(probabilities=results, metadata=self.metadata) @results.setter - @inject_backend - def results(self, backend: MatchboxDBAdapter, results: Results) -> None: + def results(self, results: Results) -> None: """Write results associated with the model to the database.""" - backend.set_model_results( - model=self.metadata.name, results=results.probabilities + _handler.add_model_results( + name=self.metadata.name, results=results.probabilities ) @property - @inject_backend - def truth(self, backend: MatchboxDBAdapter) -> float: + def truth(self) -> float: """Retrieve the truth threshold for the model.""" - return backend.get_model_truth(model=self.metadata.name) + return _handler.get_model_truth(name=self.metadata.name) @truth.setter - @inject_backend - def truth(self, backend: MatchboxDBAdapter, truth: float) -> None: + def truth(self, truth: float) -> None: """Set the truth threshold for the model.""" - backend.set_model_truth(model=self.metadata.name, truth=truth) + _handler.set_model_truth(name=self.metadata.name, truth=truth) @property - @inject_backend - def ancestors(self, backend: MatchboxDBAdapter) -> dict[str, float]: + def ancestors(self) -> dict[str, float]: """Retrieve the ancestors of the model.""" return { ancestor.name: ancestor.truth - for ancestor in backend.get_model_ancestors(model=self.metadata.name) + for ancestor in _handler.get_model_ancestors(name=self.metadata.name) } @property - @inject_backend - def ancestors_cache(self, backend: MatchboxDBAdapter) -> dict[str, float]: + def ancestors_cache(self) -> dict[str, float]: """Retrieve the ancestors cache of the model.""" return { ancestor.name: ancestor.truth - for ancestor in backend.get_model_ancestors_cache(model=self.metadata.name) + for ancestor in _handler.get_model_ancestors_cache(name=self.metadata.name) } @ancestors_cache.setter - @inject_backend - def ancestors_cache( - self, backend: MatchboxDBAdapter, ancestors_cache: dict[str, float] - ) -> None: + def ancestors_cache(self, ancestors_cache: dict[str, float]) -> None: """Set the ancestors cache of the model.""" - backend.set_model_ancestors_cache( - model=self.metadata.name, - ancestors_cache=[ - ModelAncestor(name=k, truth=v) for k, v in ancestors_cache + _handler.set_model_ancestors_cache( + name=self.metadata.name, + ancestors=[ + ModelAncestor(name=k, truth=v) for k, v in ancestors_cache.items() ], ) + def delete(self, certain: bool = False) -> bool: + """Delete the model from the database.""" + result = _handler.delete_model(name=self.metadata.name, certain=certain) + return result.success + def run(self) -> Results: """Execute the model pipeline and return results.""" if self.metadata.type == ModelType.LINKER: diff --git a/src/matchbox/common/exceptions.py b/src/matchbox/common/exceptions.py index 0e7c7f5d..0558c591 100644 --- a/src/matchbox/common/exceptions.py +++ b/src/matchbox/common/exceptions.py @@ -141,14 +141,18 @@ class MatchboxConnectionError(Exception): class MatchboxConfirmDelete(Exception): """Deletion must be confirmed: if certain, rerun with certain=True.""" - def __init__(self, children: list[str]): - children_names = ", ".join(children) - message = ( - f"This operation will delete the resolutions {children_names}, " - "as well as all probabilities they have created. \n\n" - "It won't delete validation associated with these " - "probabilities. \n\n" - "If you're sure you want to continue, rerun with certain=True. " - ) + def __init__(self, message: str | None = None, children: list[str] | None = None): + if message is None: + message = "Deletion must be confirmed: if certain, rerun with certain=True." + + if children is not None: + children_names = ", ".join(children) + message = ( + f"This operation will delete the resolutions {children_names}, " + "as well as all probabilities they have created. \n\n" + "It won't delete validation associated with these " + "probabilities. \n\n" + "If you're sure you want to continue, rerun with certain=True. " + ) super().__init__(message) diff --git a/src/matchbox/common/factories/models.py b/src/matchbox/common/factories/models.py index 3a51ec52..41c0c5f0 100644 --- a/src/matchbox/common/factories/models.py +++ b/src/matchbox/common/factories/models.py @@ -1,13 +1,19 @@ from collections import Counter from textwrap import dedent from typing import Any, Literal +from unittest.mock import Mock, PropertyMock, create_autospec import numpy as np import pyarrow as pa import rustworkx as rx from faker import Faker +from pandas import DataFrame from pydantic import BaseModel, ConfigDict +from matchbox.client.models.dedupers.base import Deduper +from matchbox.client.models.linkers.base import Linker +from matchbox.client.models.models import Model +from matchbox.client.results import Results from matchbox.common.dtos import ModelMetadata, ModelType from matchbox.common.transform import graph_results @@ -315,6 +321,34 @@ class ModelDummy(BaseModel): data: pa.Table metrics: ModelMetrics + def to_mock(self) -> Mock: + """Create a mock Model object that mimics this dummy model's behavior.""" + mock_model = create_autospec(Model) + + # Set basic attributes + mock_model.metadata = self.model + mock_model.left_data = DataFrame() # Default empty DataFrame + mock_model.right_data = ( + DataFrame() if self.model.type == ModelType.LINKER else None + ) + + # Mock results property + mock_results = Results(probabilities=self.data, metadata=self.model) + type(mock_model).results = PropertyMock(return_value=mock_results) + + # Mock run method + mock_model.run.return_value = mock_results + + # Mock the model instance based on type + if self.model.type == ModelType.LINKER: + mock_model.model_instance = create_autospec(Linker) + mock_model.model_instance.link.return_value = self.data + else: + mock_model.model_instance = create_autospec(Deduper) + mock_model.model_instance.dedupe.return_value = self.data + + return mock_model + def model_factory( name: str | None = None, diff --git a/src/matchbox/server/api/routes.py b/src/matchbox/server/api/routes.py index a1b0339a..d5068d8b 100644 --- a/src/matchbox/server/api/routes.py +++ b/src/matchbox/server/api/routes.py @@ -469,7 +469,7 @@ async def get_results( async def set_truth( backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], name: str, - truth: Annotated[float, Body()], + truth: Annotated[float, Body(ge=0.0, le=1.0)], ) -> ModelOperationStatus: """Set truth data for a model.""" try: diff --git a/src/matchbox/server/postgresql/adapter.py b/src/matchbox/server/postgresql/adapter.py index 2d8a7f27..e6cc83cc 100644 --- a/src/matchbox/server/postgresql/adapter.py +++ b/src/matchbox/server/postgresql/adapter.py @@ -532,5 +532,5 @@ def delete_model(self, model: str, certain: bool = False) -> None: session.delete(resolution) session.commit() else: - childen = [r.name for r in resolution.descendants] - raise MatchboxConfirmDelete(childen) + children = [r.name for r in resolution.descendants] + raise MatchboxConfirmDelete(childen=children) diff --git a/test/client/test_model.py b/test/client/test_model.py new file mode 100644 index 00000000..d7222a67 --- /dev/null +++ b/test/client/test_model.py @@ -0,0 +1,499 @@ +import json +from os import getenv +from unittest.mock import Mock + +import pytest +from httpx import Response +from pandas import DataFrame +from respx.router import MockRouter + +from matchbox.client.models.models import Model +from matchbox.client.results import Results +from matchbox.common.arrow import SCHEMA_RESULTS, table_to_buffer +from matchbox.common.dtos import ( + BackendRetrievableType, + BackendUploadType, + ModelAncestor, + ModelOperationStatus, + ModelOperationType, + NotFoundError, + UploadStatus, +) +from matchbox.common.exceptions import ( + MatchboxConfirmDelete, + MatchboxResolutionNotFoundError, + MatchboxServerFileError, + MatchboxUnhandledServerResponse, + MatchboxUnparsedClientRequest, +) +from matchbox.common.factories.models import model_factory + + +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) +def test_insert_model(respx_mock: MockRouter): + """Test inserting a model via the API.""" + # Create test model using factory + dummy = model_factory(model_type="linker") + model = Model( + metadata=dummy.model, + model_instance=Mock(), + left_data=DataFrame(), + right_data=DataFrame(), + ) + + # Mock the POST /models endpoint + route = respx_mock.post("/models").mock( + return_value=Response( + 200, + json=ModelOperationStatus( + success=True, + model_name=dummy.model.name, + operation=ModelOperationType.INSERT, + ).model_dump(), + ) + ) + + # Call insert_model + model.insert_model() + + # Verify the API call + assert route.called + assert route.calls.last.request.content.decode() == dummy.model.model_dump_json() + + +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) +def test_insert_model_error(respx_mock: MockRouter): + """Test handling of model insertion errors.""" + dummy = model_factory(model_type="linker") + model = Model( + metadata=dummy.model, + model_instance=Mock(), + left_data=DataFrame(), + right_data=DataFrame(), + ) + + # Mock the POST /models endpoint with an error response + route = respx_mock.post("/models").mock( + return_value=Response( + 500, + json=ModelOperationStatus( + success=False, + model_name=dummy.model.name, + operation=ModelOperationType.INSERT, + details="Internal server error", + ).model_dump(), + ) + ) + + # Call insert_model and verify it raises an exception + with pytest.raises(MatchboxUnhandledServerResponse, match="Internal server error"): + model.insert_model() + + assert route.called + + +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) +def test_results_getter(respx_mock: MockRouter): + """Test getting model results via the API.""" + dummy = model_factory(model_type="linker") + model = Model( + metadata=dummy.model, + model_instance=Mock(), + left_data=DataFrame(), + right_data=DataFrame(), + ) + + # Mock the GET /models/{name}/results endpoint + route = respx_mock.get(f"/models/{dummy.model.name}/results").mock( + return_value=Response(200, content=table_to_buffer(dummy.data).read()) + ) + + # Get results + results = model.results + + # Verify the API call + assert route.called + assert isinstance(results, Results) + assert results.probabilities.schema.equals(SCHEMA_RESULTS) + + +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) +def test_results_getter_not_found(respx_mock: MockRouter): + """Test getting model results when they don't exist.""" + dummy = model_factory(model_type="linker") + model = Model( + metadata=dummy.model, + model_instance=Mock(), + left_data=DataFrame(), + right_data=DataFrame(), + ) + + # Mock the GET endpoint with a 404 response + route = respx_mock.get(f"/models/{dummy.model.name}/results").mock( + return_value=Response( + 404, + json=NotFoundError( + details="Results not found", entity=BackendRetrievableType.RESOLUTION + ).model_dump(), + ) + ) + + # Verify that accessing results raises an exception + with pytest.raises(MatchboxResolutionNotFoundError, match="Results not found"): + _ = model.results + + assert route.called + + +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) +def test_results_setter(respx_mock: MockRouter): + """Test setting model results via the API.""" + dummy = model_factory(model_type="linker") + model = Model( + metadata=dummy.model, + model_instance=Mock(), + left_data=DataFrame(), + right_data=DataFrame(), + ) + + # Mock the endpoints needed for results upload + init_route = respx_mock.post(f"/models/{dummy.model.name}/results").mock( + return_value=Response( + 200, + json=UploadStatus( + id="test-upload-id", + status="awaiting_upload", + entity=BackendUploadType.RESULTS, + ).model_dump(), + ) + ) + + upload_route = respx_mock.post("/upload/test-upload-id").mock( + return_value=Response( + 200, + json=UploadStatus( + id="test-upload-id", + status="processing", + entity=BackendUploadType.RESULTS, + ).model_dump(), + ) + ) + + status_route = respx_mock.get("/upload/test-upload-id/status").mock( + return_value=Response( + 200, + json=UploadStatus( + id="test-upload-id", + status="complete", + entity=BackendUploadType.RESULTS, + ).model_dump(), + ) + ) + + # Set results + test_results = Results(probabilities=dummy.data, metadata=dummy.model) + model.results = test_results + + # Verify API calls + assert init_route.called + assert upload_route.called + assert status_route.called + assert ( + b"PAR1" in upload_route.calls.last.request.content + ) # Check for parquet file signature + + +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) +def test_results_setter_upload_failure(respx_mock: MockRouter): + """Test handling of upload failures when setting results.""" + dummy = model_factory(model_type="linker") + model = Model( + metadata=dummy.model, + model_instance=Mock(), + left_data=DataFrame(), + right_data=DataFrame(), + ) + + # Mock the initial POST endpoint + init_route = respx_mock.post(f"/models/{dummy.model.name}/results").mock( + return_value=Response( + 200, + json=UploadStatus( + id="test-upload-id", + status="awaiting_upload", + entity=BackendUploadType.RESULTS, + ).model_dump(), + ) + ) + + # Mock the upload endpoint with a failure + upload_route = respx_mock.post("/upload/test-upload-id").mock( + return_value=Response( + 400, + json=UploadStatus( + id="test-upload-id", + status="failed", + entity=BackendUploadType.RESULTS, + details="Invalid data format", + ).model_dump(), + ) + ) + + # Attempt to set results and verify it raises an exception + test_results = Results(probabilities=dummy.data, metadata=dummy.model) + with pytest.raises(MatchboxServerFileError, match="Invalid data format"): + model.results = test_results + + assert init_route.called + assert upload_route.called + + +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) +def test_truth_getter(respx_mock: MockRouter): + """Test getting model truth threshold via the API.""" + dummy = model_factory(model_type="linker") + model = Model( + metadata=dummy.model, + model_instance=Mock(), + left_data=DataFrame(), + right_data=DataFrame(), + ) + + # Mock the GET /models/{name}/truth endpoint + route = respx_mock.get(f"/models/{dummy.model.name}/truth").mock( + return_value=Response(200, json=0.9) + ) + + # Get truth + truth = model.truth + + # Verify the API call + assert route.called + assert truth == 0.9 + + +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) +def test_truth_setter(respx_mock: MockRouter): + """Test setting model truth threshold via the API.""" + dummy = model_factory(model_type="linker") + model = Model( + metadata=dummy.model, + model_instance=Mock(), + left_data=DataFrame(), + right_data=DataFrame(), + ) + + # Mock the POST /models/{name}/truth endpoint + route = respx_mock.post(f"/models/{dummy.model.name}/truth").mock( + return_value=Response( + 200, + json=ModelOperationStatus( + success=True, + model_name=dummy.model.name, + operation=ModelOperationType.UPDATE_TRUTH, + ).model_dump(), + ) + ) + + # Set truth + model.truth = 0.9 + + # Verify the API call + assert route.called + assert float(route.calls.last.request.read()) == 0.9 + + +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) +def test_truth_setter_validation_error(respx_mock: MockRouter): + """Test setting invalid truth values.""" + dummy = model_factory(model_type="linker") + model = Model( + metadata=dummy.model, + model_instance=Mock(), + left_data=DataFrame(), + right_data=DataFrame(), + ) + + # Mock the POST endpoint with a validation error + route = respx_mock.post(f"/models/{dummy.model.name}/truth").mock( + return_value=Response(422) + ) + + # Attempt to set an invalid truth value + with pytest.raises(MatchboxUnparsedClientRequest): + model.truth = 1.5 + + assert route.called + + +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) +def test_ancestors_getter(respx_mock: MockRouter): + """Test getting model ancestors via the API.""" + dummy = model_factory(model_type="linker") + model = Model( + metadata=dummy.model, + model_instance=Mock(), + left_data=DataFrame(), + right_data=DataFrame(), + ) + + ancestors_data = [ + ModelAncestor(name="model1", truth=0.9).model_dump(), + ModelAncestor(name="model2", truth=0.8).model_dump(), + ] + + # Mock the GET /models/{name}/ancestors endpoint + route = respx_mock.get(f"/models/{dummy.model.name}/ancestors").mock( + return_value=Response(200, json=ancestors_data) + ) + + # Get ancestors + ancestors = model.ancestors + + # Verify the API call + assert route.called + assert ancestors == {"model1": 0.9, "model2": 0.8} + + +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) +def test_ancestors_cache_operations(respx_mock: MockRouter): + """Test getting and setting model ancestors cache via the API.""" + dummy = model_factory(model_type="linker") + model = Model( + metadata=dummy.model, + model_instance=Mock(), + left_data=DataFrame(), + right_data=DataFrame(), + ) + + # Mock the GET endpoint + get_route = respx_mock.get(f"/models/{dummy.model.name}/ancestors_cache").mock( + return_value=Response( + 200, json=[ModelAncestor(name="model1", truth=0.9).model_dump()] + ) + ) + + # Mock the POST endpoint + set_route = respx_mock.post(f"/models/{dummy.model.name}/ancestors_cache").mock( + return_value=Response( + 200, + json=ModelOperationStatus( + success=True, + model_name=dummy.model.name, + operation=ModelOperationType.UPDATE_ANCESTOR_CACHE, + ).model_dump(), + ) + ) + + # Get ancestors cache + cache = model.ancestors_cache + assert get_route.called + assert cache == {"model1": 0.9} + + # Set ancestors cache + model.ancestors_cache = {"model2": 0.8} + assert set_route.called + assert json.loads(set_route.calls.last.request.content.decode()) == [ + ModelAncestor(name="model2", truth=0.8).model_dump() + ] + + +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) +def test_ancestors_cache_set_error(respx_mock: MockRouter): + """Test error handling when setting ancestors cache.""" + dummy = model_factory(model_type="linker") + model = Model( + metadata=dummy.model, + model_instance=Mock(), + left_data=DataFrame(), + right_data=DataFrame(), + ) + + # Mock the POST endpoint with an error + route = respx_mock.post(f"/models/{dummy.model.name}/ancestors_cache").mock( + return_value=Response( + 500, + json=ModelOperationStatus( + success=False, + model_name=dummy.model.name, + operation=ModelOperationType.UPDATE_ANCESTOR_CACHE, + details="Database error", + ).model_dump(), + ) + ) + + # Attempt to set ancestors cache + with pytest.raises(MatchboxUnhandledServerResponse, match="Database error"): + model.ancestors_cache = {"model1": 0.9} + + assert route.called + + +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) +def test_delete_model(respx_mock: MockRouter): + """Test successfully deleting a model.""" + # Create test model using factory + dummy = model_factory() + model = Model( + metadata=dummy.model, + model_instance=Mock(), + left_data=DataFrame(), + right_data=DataFrame(), + ) + + # Mock the DELETE endpoint with success response + route = respx_mock.delete( + f"/models/{dummy.model.name}", params={"certain": True} + ).mock( + return_value=Response( + 200, + json=ModelOperationStatus( + success=True, + model_name=dummy.model.name, + operation=ModelOperationType.DELETE, + ).model_dump(), + ) + ) + + # Delete the model + response = model.delete(certain=True) + + # Verify the response and API call + assert response + assert route.called + assert route.calls.last.request.url.params["certain"] == "true" + + +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) +def test_delete_model_needs_confirmation(respx_mock: MockRouter): + """Test attempting to delete a model without confirmation returns 409.""" + # Create test model using factory + dummy = model_factory() + model = Model( + metadata=dummy.model, + model_instance=Mock(), + left_data=DataFrame(), + right_data=DataFrame(), + ) + + # Mock the DELETE endpoint with 409 confirmation required response + error_details = "Cannot delete model with dependent models: dedupe1, dedupe2" + route = respx_mock.delete(f"/models/{dummy.model.name}").mock( + return_value=Response( + 409, + json=ModelOperationStatus( + success=False, + model_name=dummy.model.name, + operation=ModelOperationType.DELETE, + details=error_details, + ).model_dump(), + ) + ) + + # Attempt to delete without certain=True + with pytest.raises(MatchboxConfirmDelete): + model.delete() + + # Verify the response and API call + assert route.called + assert route.calls.last.request.url.params["certain"] == "false" diff --git a/test/client/test_visualisation.py b/test/client/test_visualisation.py index 1891f8e1..f9dc511c 100644 --- a/test/client/test_visualisation.py +++ b/test/client/test_visualisation.py @@ -1,3 +1,5 @@ +from os import getenv + import pytest from httpx import Response from matplotlib.figure import Figure @@ -7,7 +9,7 @@ from matchbox.common.graph import ResolutionGraph -@pytest.mark.respx(base_url="http://localhost:8000") +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) def test_draw_resolution_graph( respx_mock: MockRouter, resolution_graph: ResolutionGraph ): diff --git a/test/server/api/test_routes.py b/test/server/api/test_routes.py index fb8bc539..31318407 100644 --- a/test/server/api/test_routes.py +++ b/test/server/api/test_routes.py @@ -641,7 +641,7 @@ def test_delete_model(get_backend: Mock): def test_delete_model_needs_confirmation(get_backend: Mock): mock_backend = Mock() mock_backend.delete_model = Mock( - side_effect=MatchboxConfirmDelete(["dedupe1", "dedupe2"]) + side_effect=MatchboxConfirmDelete(children=["dedupe1", "dedupe2"]) ) get_backend.return_value = mock_backend @@ -680,6 +680,19 @@ def test_set_results(get_backend: Mock): assert response.json()["status"] == "awaiting_upload" +@patch("matchbox.server.base.BackendManager.get_backend") +def test_set_results_model_not_found(get_backend: Mock): + """Test setting results for a non-existent model.""" + mock_backend = Mock() + mock_backend.get_model = Mock(side_effect=MatchboxResolutionNotFoundError()) + get_backend.return_value = mock_backend + + response = client.post("/models/nonexistent-model/results") + + assert response.status_code == 404 + assert response.json()["entity"] == BackendRetrievableType.RESOLUTION + + @patch("matchbox.server.base.BackendManager.get_backend") def test_get_truth(get_backend: Mock): mock_backend = Mock() @@ -708,6 +721,22 @@ def test_set_truth(get_backend: Mock): ) +@patch("matchbox.server.base.BackendManager.get_backend") +def test_set_truth_invalid_value(get_backend: Mock): + """Test setting an invalid truth value (outside 0-1 range).""" + mock_backend = Mock() + dummy_model = model_factory() + get_backend.return_value = mock_backend + + # Test value > 1 + response = client.patch(f"/models/{dummy_model.model.name}/truth", json=1.5) + assert response.status_code == 422 + + # Test value < 0 + response = client.patch(f"/models/{dummy_model.model.name}/truth", json=-0.5) + assert response.status_code == 422 + + @patch("matchbox.server.base.BackendManager.get_backend") def test_get_ancestors(get_backend: Mock): mock_backend = Mock() @@ -726,6 +755,50 @@ def test_get_ancestors(get_backend: Mock): assert [ModelAncestor.model_validate(a) for a in response.json()] == mock_ancestors +@patch("matchbox.server.base.BackendManager.get_backend") +def test_set_ancestors_cache(get_backend: Mock): + """Test setting the ancestors cache for a model.""" + mock_backend = Mock() + dummy_model = model_factory() + get_backend.return_value = mock_backend + + ancestors_data = [ + ModelAncestor(name="parent_model", truth=0.7), + ModelAncestor(name="grandparent_model", truth=0.8), + ] + + response = client.patch( + f"/models/{dummy_model.model.name}/ancestors_cache", + json=[a.model_dump() for a in ancestors_data], + ) + + assert response.status_code == 200 + assert response.json()["success"] is True + assert response.json()["operation"] == ModelOperationType.UPDATE_ANCESTOR_CACHE + mock_backend.set_model_ancestors_cache.assert_called_once_with( + model=dummy_model.model.name, ancestors_cache=ancestors_data + ) + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_get_ancestors_cache(get_backend: Mock): + """Test retrieving the ancestors cache for a model.""" + mock_backend = Mock() + dummy_model = model_factory() + mock_ancestors = [ + ModelAncestor(name="parent_model", truth=0.7), + ModelAncestor(name="grandparent_model", truth=0.8), + ] + mock_backend.get_model_ancestors_cache = Mock(return_value=mock_ancestors) + get_backend.return_value = mock_backend + + response = client.get(f"/models/{dummy_model.model.name}/ancestors_cache") + + assert response.status_code == 200 + assert len(response.json()) == 2 + assert [ModelAncestor.model_validate(a) for a in response.json()] == mock_ancestors + + # Query and match endpoints From 7b44292681851c072fc6aa3591de1e6b8027ecc8 Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Thu, 13 Feb 2025 12:50:18 +0000 Subject: [PATCH 08/19] Removed inject_backend from all client-facing functions --- src/matchbox/client/_handler.py | 2 +- src/matchbox/client/models/models.py | 7 +++--- src/matchbox/client/results.py | 4 +--- src/matchbox/server/__init__.py | 3 +-- src/matchbox/server/base.py | 33 ---------------------------- src/matchbox/server/postgresql/db.py | 2 ++ test/client/test_dedupers.py | 2 +- test/client/test_linkers.py | 2 +- test/client/test_model.py | 8 +++---- test/fixtures/db.py | 4 ++-- 10 files changed, 17 insertions(+), 50 deletions(-) diff --git a/src/matchbox/client/_handler.py b/src/matchbox/client/_handler.py index 1801734c..e57054f3 100644 --- a/src/matchbox/client/_handler.py +++ b/src/matchbox/client/_handler.py @@ -271,7 +271,7 @@ def get_model_results(name: str) -> Table: def set_model_truth(name: str, truth: float) -> ModelOperationStatus: """Set the truth threshold for a model in Matchbox.""" - res = CLIENT.post(f"/models/{name}/truth", json=truth) + res = CLIENT.patch(f"/models/{name}/truth", json=truth) return ModelOperationStatus.model_validate(res.json()) diff --git a/src/matchbox/client/models/models.py b/src/matchbox/client/models/models.py index 3dd3a782..48312f50 100644 --- a/src/matchbox/client/models/models.py +++ b/src/matchbox/client/models/models.py @@ -41,9 +41,10 @@ def results(self) -> Results: @results.setter def results(self, results: Results) -> None: """Write results associated with the model to the database.""" - _handler.add_model_results( - name=self.metadata.name, results=results.probabilities - ) + if results.probabilities.shape[0] > 0: + _handler.add_model_results( + name=self.metadata.name, results=results.probabilities + ) @property def truth(self) -> float: diff --git a/src/matchbox/client/results.py b/src/matchbox/client/results.py index 776e64d8..c98c484e 100644 --- a/src/matchbox/client/results.py +++ b/src/matchbox/client/results.py @@ -11,7 +11,6 @@ from matchbox.common.dtos import ModelMetadata from matchbox.common.hash import IntMap from matchbox.common.transform import to_clusters -from matchbox.server.base import MatchboxDBAdapter, inject_backend if TYPE_CHECKING: from matchbox.client.models.models import Model @@ -210,8 +209,7 @@ def inspect_clusters( right_merge_col="child", ) - @inject_backend - def to_matchbox(self, backend: MatchboxDBAdapter) -> None: + def to_matchbox(self) -> None: """Writes the results to the Matchbox database.""" self.model.insert_model() self.model.results = self diff --git a/src/matchbox/server/__init__.py b/src/matchbox/server/__init__.py index 88bf6d0f..3b208ccb 100644 --- a/src/matchbox/server/__init__.py +++ b/src/matchbox/server/__init__.py @@ -3,9 +3,8 @@ MatchboxDBAdapter, MatchboxSettings, initialise_matchbox, - inject_backend, ) -__all__ = ["app", "MatchboxDBAdapter", "MatchboxSettings", "inject_backend"] +__all__ = ["app", "MatchboxDBAdapter", "MatchboxSettings", "initialise_matchbox"] initialise_matchbox() diff --git a/src/matchbox/server/base.py b/src/matchbox/server/base.py index d7e3b553..d3ae7e4c 100644 --- a/src/matchbox/server/base.py +++ b/src/matchbox/server/base.py @@ -1,16 +1,12 @@ -import inspect from abc import ABC, abstractmethod from enum import StrEnum -from functools import wraps from pathlib import Path from typing import ( TYPE_CHECKING, Any, - Callable, ParamSpec, Protocol, TypeVar, - cast, ) import boto3 @@ -173,35 +169,6 @@ def initialise_matchbox() -> None: initialise_backend(settings) -def inject_backend(func: Callable[..., R]) -> Callable[..., R]: - """Decorator to inject the Matchbox backend into functions. - - Used to allow user-facing functions to access the backend without needing to - pass it in. The backend is defined by the MB__BACKEND_TYPE environment variable. - - Can be used for both functions and methods. - - If the user specifies a backend, it will be used instead of the injection. - """ - - @wraps(func) - def _inject_backend( - *args: P.args, backend: "MatchboxDBAdapter | None" = None, **kwargs: P.kwargs - ) -> R: - if backend is None: - backend = BackendManager.get_backend() - - sig = inspect.signature(func) - params = list(sig.parameters.values()) - - if params and params[0].name in ("self", "cls"): - return cast(R, func(args[0], backend, *args[1:], **kwargs)) - else: - return cast(R, func(backend, *args, **kwargs)) - - return cast(Callable[..., R], _inject_backend) - - class Countable(Protocol): """A protocol for objects that can be counted.""" diff --git a/src/matchbox/server/postgresql/db.py b/src/matchbox/server/postgresql/db.py index 043403fb..4e1ae7c0 100644 --- a/src/matchbox/server/postgresql/db.py +++ b/src/matchbox/server/postgresql/db.py @@ -93,6 +93,8 @@ def clear_database(self): ) conn.commit() + self.engine.dispose() + self.create_database() diff --git a/test/client/test_dedupers.py b/test/client/test_dedupers.py index 2046e580..e283944f 100644 --- a/test/client/test_dedupers.py +++ b/test/client/test_dedupers.py @@ -108,7 +108,7 @@ def test_dedupers( # 4. Probabilities and clusters are inserted correctly - results.to_matchbox(backend=matchbox_postgres) + results.to_matchbox() retrieved_results = matchbox_postgres.get_model_results(model=deduper_name) assert retrieved_results.shape[0] == fx_data.tgt_prob_n diff --git a/test/client/test_linkers.py b/test/client/test_linkers.py index 914b9a3e..9204207a 100644 --- a/test/client/test_linkers.py +++ b/test/client/test_linkers.py @@ -174,7 +174,7 @@ def unique_non_null(s): # 4. Probabilities and clusters are inserted correctly - results.to_matchbox(backend=matchbox_postgres) + results.to_matchbox() retrieved_results = matchbox_postgres.get_model_results(model=linker_name) assert retrieved_results.shape[0] == fx_data.tgt_prob_n diff --git a/test/client/test_model.py b/test/client/test_model.py index d7222a67..299967ab 100644 --- a/test/client/test_model.py +++ b/test/client/test_model.py @@ -283,8 +283,8 @@ def test_truth_setter(respx_mock: MockRouter): right_data=DataFrame(), ) - # Mock the POST /models/{name}/truth endpoint - route = respx_mock.post(f"/models/{dummy.model.name}/truth").mock( + # Mock the PATCH /models/{name}/truth endpoint + route = respx_mock.patch(f"/models/{dummy.model.name}/truth").mock( return_value=Response( 200, json=ModelOperationStatus( @@ -314,8 +314,8 @@ def test_truth_setter_validation_error(respx_mock: MockRouter): right_data=DataFrame(), ) - # Mock the POST endpoint with a validation error - route = respx_mock.post(f"/models/{dummy.model.name}/truth").mock( + # Mock the PATCH endpoint with a validation error + route = respx_mock.patch(f"/models/{dummy.model.name}/truth").mock( return_value=Response(422) ) diff --git a/test/fixtures/db.py b/test/fixtures/db.py index 65a951fd..b4a23faf 100644 --- a/test/fixtures/db.py +++ b/test/fixtures/db.py @@ -103,7 +103,7 @@ def _db_add_dedupe_models_and_data( ) results = model.run() - results.to_matchbox(backend=backend) + results.to_matchbox() model.truth = 0.0 return _db_add_dedupe_models_and_data @@ -173,7 +173,7 @@ def _db_add_link_models_and_data( ) results = model.run() - results.to_matchbox(backend=backend) + results.to_matchbox() model.truth = 0.0 return _db_add_link_models_and_data From 898fd2c428e6217391239a8d98450247fd9d4bd5 Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Thu, 13 Feb 2025 13:04:28 +0000 Subject: [PATCH 09/19] Updated helper unit tests to use the environment variable --- src/matchbox/server/__init__.py | 2 +- test/client/test_helpers.py | 27 ++++++++++++++------------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/src/matchbox/server/__init__.py b/src/matchbox/server/__init__.py index 3b208ccb..51b24618 100644 --- a/src/matchbox/server/__init__.py +++ b/src/matchbox/server/__init__.py @@ -5,6 +5,6 @@ initialise_matchbox, ) -__all__ = ["app", "MatchboxDBAdapter", "MatchboxSettings", "initialise_matchbox"] +__all__ = ["app", "MatchboxDBAdapter", "MatchboxSettings"] initialise_matchbox() diff --git a/test/client/test_helpers.py b/test/client/test_helpers.py index d2c7bb73..4abcfc2c 100644 --- a/test/client/test_helpers.py +++ b/test/client/test_helpers.py @@ -1,4 +1,5 @@ import logging +from os import getenv from unittest.mock import Mock, patch import pyarrow as pa @@ -74,7 +75,7 @@ def test_comparisons(): assert comparison_name_id is not None -@pytest.mark.respx(base_url="http://localhost:8000") +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) def test_select_mixed_style(respx_mock: MockRouter, warehouse_engine: Engine): """We can select select specific columns from some of the sources""" # Set up mocks and test data @@ -126,7 +127,7 @@ def test_select_mixed_style(respx_mock: MockRouter, warehouse_engine: Engine): assert selection[1].source.engine == warehouse_engine -@pytest.mark.respx(base_url="http://localhost:8000") +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) def test_select_non_indexed_columns(respx_mock: MockRouter, warehouse_engine: Engine): """Selecting columns not declared to backend generates warning.""" source = Source( @@ -151,7 +152,7 @@ def test_select_non_indexed_columns(respx_mock: MockRouter, warehouse_engine: En select({"test.foo": ["a", "b"]}, engine=warehouse_engine) -@pytest.mark.respx(base_url="http://localhost:8000") +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) def test_select_missing_columns(respx_mock: MockRouter, warehouse_engine: Engine): """Selecting columns not in the warehouse errors.""" source = Source( @@ -202,7 +203,7 @@ def test_query_no_resolution_fail(): query(sels) -@pytest.mark.respx(base_url="http://localhost:8000") +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) @patch.object(Source, "to_arrow") def test_query_no_resolution_ok_various_params(to_arrow: Mock, respx_mock: MockRouter): """Tests that we can avoid passing resolution name, with a variety of parameters.""" @@ -270,7 +271,7 @@ def test_query_no_resolution_ok_various_params(to_arrow: Mock, respx_mock: MockR } -@pytest.mark.respx(base_url="http://localhost:8000") +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) @patch.object(Source, "to_arrow") def test_query_multiple_sources_with_limits(to_arrow: Mock, respx_mock: MockRouter): """Tests that we can query multiple sources and distribute the limit among them.""" @@ -372,7 +373,7 @@ def test_query_multiple_sources_with_limits(to_arrow: Mock, respx_mock: MockRout query([sels[0]], [sels[1]], resolution_name="link", limit=7) -@pytest.mark.respx(base_url="http://localhost:8000") +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) def test_query_404_resolution(respx_mock: MockRouter): # Mock API respx_mock.get("/query").mock( @@ -404,7 +405,7 @@ def test_query_404_resolution(respx_mock: MockRouter): query(sels) -@pytest.mark.respx(base_url="http://localhost:8000") +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) def test_query_404_source(respx_mock: MockRouter): # Mock API respx_mock.get("/query").mock( @@ -436,7 +437,7 @@ def test_query_404_source(respx_mock: MockRouter): query(sels) -@pytest.mark.respx(base_url="http://localhost:8000") +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) @patch("matchbox.client.helpers.index.Source") def test_index_success(MockSource: Mock, respx_mock: MockRouter): """Test successful indexing flow through the API.""" @@ -488,7 +489,7 @@ def test_index_success(MockSource: Mock, respx_mock: MockRouter): assert b"PAR1" in upload_route.calls.last.request.content -@pytest.mark.respx(base_url="http://localhost:8000") +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) @patch("matchbox.client.helpers.index.Source") @pytest.mark.parametrize( "columns", @@ -566,7 +567,7 @@ def test_index_with_columns( mock_source_instance.default_columns.assert_called_once() -@pytest.mark.respx(base_url="http://localhost:8000") +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) @patch("matchbox.client.helpers.index.Source") def test_index_upload_failure(MockSource: Mock, respx_mock: MockRouter): """Test handling of upload failures.""" @@ -622,7 +623,7 @@ def test_index_upload_failure(MockSource: Mock, respx_mock: MockRouter): assert b"PAR1" in upload_route.calls.last.request.content -@pytest.mark.respx(base_url="http://localhost:8000") +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) def test_match_ok(respx_mock: MockRouter): """The client can perform the right call for matching.""" # Set up mocks @@ -714,7 +715,7 @@ def test_match_ok(respx_mock: MockRouter): ) -@pytest.mark.respx(base_url="http://localhost:8000") +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) def test_match_404_resolution(respx_mock: MockRouter): """The client can handle a resolution not found error.""" # Set up mocks @@ -763,7 +764,7 @@ def test_match_404_resolution(respx_mock: MockRouter): ) -@pytest.mark.respx(base_url="http://localhost:8000") +@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) def test_match_404_source(respx_mock: MockRouter): """The client can handle a source not found error.""" # Set up mocks From c69e4bbac254962fb0e55b0e8ada3287555007e3 Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Thu, 13 Feb 2025 16:07:47 +0000 Subject: [PATCH 10/19] Removed add 501 not implemented endpoints --- src/matchbox/server/api/routes.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/src/matchbox/server/api/routes.py b/src/matchbox/server/api/routes.py index d5068d8b..22d25231 100644 --- a/src/matchbox/server/api/routes.py +++ b/src/matchbox/server/api/routes.py @@ -307,11 +307,6 @@ async def match( # Data management -@app.get("/sources") -async def list_sources(): - raise HTTPException(status_code=501, detail="Not implemented") - - @app.post("/sources") async def add_source(source: Source) -> UploadStatus: """Create an upload and insert task for indexed source data.""" @@ -348,19 +343,9 @@ async def get_resolutions( return backend.get_resolution_graph() -@app.post("/testing/clear") -async def clear_backend(): - raise HTTPException(status_code=501, detail="Not implemented") - - # Model management -@app.get("/models") -async def list_models(): - raise HTTPException(status_code=501, detail="Not implemented") - - @app.post( "/models", responses={ From 330d9986c4d884af5034796e55c34ce568e7569e Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Thu, 13 Feb 2025 16:57:24 +0000 Subject: [PATCH 11/19] Removed dead unit test --- test/server/api/test_routes.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/test/server/api/test_routes.py b/test/server/api/test_routes.py index 31318407..030b9b01 100644 --- a/test/server/api/test_routes.py +++ b/test/server/api/test_routes.py @@ -382,19 +382,6 @@ async def test_complete_source_upload_process(get_backend: Mock, s3: S3Client): # Model endpoints -# @patch("matchbox.server.base.BackendManager.get_backend") -# def test_list_models(get_backend: Mock): -# mock_backend = Mock() -# dummy_models = [ -# model_factory(name="model1", description="test model 1").model, -# model_factory(name="model2", description="test model 2").model -# ] -# mock_backend.list_models = Mock(return_value=dummy_models) -# get_backend.return_value = mock_backend - -# response = client.get("/models") -# assert response.status_code == 200 - @patch("matchbox.server.base.BackendManager.get_backend") def test_insert_model(get_backend: Mock): From d7bc678969fc0657240ab6069ebe862de77f15dd Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Fri, 14 Feb 2025 10:53:59 +0000 Subject: [PATCH 12/19] Renamed deletion exception, refactored model_factory, and cached calls to both source_ and model_factory --- src/matchbox/client/_handler.py | 6 +- src/matchbox/common/exceptions.py | 2 +- src/matchbox/common/factories/models.py | 17 ++- src/matchbox/common/factories/sources.py | 54 +++++-- src/matchbox/server/api/routes.py | 4 +- src/matchbox/server/postgresql/adapter.py | 4 +- test/client/test_model.py | 176 ++++++---------------- test/common/test_factories.py | 58 +++---- test/server/api/test_routes.py | 129 ++++++++-------- 9 files changed, 208 insertions(+), 242 deletions(-) diff --git a/src/matchbox/client/_handler.py b/src/matchbox/client/_handler.py index e57054f3..e9835411 100644 --- a/src/matchbox/client/_handler.py +++ b/src/matchbox/client/_handler.py @@ -18,7 +18,7 @@ ) from matchbox.common.exceptions import ( MatchboxClientFileError, - MatchboxConfirmDelete, + MatchboxDeletionNotConfirmed, MatchboxResolutionNotFoundError, MatchboxServerFileError, MatchboxSourceNotFoundError, @@ -81,7 +81,7 @@ def handle_http_code(res: httpx.Response) -> httpx.Response: if res.status_code == 409: error = ModelOperationStatus.model_validate(res.json()) - raise MatchboxConfirmDelete(message=error.details) + raise MatchboxDeletionNotConfirmed(message=error.details) if res.status_code == 422: raise MatchboxUnparsedClientRequest(res.content) @@ -303,7 +303,7 @@ def get_model_ancestors_cache(name: str) -> list[ModelAncestor]: return [ModelAncestor.model_validate(m) for m in res.json()] -def delete_model(name: str, certain: bool | None = None) -> ModelOperationStatus: +def delete_model(name: str, certain: bool | None = False) -> ModelOperationStatus: """Delete a model in Matchbox.""" res = CLIENT.delete(f"/models/{name}", params={"certain": certain}) return ModelOperationStatus.model_validate(res.json()) diff --git a/src/matchbox/common/exceptions.py b/src/matchbox/common/exceptions.py index 0558c591..8b7a6d8a 100644 --- a/src/matchbox/common/exceptions.py +++ b/src/matchbox/common/exceptions.py @@ -138,7 +138,7 @@ class MatchboxConnectionError(Exception): """Connection to Matchbox's backend database failed.""" -class MatchboxConfirmDelete(Exception): +class MatchboxDeletionNotConfirmed(Exception): """Deletion must be confirmed: if certain, rerun with certain=True.""" def __init__(self, message: str | None = None, children: list[str] | None = None): diff --git a/src/matchbox/common/factories/models.py b/src/matchbox/common/factories/models.py index 41c0c5f0..87a13491 100644 --- a/src/matchbox/common/factories/models.py +++ b/src/matchbox/common/factories/models.py @@ -1,4 +1,5 @@ from collections import Counter +from functools import cache from textwrap import dedent from typing import Any, Literal from unittest.mock import Mock, PropertyMock, create_autospec @@ -317,7 +318,7 @@ class ModelDummy(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - model: ModelMetadata + model: Model data: pa.Table metrics: ModelMetrics @@ -350,6 +351,7 @@ def to_mock(self) -> Mock: return mock_model +@cache def model_factory( name: str | None = None, description: str | None = None, @@ -376,7 +378,7 @@ def model_factory( model_type = ModelType(model_type.lower() if model_type else "deduper") - model = ModelMetadata( + metadata = ModelMetadata( name=name or generator.word(), description=description or generator.sentence(), type=model_type, @@ -384,14 +386,21 @@ def model_factory( right_resolution=generator.word() if model_type == ModelType.LINKER else None, ) - if model.type == ModelType.LINKER: + if metadata.type == ModelType.LINKER: left_values = list(range(n_true_entities)) right_values = list(range(n_true_entities, n_true_entities * 2)) - elif model.type == ModelType.DEDUPER: + elif metadata.type == ModelType.DEDUPER: values_count = n_true_entities * 2 # So there's something to dedupe left_values = list(range(values_count)) right_values = None + model = Model( + metadata=metadata, + model_instance=Mock(), + left_data=DataFrame({"id": left_values}), + right_data=DataFrame({"id": right_values}) if right_values else None, + ) + probabilities = generate_dummy_probabilities( left_values=left_values, right_values=right_values, diff --git a/src/matchbox/common/factories/sources.py b/src/matchbox/common/factories/sources.py index 4df0c4a6..f0646707 100644 --- a/src/matchbox/common/factories/sources.py +++ b/src/matchbox/common/factories/sources.py @@ -1,5 +1,7 @@ from abc import ABC, abstractmethod +from functools import cache, wraps from math import comb +from typing import Callable, ParamSpec, TypeVar from unittest.mock import Mock, create_autospec from uuid import uuid4 @@ -12,10 +14,39 @@ from matchbox.common.arrow import SCHEMA_INDEX from matchbox.common.sources import Source, SourceAddress, SourceColumn +P = ParamSpec("P") +R = TypeVar("R") + + +def make_features_hashable(func: Callable[P, R]) -> Callable[P, R]: + @wraps(func) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + # Handle features in first positional arg + if args and args[0] is not None: + if isinstance(args[0][0], dict): + args = (tuple(FeatureConfig(**d) for d in args[0]),) + args[1:] + else: + args = (tuple(args[0]),) + args[1:] + + # Handle features in kwargs + if "features" in kwargs and kwargs["features"] is not None: + if isinstance(kwargs["features"][0], dict): + kwargs["features"] = tuple( + FeatureConfig(**d) for d in kwargs["features"] + ) + else: + kwargs["features"] = tuple(kwargs["features"]) + + return func(*args, **kwargs) + + return wrapper + class VariationRule(BaseModel, ABC): """Abstract base class for variation rules.""" + model_config = ConfigDict(frozen=True) + @abstractmethod def apply(self, value: str) -> str: """Apply the variation to a value.""" @@ -71,10 +102,12 @@ def type(self) -> str: class FeatureConfig(BaseModel): """Configuration for generating a feature with variations.""" + model_config = ConfigDict(frozen=True) + name: str base_generator: str - parameters: dict = Field(default_factory=dict) - variations: list[VariationRule] = Field(default_factory=list) + parameters: tuple = Field(default_factory=tuple) + variations: tuple[VariationRule, ...] = Field(default_factory=tuple) class SourceMetrics(BaseModel): @@ -138,7 +171,7 @@ class SourceDummy(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) source: Source - features: list[FeatureConfig] + features: tuple[FeatureConfig, ...] data: pa.Table data_hashes: pa.Table metrics: SourceMetrics @@ -166,7 +199,7 @@ def __init__(self, seed: int = 42): self.faker.seed_instance(seed) def generate_data( - self, n_true_entities: int, features: list[FeatureConfig], repetition: int + self, n_true_entities: int, features: tuple[FeatureConfig], repetition: int ) -> tuple[pa.Table, pa.Table, SourceMetrics]: """Generate raw data as PyArrow tables. @@ -186,7 +219,7 @@ def generate_data( for _ in range(n_true_entities): # Generate base values -- the raw row base_values = { - f.name: getattr(self.faker, f.base_generator)(**f.parameters) + f.name: getattr(self.faker, f.base_generator)(**dict(f.parameters)) for f in features } @@ -235,6 +268,8 @@ def generate_data( return pa.Table.from_pandas(df), data_hashes, metrics +@make_features_hashable +@cache def source_factory( features: list[FeatureConfig] | list[dict] | None = None, full_name: str | None = None, @@ -259,7 +294,7 @@ def source_factory( generator = SourceDataGenerator(seed) if features is None: - features = [ + features = ( FeatureConfig( name="company_name", base_generator="company", @@ -267,9 +302,9 @@ def source_factory( FeatureConfig( name="crn", base_generator="bothify", - parameters={"text": "???-###-???-###"}, + parameters=(("text", "???-###-???-###"),), ), - ] + ) if full_name is None: full_name = generator.faker.word() @@ -277,9 +312,6 @@ def source_factory( if engine is None: engine = create_engine("sqlite:///:memory:") - if features and isinstance(features[0], dict): - features = [FeatureConfig.model_validate(feature) for feature in features] - data, data_hashes, metrics = generator.generate_data( n_true_entities=n_true_entities, features=features, repetition=repetition ) diff --git a/src/matchbox/server/api/routes.py b/src/matchbox/server/api/routes.py index 22d25231..4d2fbf48 100644 --- a/src/matchbox/server/api/routes.py +++ b/src/matchbox/server/api/routes.py @@ -29,7 +29,7 @@ UploadStatus, ) from matchbox.common.exceptions import ( - MatchboxConfirmDelete, + MatchboxDeletionNotConfirmed, MatchboxResolutionNotFoundError, MatchboxServerFileError, MatchboxSourceNotFoundError, @@ -611,7 +611,7 @@ async def delete_model( details=str(e), entity=BackendRetrievableType.RESOLUTION ).model_dump(), ) from e - except MatchboxConfirmDelete as e: + except MatchboxDeletionNotConfirmed as e: raise HTTPException( status_code=409, detail=ModelOperationStatus( diff --git a/src/matchbox/server/postgresql/adapter.py b/src/matchbox/server/postgresql/adapter.py index e6cc83cc..afeb9410 100644 --- a/src/matchbox/server/postgresql/adapter.py +++ b/src/matchbox/server/postgresql/adapter.py @@ -7,8 +7,8 @@ from matchbox.common.dtos import ModelAncestor, ModelMetadata, ModelType from matchbox.common.exceptions import ( - MatchboxConfirmDelete, MatchboxDataNotFound, + MatchboxDeletionNotConfirmed, MatchboxResolutionNotFoundError, MatchboxSourceNotFoundError, ) @@ -533,4 +533,4 @@ def delete_model(self, model: str, certain: bool = False) -> None: session.commit() else: children = [r.name for r in resolution.descendants] - raise MatchboxConfirmDelete(childen=children) + raise MatchboxDeletionNotConfirmed(childen=children) diff --git a/test/client/test_model.py b/test/client/test_model.py index 299967ab..b785bca9 100644 --- a/test/client/test_model.py +++ b/test/client/test_model.py @@ -1,13 +1,10 @@ import json from os import getenv -from unittest.mock import Mock import pytest from httpx import Response -from pandas import DataFrame from respx.router import MockRouter -from matchbox.client.models.models import Model from matchbox.client.results import Results from matchbox.common.arrow import SCHEMA_RESULTS, table_to_buffer from matchbox.common.dtos import ( @@ -20,7 +17,7 @@ UploadStatus, ) from matchbox.common.exceptions import ( - MatchboxConfirmDelete, + MatchboxDeletionNotConfirmed, MatchboxResolutionNotFoundError, MatchboxServerFileError, MatchboxUnhandledServerResponse, @@ -34,12 +31,6 @@ def test_insert_model(respx_mock: MockRouter): """Test inserting a model via the API.""" # Create test model using factory dummy = model_factory(model_type="linker") - model = Model( - metadata=dummy.model, - model_instance=Mock(), - left_data=DataFrame(), - right_data=DataFrame(), - ) # Mock the POST /models endpoint route = respx_mock.post("/models").mock( @@ -47,30 +38,27 @@ def test_insert_model(respx_mock: MockRouter): 200, json=ModelOperationStatus( success=True, - model_name=dummy.model.name, + model_name=dummy.model.metadata.name, operation=ModelOperationType.INSERT, ).model_dump(), ) ) # Call insert_model - model.insert_model() + dummy.model.insert_model() # Verify the API call assert route.called - assert route.calls.last.request.content.decode() == dummy.model.model_dump_json() + assert ( + route.calls.last.request.content.decode() + == dummy.model.metadata.model_dump_json() + ) @pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) def test_insert_model_error(respx_mock: MockRouter): """Test handling of model insertion errors.""" dummy = model_factory(model_type="linker") - model = Model( - metadata=dummy.model, - model_instance=Mock(), - left_data=DataFrame(), - right_data=DataFrame(), - ) # Mock the POST /models endpoint with an error response route = respx_mock.post("/models").mock( @@ -78,7 +66,7 @@ def test_insert_model_error(respx_mock: MockRouter): 500, json=ModelOperationStatus( success=False, - model_name=dummy.model.name, + model_name=dummy.model.metadata.name, operation=ModelOperationType.INSERT, details="Internal server error", ).model_dump(), @@ -87,7 +75,7 @@ def test_insert_model_error(respx_mock: MockRouter): # Call insert_model and verify it raises an exception with pytest.raises(MatchboxUnhandledServerResponse, match="Internal server error"): - model.insert_model() + dummy.model.insert_model() assert route.called @@ -96,20 +84,14 @@ def test_insert_model_error(respx_mock: MockRouter): def test_results_getter(respx_mock: MockRouter): """Test getting model results via the API.""" dummy = model_factory(model_type="linker") - model = Model( - metadata=dummy.model, - model_instance=Mock(), - left_data=DataFrame(), - right_data=DataFrame(), - ) # Mock the GET /models/{name}/results endpoint - route = respx_mock.get(f"/models/{dummy.model.name}/results").mock( + route = respx_mock.get(f"/models/{dummy.model.metadata.name}/results").mock( return_value=Response(200, content=table_to_buffer(dummy.data).read()) ) # Get results - results = model.results + results = dummy.model.results # Verify the API call assert route.called @@ -121,15 +103,9 @@ def test_results_getter(respx_mock: MockRouter): def test_results_getter_not_found(respx_mock: MockRouter): """Test getting model results when they don't exist.""" dummy = model_factory(model_type="linker") - model = Model( - metadata=dummy.model, - model_instance=Mock(), - left_data=DataFrame(), - right_data=DataFrame(), - ) # Mock the GET endpoint with a 404 response - route = respx_mock.get(f"/models/{dummy.model.name}/results").mock( + route = respx_mock.get(f"/models/{dummy.model.metadata.name}/results").mock( return_value=Response( 404, json=NotFoundError( @@ -140,7 +116,7 @@ def test_results_getter_not_found(respx_mock: MockRouter): # Verify that accessing results raises an exception with pytest.raises(MatchboxResolutionNotFoundError, match="Results not found"): - _ = model.results + _ = dummy.model.results assert route.called @@ -149,15 +125,9 @@ def test_results_getter_not_found(respx_mock: MockRouter): def test_results_setter(respx_mock: MockRouter): """Test setting model results via the API.""" dummy = model_factory(model_type="linker") - model = Model( - metadata=dummy.model, - model_instance=Mock(), - left_data=DataFrame(), - right_data=DataFrame(), - ) # Mock the endpoints needed for results upload - init_route = respx_mock.post(f"/models/{dummy.model.name}/results").mock( + init_route = respx_mock.post(f"/models/{dummy.model.metadata.name}/results").mock( return_value=Response( 200, json=UploadStatus( @@ -191,8 +161,8 @@ def test_results_setter(respx_mock: MockRouter): ) # Set results - test_results = Results(probabilities=dummy.data, metadata=dummy.model) - model.results = test_results + test_results = Results(probabilities=dummy.data, metadata=dummy.model.metadata) + dummy.model.results = test_results # Verify API calls assert init_route.called @@ -207,15 +177,9 @@ def test_results_setter(respx_mock: MockRouter): def test_results_setter_upload_failure(respx_mock: MockRouter): """Test handling of upload failures when setting results.""" dummy = model_factory(model_type="linker") - model = Model( - metadata=dummy.model, - model_instance=Mock(), - left_data=DataFrame(), - right_data=DataFrame(), - ) # Mock the initial POST endpoint - init_route = respx_mock.post(f"/models/{dummy.model.name}/results").mock( + init_route = respx_mock.post(f"/models/{dummy.model.metadata.name}/results").mock( return_value=Response( 200, json=UploadStatus( @@ -240,9 +204,9 @@ def test_results_setter_upload_failure(respx_mock: MockRouter): ) # Attempt to set results and verify it raises an exception - test_results = Results(probabilities=dummy.data, metadata=dummy.model) + test_results = Results(probabilities=dummy.data, metadata=dummy.model.metadata) with pytest.raises(MatchboxServerFileError, match="Invalid data format"): - model.results = test_results + dummy.model.results = test_results assert init_route.called assert upload_route.called @@ -252,20 +216,14 @@ def test_results_setter_upload_failure(respx_mock: MockRouter): def test_truth_getter(respx_mock: MockRouter): """Test getting model truth threshold via the API.""" dummy = model_factory(model_type="linker") - model = Model( - metadata=dummy.model, - model_instance=Mock(), - left_data=DataFrame(), - right_data=DataFrame(), - ) # Mock the GET /models/{name}/truth endpoint - route = respx_mock.get(f"/models/{dummy.model.name}/truth").mock( + route = respx_mock.get(f"/models/{dummy.model.metadata.name}/truth").mock( return_value=Response(200, json=0.9) ) # Get truth - truth = model.truth + truth = dummy.model.truth # Verify the API call assert route.called @@ -276,27 +234,21 @@ def test_truth_getter(respx_mock: MockRouter): def test_truth_setter(respx_mock: MockRouter): """Test setting model truth threshold via the API.""" dummy = model_factory(model_type="linker") - model = Model( - metadata=dummy.model, - model_instance=Mock(), - left_data=DataFrame(), - right_data=DataFrame(), - ) # Mock the PATCH /models/{name}/truth endpoint - route = respx_mock.patch(f"/models/{dummy.model.name}/truth").mock( + route = respx_mock.patch(f"/models/{dummy.model.metadata.name}/truth").mock( return_value=Response( 200, json=ModelOperationStatus( success=True, - model_name=dummy.model.name, + model_name=dummy.model.metadata.name, operation=ModelOperationType.UPDATE_TRUTH, ).model_dump(), ) ) # Set truth - model.truth = 0.9 + dummy.model.truth = 0.9 # Verify the API call assert route.called @@ -307,21 +259,15 @@ def test_truth_setter(respx_mock: MockRouter): def test_truth_setter_validation_error(respx_mock: MockRouter): """Test setting invalid truth values.""" dummy = model_factory(model_type="linker") - model = Model( - metadata=dummy.model, - model_instance=Mock(), - left_data=DataFrame(), - right_data=DataFrame(), - ) # Mock the PATCH endpoint with a validation error - route = respx_mock.patch(f"/models/{dummy.model.name}/truth").mock( + route = respx_mock.patch(f"/models/{dummy.model.metadata.name}/truth").mock( return_value=Response(422) ) # Attempt to set an invalid truth value with pytest.raises(MatchboxUnparsedClientRequest): - model.truth = 1.5 + dummy.model.truth = 1.5 assert route.called @@ -330,12 +276,6 @@ def test_truth_setter_validation_error(respx_mock: MockRouter): def test_ancestors_getter(respx_mock: MockRouter): """Test getting model ancestors via the API.""" dummy = model_factory(model_type="linker") - model = Model( - metadata=dummy.model, - model_instance=Mock(), - left_data=DataFrame(), - right_data=DataFrame(), - ) ancestors_data = [ ModelAncestor(name="model1", truth=0.9).model_dump(), @@ -343,12 +283,12 @@ def test_ancestors_getter(respx_mock: MockRouter): ] # Mock the GET /models/{name}/ancestors endpoint - route = respx_mock.get(f"/models/{dummy.model.name}/ancestors").mock( + route = respx_mock.get(f"/models/{dummy.model.metadata.name}/ancestors").mock( return_value=Response(200, json=ancestors_data) ) # Get ancestors - ancestors = model.ancestors + ancestors = dummy.model.ancestors # Verify the API call assert route.called @@ -359,39 +299,37 @@ def test_ancestors_getter(respx_mock: MockRouter): def test_ancestors_cache_operations(respx_mock: MockRouter): """Test getting and setting model ancestors cache via the API.""" dummy = model_factory(model_type="linker") - model = Model( - metadata=dummy.model, - model_instance=Mock(), - left_data=DataFrame(), - right_data=DataFrame(), - ) # Mock the GET endpoint - get_route = respx_mock.get(f"/models/{dummy.model.name}/ancestors_cache").mock( + get_route = respx_mock.get( + f"/models/{dummy.model.metadata.name}/ancestors_cache" + ).mock( return_value=Response( 200, json=[ModelAncestor(name="model1", truth=0.9).model_dump()] ) ) # Mock the POST endpoint - set_route = respx_mock.post(f"/models/{dummy.model.name}/ancestors_cache").mock( + set_route = respx_mock.post( + f"/models/{dummy.model.metadata.name}/ancestors_cache" + ).mock( return_value=Response( 200, json=ModelOperationStatus( success=True, - model_name=dummy.model.name, + model_name=dummy.model.metadata.name, operation=ModelOperationType.UPDATE_ANCESTOR_CACHE, ).model_dump(), ) ) # Get ancestors cache - cache = model.ancestors_cache + cache = dummy.model.ancestors_cache assert get_route.called assert cache == {"model1": 0.9} # Set ancestors cache - model.ancestors_cache = {"model2": 0.8} + dummy.model.ancestors_cache = {"model2": 0.8} assert set_route.called assert json.loads(set_route.calls.last.request.content.decode()) == [ ModelAncestor(name="model2", truth=0.8).model_dump() @@ -402,20 +340,16 @@ def test_ancestors_cache_operations(respx_mock: MockRouter): def test_ancestors_cache_set_error(respx_mock: MockRouter): """Test error handling when setting ancestors cache.""" dummy = model_factory(model_type="linker") - model = Model( - metadata=dummy.model, - model_instance=Mock(), - left_data=DataFrame(), - right_data=DataFrame(), - ) # Mock the POST endpoint with an error - route = respx_mock.post(f"/models/{dummy.model.name}/ancestors_cache").mock( + route = respx_mock.post( + f"/models/{dummy.model.metadata.name}/ancestors_cache" + ).mock( return_value=Response( 500, json=ModelOperationStatus( success=False, - model_name=dummy.model.name, + model_name=dummy.model.metadata.name, operation=ModelOperationType.UPDATE_ANCESTOR_CACHE, details="Database error", ).model_dump(), @@ -424,7 +358,7 @@ def test_ancestors_cache_set_error(respx_mock: MockRouter): # Attempt to set ancestors cache with pytest.raises(MatchboxUnhandledServerResponse, match="Database error"): - model.ancestors_cache = {"model1": 0.9} + dummy.model.ancestors_cache = {"model1": 0.9} assert route.called @@ -434,29 +368,23 @@ def test_delete_model(respx_mock: MockRouter): """Test successfully deleting a model.""" # Create test model using factory dummy = model_factory() - model = Model( - metadata=dummy.model, - model_instance=Mock(), - left_data=DataFrame(), - right_data=DataFrame(), - ) # Mock the DELETE endpoint with success response route = respx_mock.delete( - f"/models/{dummy.model.name}", params={"certain": True} + f"/models/{dummy.model.metadata.name}", params={"certain": True} ).mock( return_value=Response( 200, json=ModelOperationStatus( success=True, - model_name=dummy.model.name, + model_name=dummy.model.metadata.name, operation=ModelOperationType.DELETE, ).model_dump(), ) ) # Delete the model - response = model.delete(certain=True) + response = dummy.model.delete(certain=True) # Verify the response and API call assert response @@ -469,21 +397,15 @@ def test_delete_model_needs_confirmation(respx_mock: MockRouter): """Test attempting to delete a model without confirmation returns 409.""" # Create test model using factory dummy = model_factory() - model = Model( - metadata=dummy.model, - model_instance=Mock(), - left_data=DataFrame(), - right_data=DataFrame(), - ) # Mock the DELETE endpoint with 409 confirmation required response error_details = "Cannot delete model with dependent models: dedupe1, dedupe2" - route = respx_mock.delete(f"/models/{dummy.model.name}").mock( + route = respx_mock.delete(f"/models/{dummy.model.metadata.name}").mock( return_value=Response( 409, json=ModelOperationStatus( success=False, - model_name=dummy.model.name, + model_name=dummy.model.metadata.name, operation=ModelOperationType.DELETE, details=error_details, ).model_dump(), @@ -491,8 +413,8 @@ def test_delete_model_needs_confirmation(respx_mock: MockRouter): ) # Attempt to delete without certain=True - with pytest.raises(MatchboxConfirmDelete): - model.delete() + with pytest.raises(MatchboxDeletionNotConfirmed): + dummy.model.delete() # Verify the response and API call assert route.called diff --git a/test/common/test_factories.py b/test/common/test_factories.py index 96b85c21..68c4d7fc 100644 --- a/test/common/test_factories.py +++ b/test/common/test_factories.py @@ -285,15 +285,15 @@ def test_source_factory_metrics_with_multiple_features(): FeatureConfig( name="company_name", base_generator="company", - variations=[ + variations=( SuffixRule(suffix=" Inc"), SuffixRule(suffix=" Ltd"), - ], + ), ), FeatureConfig( name="email", base_generator="email", - variations=[ReplaceRule(old="@", new="+test@")], + variations=(ReplaceRule(old="@", new="+test@"),), ), ] @@ -418,12 +418,12 @@ def test_source_factory_mock_properties(): FeatureConfig( name="company_name", base_generator="company", - variations=[SuffixRule(suffix=" Ltd")], + variations=(SuffixRule(suffix=" Ltd"),), ), FeatureConfig( name="registration_id", base_generator="numerify", - parameters={"text": "######"}, + parameters=(("text", "######"),), ), ] @@ -462,15 +462,15 @@ def test_source_factory_mock_properties(): def test_model_factory_default(): """Test that model_factory generates a dummy model with default parameters.""" - model = model_factory() + dummy = model_factory() - assert model.metrics.n_true_entities == 10 - assert model.model.type == ModelType.DEDUPER - assert model.model.right_resolution is None + assert dummy.metrics.n_true_entities == 10 + assert dummy.model.metadata.type == ModelType.DEDUPER + assert dummy.model.metadata.right_resolution is None # Check that probabilities table was generated correctly - assert len(model.data) > 0 - assert model.data.schema.equals(SCHEMA_RESULTS) + assert len(dummy.data) > 0 + assert dummy.data.schema.equals(SCHEMA_RESULTS) def test_model_factory_with_custom_params(): @@ -480,19 +480,19 @@ def test_model_factory_with_custom_params(): n_true_entities = 5 prob_range = (0.9, 1.0) - model = model_factory( + dummy = model_factory( name=name, description=description, n_true_entities=n_true_entities, prob_range=prob_range, ) - assert model.model.name == name - assert model.model.description == description - assert model.metrics.n_true_entities == n_true_entities + assert dummy.model.metadata.name == name + assert dummy.model.metadata.description == description + assert dummy.metrics.n_true_entities == n_true_entities # Check probability range - probs = model.data.column("probability").to_pylist() + probs = dummy.data.column("probability").to_pylist() assert all(90 <= p <= 100 for p in probs) @@ -505,16 +505,16 @@ def test_model_factory_with_custom_params(): ) def test_model_factory_different_types(model_type: str): """Test model_factory handles different model types correctly.""" - model = model_factory(model_type=model_type) + dummy = model_factory(model_type=model_type) - assert model.model.type == model_type + assert dummy.model.metadata.type == model_type if model_type == ModelType.LINKER: - assert model.model.right_resolution is not None + assert dummy.model.metadata.right_resolution is not None # Check that left and right values are in different ranges - left_vals = model.data.column("left_id").to_pylist() - right_vals = model.data.column("right_id").to_pylist() + left_vals = dummy.data.column("left_id").to_pylist() + right_vals = dummy.data.column("right_id").to_pylist() left_min, left_max = min(left_vals), max(left_vals) right_min, right_max = min(right_vals), max(right_vals) assert (left_min < left_max < right_min < right_max) or ( @@ -531,14 +531,14 @@ def test_model_factory_different_types(model_type: str): ) def test_model_factory_seed_behavior(seed1: int, seed2: int, should_be_equal: bool): """Test that model_factory handles seeds correctly for reproducibility.""" - model1 = model_factory(seed=seed1) - model2 = model_factory(seed=seed2) + dummy1 = model_factory(seed=seed1) + dummy2 = model_factory(seed=seed2) if should_be_equal: - assert model1.model.name == model2.model.name - assert model1.model.description == model2.model.description - assert model1.data.equals(model2.data) + assert dummy1.model.metadata.name == dummy2.model.metadata.name + assert dummy1.model.metadata.description == dummy2.model.metadata.description + assert dummy1.data.equals(dummy2.data) else: - assert model1.model.name != model2.model.name - assert model1.model.description != model2.model.description - assert not model1.data.equals(model2.data) + assert dummy1.model.metadata.name != dummy2.model.metadata.name + assert dummy1.model.metadata.description != dummy2.model.metadata.description + assert not dummy1.data.equals(dummy2.data) diff --git a/test/server/api/test_routes.py b/test/server/api/test_routes.py index 030b9b01..8c1f615f 100644 --- a/test/server/api/test_routes.py +++ b/test/server/api/test_routes.py @@ -16,7 +16,7 @@ UploadStatus, ) from matchbox.common.exceptions import ( - MatchboxConfirmDelete, + MatchboxDeletionNotConfirmed, MatchboxResolutionNotFoundError, MatchboxSourceNotFoundError, ) @@ -188,9 +188,9 @@ def test_source_upload( mock_add_task.assert_called_once() # Verify background task was queued -@patch("matchbox.server.base.BackendManager.get_backend") +@patch("matchbox.server.base.BackendManager.get_backend") # Stops real backend call @patch("matchbox.server.api.routes.metadata_store") -def test_upload_status_check(metadata_store: Mock, get_backend: Mock): +def test_upload_status_check(metadata_store: Mock, _: Mock): """Test checking status of an upload using the status endpoint.""" # Setup store with a processing entry store = MetadataStore() @@ -210,9 +210,9 @@ def test_upload_status_check(metadata_store: Mock, get_backend: Mock): metadata_store.update_status.assert_not_called() -@patch("matchbox.server.base.BackendManager.get_backend") +@patch("matchbox.server.base.BackendManager.get_backend") # Stops real backend call @patch("matchbox.server.api.routes.metadata_store") -def test_upload_already_processing(metadata_store: Mock, get_backend: Mock): +def test_upload_already_processing(metadata_store: Mock, _: Mock): """Test attempting to upload when status is already processing.""" # Setup store with a processing entry store = MetadataStore() @@ -233,8 +233,9 @@ def test_upload_already_processing(metadata_store: Mock, get_backend: Mock): assert response.json()["status"] == "processing" +@patch("matchbox.server.base.BackendManager.get_backend") # Stops real backend call @patch("matchbox.server.api.routes.metadata_store") -def test_upload_already_queued(metadata_store: Mock): +def test_upload_already_queued(metadata_store: Mock, _: Mock): """Test attempting to upload when status is already queued.""" # Setup store with a queued entry store = MetadataStore() @@ -388,8 +389,8 @@ def test_insert_model(get_backend: Mock): mock_backend = Mock() get_backend.return_value = mock_backend - dummy_model = model_factory(name="test_model") - response = client.post("/models", json=dummy_model.model.model_dump()) + dummy = model_factory(name="test_model") + response = client.post("/models", json=dummy.model.metadata.model_dump()) assert response.status_code == 200 assert response.json() == { @@ -398,7 +399,7 @@ def test_insert_model(get_backend: Mock): "operation": ModelOperationType.INSERT.value, "details": None, } - mock_backend.insert_model.assert_called_once_with(dummy_model.model) + mock_backend.insert_model.assert_called_once_with(dummy.model.metadata) @patch("matchbox.server.base.BackendManager.get_backend") @@ -407,8 +408,8 @@ def test_insert_model_error(get_backend: Mock): mock_backend.insert_model = Mock(side_effect=Exception("Test error")) get_backend.return_value = mock_backend - dummy_model = model_factory() - response = client.post("/models", json=dummy_model.model.model_dump()) + dummy = model_factory() + response = client.post("/models", json=dummy.model.metadata.model_dump()) assert response.status_code == 500 assert response.json()["success"] is False @@ -418,15 +419,15 @@ def test_insert_model_error(get_backend: Mock): @patch("matchbox.server.base.BackendManager.get_backend") def test_get_model(get_backend: Mock): mock_backend = Mock() - dummy_model = model_factory(name="test_model", description="test description") - mock_backend.get_model = Mock(return_value=dummy_model.model) + dummy = model_factory(name="test_model", description="test description") + mock_backend.get_model = Mock(return_value=dummy.model.metadata) get_backend.return_value = mock_backend response = client.get("/models/test_model") assert response.status_code == 200 - assert response.json()["name"] == dummy_model.model.name - assert response.json()["description"] == dummy_model.model.description + assert response.json()["name"] == dummy.model.metadata.name + assert response.json()["description"] == dummy.model.metadata.description @patch("matchbox.server.base.BackendManager.get_backend") @@ -464,11 +465,11 @@ def test_model_upload( ) # Create test data with specified model type - dummy_model = model_factory(model_type=model_type) + dummy = model_factory(model_type=model_type) # Setup metadata store store = MetadataStore() - upload_id = store.cache_model(dummy_model.model) + upload_id = store.cache_model(dummy.model.metadata) metadata_store.get.side_effect = store.get metadata_store.update_status.side_effect = store.update_status @@ -479,7 +480,7 @@ def test_model_upload( files={ "file": ( "data.parquet", - table_to_buffer(dummy_model.data), + table_to_buffer(dummy.data), "application/octet-stream", ), }, @@ -512,20 +513,20 @@ async def test_complete_model_upload_process( ) # Create test data with specified model type - dummy_model = model_factory(model_type=model_type) + dummy = model_factory(model_type=model_type) # Set up the mock to return the actual model metadata and data - mock_backend.get_model = Mock(return_value=dummy_model.model) - mock_backend.get_model_results = Mock(return_value=dummy_model.data) + mock_backend.get_model = Mock(return_value=dummy.model.metadata) + mock_backend.get_model_results = Mock(return_value=dummy.data) # Step 1: Create model - response = client.post("/models", json=dummy_model.model.model_dump()) + response = client.post("/models", json=dummy.model.metadata.model_dump()) assert response.status_code == 200 assert response.json()["success"] is True - assert response.json()["model_name"] == dummy_model.model.name + assert response.json()["model_name"] == dummy.model.metadata.name # Step 2: Initialize results upload - response = client.post(f"/models/{dummy_model.model.name}/results") + response = client.post(f"/models/{dummy.model.metadata.name}/results") assert response.status_code == 200 upload_id = response.json()["id"] assert response.json()["status"] == "awaiting_upload" @@ -536,7 +537,7 @@ async def test_complete_model_upload_process( files={ "file": ( "results.parquet", - table_to_buffer(dummy_model.data), + table_to_buffer(dummy.data), "application/octet-stream", ), }, @@ -573,34 +574,36 @@ async def test_complete_model_upload_process( # Step 5: Verify results were stored correctly mock_backend.set_model_results.assert_called_once() call_args = mock_backend.set_model_results.call_args - assert call_args[1]["model"] == dummy_model.model.name # Check model name matches - assert call_args[1]["results"].equals( - dummy_model.data - ) # Check results data matches + assert ( + call_args[1]["model"] == dummy.model.metadata.name + ) # Check model name matches + assert call_args[1]["results"].equals(dummy.data) # Check results data matches # Step 6: Verify we can retrieve the results - response = client.get(f"/models/{dummy_model.model.name}/results") + response = client.get(f"/models/{dummy.model.metadata.name}/results") assert response.status_code == 200 assert response.headers["content-type"] == "application/octet-stream" # Step 7: Additional model-specific verifications if model_type == "linker": # For linker models, verify left and right resolutions are set - assert dummy_model.model.left_resolution is not None - assert dummy_model.model.right_resolution is not None + assert dummy.model.metadata.left_resolution is not None + assert dummy.model.metadata.right_resolution is not None else: # For deduper models, verify only left resolution is set - assert dummy_model.model.left_resolution is not None - assert dummy_model.model.right_resolution is None + assert dummy.model.metadata.left_resolution is not None + assert dummy.model.metadata.right_resolution is None # Verify the model truth can be set and retrieved truth_value = 0.85 mock_backend.get_model_truth = Mock(return_value=truth_value) - response = client.patch(f"/models/{dummy_model.model.name}/truth", json=truth_value) + response = client.patch( + f"/models/{dummy.model.metadata.name}/truth", json=truth_value + ) assert response.status_code == 200 - response = client.get(f"/models/{dummy_model.model.name}/truth") + response = client.get(f"/models/{dummy.model.metadata.name}/truth") assert response.status_code == 200 assert response.json() == truth_value @@ -610,15 +613,15 @@ def test_delete_model(get_backend: Mock): mock_backend = Mock() get_backend.return_value = mock_backend - dummy_model = model_factory() + dummy = model_factory() response = client.delete( - f"/models/{dummy_model.model.name}", params={"certain": True} + f"/models/{dummy.model.metadata.name}", params={"certain": True} ) assert response.status_code == 200 assert response.json() == { "success": True, - "model_name": dummy_model.model.name, + "model_name": dummy.model.metadata.name, "operation": ModelOperationType.DELETE, "details": None, } @@ -628,12 +631,12 @@ def test_delete_model(get_backend: Mock): def test_delete_model_needs_confirmation(get_backend: Mock): mock_backend = Mock() mock_backend.delete_model = Mock( - side_effect=MatchboxConfirmDelete(children=["dedupe1", "dedupe2"]) + side_effect=MatchboxDeletionNotConfirmed(children=["dedupe1", "dedupe2"]) ) get_backend.return_value = mock_backend - dummy_model = model_factory() - response = client.delete(f"/models/{dummy_model.model.name}") + dummy = model_factory() + response = client.delete(f"/models/{dummy.model.metadata.name}") assert response.status_code == 409 assert response.json()["success"] is False @@ -644,11 +647,11 @@ def test_delete_model_needs_confirmation(get_backend: Mock): @patch("matchbox.server.base.BackendManager.get_backend") def test_get_results(get_backend: Mock): mock_backend = Mock() - dummy_model = model_factory() - mock_backend.get_model_results = Mock(return_value=dummy_model.data) + dummy = model_factory() + mock_backend.get_model_results = Mock(return_value=dummy.data) get_backend.return_value = mock_backend - response = client.get(f"/models/{dummy_model.model.name}/results") + response = client.get(f"/models/{dummy.model.metadata.name}/results") assert response.status_code == 200 assert response.headers["content-type"] == "application/octet-stream" @@ -657,11 +660,11 @@ def test_get_results(get_backend: Mock): @patch("matchbox.server.base.BackendManager.get_backend") def test_set_results(get_backend: Mock): mock_backend = Mock() - dummy_model = model_factory() - mock_backend.get_model = Mock(return_value=dummy_model.model) + dummy = model_factory() + mock_backend.get_model = Mock(return_value=dummy.model.metadata) get_backend.return_value = mock_backend - response = client.post(f"/models/{dummy_model.model.name}/results") + response = client.post(f"/models/{dummy.model.metadata.name}/results") assert response.status_code == 200 assert response.json()["status"] == "awaiting_upload" @@ -683,11 +686,11 @@ def test_set_results_model_not_found(get_backend: Mock): @patch("matchbox.server.base.BackendManager.get_backend") def test_get_truth(get_backend: Mock): mock_backend = Mock() - dummy_model = model_factory() + dummy = model_factory() mock_backend.get_model_truth = Mock(return_value=0.95) get_backend.return_value = mock_backend - response = client.get(f"/models/{dummy_model.model.name}/truth") + response = client.get(f"/models/{dummy.model.metadata.name}/truth") assert response.status_code == 200 assert response.json() == 0.95 @@ -696,15 +699,15 @@ def test_get_truth(get_backend: Mock): @patch("matchbox.server.base.BackendManager.get_backend") def test_set_truth(get_backend: Mock): mock_backend = Mock() - dummy_model = model_factory() + dummy = model_factory() get_backend.return_value = mock_backend - response = client.patch(f"/models/{dummy_model.model.name}/truth", json=0.95) + response = client.patch(f"/models/{dummy.model.metadata.name}/truth", json=0.95) assert response.status_code == 200 assert response.json()["success"] is True mock_backend.set_model_truth.assert_called_once_with( - model=dummy_model.model.name, truth=0.95 + model=dummy.model.metadata.name, truth=0.95 ) @@ -712,22 +715,22 @@ def test_set_truth(get_backend: Mock): def test_set_truth_invalid_value(get_backend: Mock): """Test setting an invalid truth value (outside 0-1 range).""" mock_backend = Mock() - dummy_model = model_factory() + dummy = model_factory() get_backend.return_value = mock_backend # Test value > 1 - response = client.patch(f"/models/{dummy_model.model.name}/truth", json=1.5) + response = client.patch(f"/models/{dummy.model.metadata.name}/truth", json=1.5) assert response.status_code == 422 # Test value < 0 - response = client.patch(f"/models/{dummy_model.model.name}/truth", json=-0.5) + response = client.patch(f"/models/{dummy.model.metadata.name}/truth", json=-0.5) assert response.status_code == 422 @patch("matchbox.server.base.BackendManager.get_backend") def test_get_ancestors(get_backend: Mock): mock_backend = Mock() - dummy_model = model_factory() + dummy = model_factory() mock_ancestors = [ ModelAncestor(name="parent_model", truth=0.7), ModelAncestor(name="grandparent_model", truth=0.97), @@ -735,7 +738,7 @@ def test_get_ancestors(get_backend: Mock): mock_backend.get_model_ancestors = Mock(return_value=mock_ancestors) get_backend.return_value = mock_backend - response = client.get(f"/models/{dummy_model.model.name}/ancestors") + response = client.get(f"/models/{dummy.model.metadata.name}/ancestors") assert response.status_code == 200 assert len(response.json()) == 2 @@ -746,7 +749,7 @@ def test_get_ancestors(get_backend: Mock): def test_set_ancestors_cache(get_backend: Mock): """Test setting the ancestors cache for a model.""" mock_backend = Mock() - dummy_model = model_factory() + dummy = model_factory() get_backend.return_value = mock_backend ancestors_data = [ @@ -755,7 +758,7 @@ def test_set_ancestors_cache(get_backend: Mock): ] response = client.patch( - f"/models/{dummy_model.model.name}/ancestors_cache", + f"/models/{dummy.model.metadata.name}/ancestors_cache", json=[a.model_dump() for a in ancestors_data], ) @@ -763,7 +766,7 @@ def test_set_ancestors_cache(get_backend: Mock): assert response.json()["success"] is True assert response.json()["operation"] == ModelOperationType.UPDATE_ANCESTOR_CACHE mock_backend.set_model_ancestors_cache.assert_called_once_with( - model=dummy_model.model.name, ancestors_cache=ancestors_data + model=dummy.model.metadata.name, ancestors_cache=ancestors_data ) @@ -771,7 +774,7 @@ def test_set_ancestors_cache(get_backend: Mock): def test_get_ancestors_cache(get_backend: Mock): """Test retrieving the ancestors cache for a model.""" mock_backend = Mock() - dummy_model = model_factory() + dummy = model_factory() mock_ancestors = [ ModelAncestor(name="parent_model", truth=0.7), ModelAncestor(name="grandparent_model", truth=0.8), @@ -779,7 +782,7 @@ def test_get_ancestors_cache(get_backend: Mock): mock_backend.get_model_ancestors_cache = Mock(return_value=mock_ancestors) get_backend.return_value = mock_backend - response = client.get(f"/models/{dummy_model.model.name}/ancestors_cache") + response = client.get(f"/models/{dummy.model.metadata.name}/ancestors_cache") assert response.status_code == 200 assert len(response.json()) == 2 From f02cb8709f2bb6fa843d3e53d4d066c5ee4632a6 Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Fri, 14 Feb 2025 11:04:54 +0000 Subject: [PATCH 13/19] Added environment variable for client polling delay when uploading files --- environments/dev_docker.env | 3 ++- environments/dev_local.env | 3 ++- environments/sample.env | 3 ++- src/matchbox/client/_handler.py | 6 +++--- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/environments/dev_docker.env b/environments/dev_docker.env index 393e54b0..c25588ec 100644 --- a/environments/dev_docker.env +++ b/environments/dev_docker.env @@ -17,4 +17,5 @@ MB__POSTGRES__DATABASE=matchbox MB__POSTGRES__DB_SCHEMA=mb MB__CLIENT__API_ROOT=http://localhost:8000 -MB__CLIENT__TIMEOUT=10 \ No newline at end of file +MB__CLIENT__TIMEOUT=10 +MB__CLIENT__RETRY_DELAY=2 \ No newline at end of file diff --git a/environments/dev_local.env b/environments/dev_local.env index dce50bf8..be87d8a1 100644 --- a/environments/dev_local.env +++ b/environments/dev_local.env @@ -17,4 +17,5 @@ MB__POSTGRES__DATABASE=matchbox MB__POSTGRES__DB_SCHEMA=mb MB__CLIENT__API_ROOT=http://localhost:8000 -MB__CLIENT__TIMEOUT=10 \ No newline at end of file +MB__CLIENT__TIMEOUT=10 +MB__CLIENT__RETRY_DELAY=2 diff --git a/environments/sample.env b/environments/sample.env index 7ee538b0..3dc3a93a 100644 --- a/environments/sample.env +++ b/environments/sample.env @@ -17,4 +17,5 @@ MB__POSTGRES__DATABASE= MB__POSTGRES__DB_SCHEMA= MB__CLIENT__API_ROOT= -MB__CLIENT__TIMEOUT= \ No newline at end of file +MB__CLIENT__TIMEOUT= +MB__CLIENT__RETRY_DELAY= diff --git a/src/matchbox/client/_handler.py b/src/matchbox/client/_handler.py index e9835411..17394e55 100644 --- a/src/matchbox/client/_handler.py +++ b/src/matchbox/client/_handler.py @@ -106,7 +106,7 @@ def create_client() -> httpx.Client: CLIENT = create_client() - +DELAY = int(getenv("MB__CLIENT__RETRY_DELAY", 2)) # Retrieval @@ -201,7 +201,7 @@ def index(source: Source, data_hashes: Table) -> UploadStatus: if status.status == "failed": raise MatchboxServerFileError(status.details) - time.sleep(2) + time.sleep(DELAY) return status @@ -257,7 +257,7 @@ def add_model_results(name: str, results: Table) -> UploadStatus: if status.status == "failed": raise MatchboxServerFileError(status.details) - time.sleep(2) + time.sleep(DELAY) return status From 630659da9f8e6e1700151d27aad62f65163bc17a Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Fri, 14 Feb 2025 11:16:24 +0000 Subject: [PATCH 14/19] Moved respx client to a fixture --- test/client/test_helpers.py | 78 +++++++++++++----------------- test/client/test_model.py | 79 +++++++++++++------------------ test/client/test_visualisation.py | 8 +--- test/fixtures/db.py | 14 ++++++ 4 files changed, 82 insertions(+), 97 deletions(-) diff --git a/test/client/test_helpers.py b/test/client/test_helpers.py index 4abcfc2c..5ac67475 100644 --- a/test/client/test_helpers.py +++ b/test/client/test_helpers.py @@ -1,5 +1,4 @@ import logging -from os import getenv from unittest.mock import Mock, patch import pyarrow as pa @@ -75,8 +74,7 @@ def test_comparisons(): assert comparison_name_id is not None -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) -def test_select_mixed_style(respx_mock: MockRouter, warehouse_engine: Engine): +def test_select_mixed_style(matchbox_api: MockRouter, warehouse_engine: Engine): """We can select select specific columns from some of the sources""" # Set up mocks and test data source1 = Source( @@ -89,10 +87,10 @@ def test_select_mixed_style(respx_mock: MockRouter, warehouse_engine: Engine): db_pk="pk", ) - respx_mock.get( + matchbox_api.get( f"/sources/{hash_to_base64(source1.address.warehouse_hash)}/test.foo" ).mock(return_value=Response(200, content=source1.model_dump_json())) - respx_mock.get( + matchbox_api.get( f"/sources/{hash_to_base64(source2.address.warehouse_hash)}/test.bar" ).mock(return_value=Response(200, content=source2.model_dump_json())) @@ -127,14 +125,13 @@ def test_select_mixed_style(respx_mock: MockRouter, warehouse_engine: Engine): assert selection[1].source.engine == warehouse_engine -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) -def test_select_non_indexed_columns(respx_mock: MockRouter, warehouse_engine: Engine): +def test_select_non_indexed_columns(matchbox_api: MockRouter, warehouse_engine: Engine): """Selecting columns not declared to backend generates warning.""" source = Source( address=SourceAddress.compose(engine=warehouse_engine, full_name="test.foo"), db_pk="pk", ) - respx_mock.get( + matchbox_api.get( f"/sources/{hash_to_base64(source.address.warehouse_hash)}/test.foo" ).mock(return_value=Response(200, content=source.model_dump_json())) @@ -152,15 +149,14 @@ def test_select_non_indexed_columns(respx_mock: MockRouter, warehouse_engine: En select({"test.foo": ["a", "b"]}, engine=warehouse_engine) -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) -def test_select_missing_columns(respx_mock: MockRouter, warehouse_engine: Engine): +def test_select_missing_columns(matchbox_api: MockRouter, warehouse_engine: Engine): """Selecting columns not in the warehouse errors.""" source = Source( address=SourceAddress.compose(engine=warehouse_engine, full_name="test.foo"), db_pk="pk", ) - respx_mock.get( + matchbox_api.get( f"/sources/{hash_to_base64(source.address.warehouse_hash)}/test.foo" ).mock(return_value=Response(200, content=source.model_dump_json())) @@ -203,12 +199,13 @@ def test_query_no_resolution_fail(): query(sels) -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) @patch.object(Source, "to_arrow") -def test_query_no_resolution_ok_various_params(to_arrow: Mock, respx_mock: MockRouter): +def test_query_no_resolution_ok_various_params( + to_arrow: Mock, matchbox_api: MockRouter +): """Tests that we can avoid passing resolution name, with a variety of parameters.""" # Mock API - query_route = respx_mock.get("/query").mock( + query_route = matchbox_api.get("/query").mock( return_value=Response( 200, content=table_to_buffer( @@ -271,12 +268,11 @@ def test_query_no_resolution_ok_various_params(to_arrow: Mock, respx_mock: MockR } -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) @patch.object(Source, "to_arrow") -def test_query_multiple_sources_with_limits(to_arrow: Mock, respx_mock: MockRouter): +def test_query_multiple_sources_with_limits(to_arrow: Mock, matchbox_api: MockRouter): """Tests that we can query multiple sources and distribute the limit among them.""" # Mock API - query_route = respx_mock.get("/query").mock( + query_route = matchbox_api.get("/query").mock( side_effect=[ Response( 200, @@ -373,10 +369,9 @@ def test_query_multiple_sources_with_limits(to_arrow: Mock, respx_mock: MockRout query([sels[0]], [sels[1]], resolution_name="link", limit=7) -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) -def test_query_404_resolution(respx_mock: MockRouter): +def test_query_404_resolution(matchbox_api: MockRouter): # Mock API - respx_mock.get("/query").mock( + matchbox_api.get("/query").mock( return_value=Response( 404, json=NotFoundError( @@ -405,10 +400,9 @@ def test_query_404_resolution(respx_mock: MockRouter): query(sels) -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) -def test_query_404_source(respx_mock: MockRouter): +def test_query_404_source(matchbox_api: MockRouter): # Mock API - respx_mock.get("/query").mock( + matchbox_api.get("/query").mock( return_value=Response( 404, json=NotFoundError( @@ -437,9 +431,8 @@ def test_query_404_source(respx_mock: MockRouter): query(sels) -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) @patch("matchbox.client.helpers.index.Source") -def test_index_success(MockSource: Mock, respx_mock: MockRouter): +def test_index_success(MockSource: Mock, matchbox_api: MockRouter): """Test successful indexing flow through the API.""" engine = create_engine("sqlite:///:memory:") @@ -451,7 +444,7 @@ def test_index_success(MockSource: Mock, respx_mock: MockRouter): MockSource.return_value = mock_source_instance # Mock the initial source metadata upload - source_route = respx_mock.post("/sources").mock( + source_route = matchbox_api.post("/sources").mock( return_value=Response( 200, json=UploadStatus( @@ -463,7 +456,7 @@ def test_index_success(MockSource: Mock, respx_mock: MockRouter): ) # Mock the data upload - upload_route = respx_mock.post("/upload/test-upload-id").mock( + upload_route = matchbox_api.post("/upload/test-upload-id").mock( return_value=Response( 200, json=UploadStatus( @@ -489,7 +482,6 @@ def test_index_success(MockSource: Mock, respx_mock: MockRouter): assert b"PAR1" in upload_route.calls.last.request.content -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) @patch("matchbox.client.helpers.index.Source") @pytest.mark.parametrize( "columns", @@ -506,7 +498,9 @@ def test_index_success(MockSource: Mock, respx_mock: MockRouter): ], ) def test_index_with_columns( - MockSource: Mock, respx_mock: MockRouter, columns: list[str] | list[dict[str, str]] + MockSource: Mock, + matchbox_api: MockRouter, + columns: list[str] | list[dict[str, str]], ): """Test indexing with different column definition formats.""" engine = create_engine("sqlite:///:memory:") @@ -523,7 +517,7 @@ def test_index_with_columns( MockSource.return_value = mock_source_instance # Mock the API endpoints - source_route = respx_mock.post("/sources").mock( + source_route = matchbox_api.post("/sources").mock( return_value=Response( 200, json=UploadStatus( @@ -534,7 +528,7 @@ def test_index_with_columns( ) ) - upload_route = respx_mock.post("/upload/test-upload-id").mock( + upload_route = matchbox_api.post("/upload/test-upload-id").mock( return_value=Response( 200, json=UploadStatus( @@ -567,9 +561,8 @@ def test_index_with_columns( mock_source_instance.default_columns.assert_called_once() -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) @patch("matchbox.client.helpers.index.Source") -def test_index_upload_failure(MockSource: Mock, respx_mock: MockRouter): +def test_index_upload_failure(MockSource: Mock, matchbox_api: MockRouter): """Test handling of upload failures.""" engine = create_engine("sqlite:///:memory:") @@ -581,7 +574,7 @@ def test_index_upload_failure(MockSource: Mock, respx_mock: MockRouter): MockSource.return_value = mock_source_instance # Mock successful source creation - source_route = respx_mock.post("/sources").mock( + source_route = matchbox_api.post("/sources").mock( return_value=Response( 200, json=UploadStatus( @@ -593,7 +586,7 @@ def test_index_upload_failure(MockSource: Mock, respx_mock: MockRouter): ) # Mock failed upload - upload_route = respx_mock.post("/upload/test-upload-id").mock( + upload_route = matchbox_api.post("/upload/test-upload-id").mock( return_value=Response( 400, json=UploadStatus( @@ -623,8 +616,7 @@ def test_index_upload_failure(MockSource: Mock, respx_mock: MockRouter): assert b"PAR1" in upload_route.calls.last.request.content -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) -def test_match_ok(respx_mock: MockRouter): +def test_match_ok(matchbox_api: MockRouter): """The client can perform the right call for matching.""" # Set up mocks mock_match1 = Match( @@ -646,7 +638,7 @@ def test_match_ok(respx_mock: MockRouter): f"[{mock_match1.model_dump_json()}, {mock_match2.model_dump_json()}]" ) - match_route = respx_mock.get("/match").mock( + match_route = matchbox_api.get("/match").mock( return_value=Response(200, content=serialised_matches) ) @@ -715,11 +707,10 @@ def test_match_ok(respx_mock: MockRouter): ) -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) -def test_match_404_resolution(respx_mock: MockRouter): +def test_match_404_resolution(matchbox_api: MockRouter): """The client can handle a resolution not found error.""" # Set up mocks - respx_mock.get("/match").mock( + matchbox_api.get("/match").mock( return_value=Response( 404, json=NotFoundError( @@ -764,11 +755,10 @@ def test_match_404_resolution(respx_mock: MockRouter): ) -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) -def test_match_404_source(respx_mock: MockRouter): +def test_match_404_source(matchbox_api: MockRouter): """The client can handle a source not found error.""" # Set up mocks - respx_mock.get("/match").mock( + matchbox_api.get("/match").mock( return_value=Response( 404, json=NotFoundError( diff --git a/test/client/test_model.py b/test/client/test_model.py index b785bca9..8b8ece0c 100644 --- a/test/client/test_model.py +++ b/test/client/test_model.py @@ -1,5 +1,4 @@ import json -from os import getenv import pytest from httpx import Response @@ -26,14 +25,13 @@ from matchbox.common.factories.models import model_factory -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) -def test_insert_model(respx_mock: MockRouter): +def test_insert_model(matchbox_api: MockRouter): """Test inserting a model via the API.""" # Create test model using factory dummy = model_factory(model_type="linker") # Mock the POST /models endpoint - route = respx_mock.post("/models").mock( + route = matchbox_api.post("/models").mock( return_value=Response( 200, json=ModelOperationStatus( @@ -55,13 +53,12 @@ def test_insert_model(respx_mock: MockRouter): ) -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) -def test_insert_model_error(respx_mock: MockRouter): +def test_insert_model_error(matchbox_api: MockRouter): """Test handling of model insertion errors.""" dummy = model_factory(model_type="linker") # Mock the POST /models endpoint with an error response - route = respx_mock.post("/models").mock( + route = matchbox_api.post("/models").mock( return_value=Response( 500, json=ModelOperationStatus( @@ -80,13 +77,12 @@ def test_insert_model_error(respx_mock: MockRouter): assert route.called -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) -def test_results_getter(respx_mock: MockRouter): +def test_results_getter(matchbox_api: MockRouter): """Test getting model results via the API.""" dummy = model_factory(model_type="linker") # Mock the GET /models/{name}/results endpoint - route = respx_mock.get(f"/models/{dummy.model.metadata.name}/results").mock( + route = matchbox_api.get(f"/models/{dummy.model.metadata.name}/results").mock( return_value=Response(200, content=table_to_buffer(dummy.data).read()) ) @@ -99,13 +95,12 @@ def test_results_getter(respx_mock: MockRouter): assert results.probabilities.schema.equals(SCHEMA_RESULTS) -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) -def test_results_getter_not_found(respx_mock: MockRouter): +def test_results_getter_not_found(matchbox_api: MockRouter): """Test getting model results when they don't exist.""" dummy = model_factory(model_type="linker") # Mock the GET endpoint with a 404 response - route = respx_mock.get(f"/models/{dummy.model.metadata.name}/results").mock( + route = matchbox_api.get(f"/models/{dummy.model.metadata.name}/results").mock( return_value=Response( 404, json=NotFoundError( @@ -121,13 +116,12 @@ def test_results_getter_not_found(respx_mock: MockRouter): assert route.called -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) -def test_results_setter(respx_mock: MockRouter): +def test_results_setter(matchbox_api: MockRouter): """Test setting model results via the API.""" dummy = model_factory(model_type="linker") # Mock the endpoints needed for results upload - init_route = respx_mock.post(f"/models/{dummy.model.metadata.name}/results").mock( + init_route = matchbox_api.post(f"/models/{dummy.model.metadata.name}/results").mock( return_value=Response( 200, json=UploadStatus( @@ -138,7 +132,7 @@ def test_results_setter(respx_mock: MockRouter): ) ) - upload_route = respx_mock.post("/upload/test-upload-id").mock( + upload_route = matchbox_api.post("/upload/test-upload-id").mock( return_value=Response( 200, json=UploadStatus( @@ -149,7 +143,7 @@ def test_results_setter(respx_mock: MockRouter): ) ) - status_route = respx_mock.get("/upload/test-upload-id/status").mock( + status_route = matchbox_api.get("/upload/test-upload-id/status").mock( return_value=Response( 200, json=UploadStatus( @@ -173,13 +167,12 @@ def test_results_setter(respx_mock: MockRouter): ) # Check for parquet file signature -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) -def test_results_setter_upload_failure(respx_mock: MockRouter): +def test_results_setter_upload_failure(matchbox_api: MockRouter): """Test handling of upload failures when setting results.""" dummy = model_factory(model_type="linker") # Mock the initial POST endpoint - init_route = respx_mock.post(f"/models/{dummy.model.metadata.name}/results").mock( + init_route = matchbox_api.post(f"/models/{dummy.model.metadata.name}/results").mock( return_value=Response( 200, json=UploadStatus( @@ -191,7 +184,7 @@ def test_results_setter_upload_failure(respx_mock: MockRouter): ) # Mock the upload endpoint with a failure - upload_route = respx_mock.post("/upload/test-upload-id").mock( + upload_route = matchbox_api.post("/upload/test-upload-id").mock( return_value=Response( 400, json=UploadStatus( @@ -212,13 +205,12 @@ def test_results_setter_upload_failure(respx_mock: MockRouter): assert upload_route.called -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) -def test_truth_getter(respx_mock: MockRouter): +def test_truth_getter(matchbox_api: MockRouter): """Test getting model truth threshold via the API.""" dummy = model_factory(model_type="linker") # Mock the GET /models/{name}/truth endpoint - route = respx_mock.get(f"/models/{dummy.model.metadata.name}/truth").mock( + route = matchbox_api.get(f"/models/{dummy.model.metadata.name}/truth").mock( return_value=Response(200, json=0.9) ) @@ -230,13 +222,12 @@ def test_truth_getter(respx_mock: MockRouter): assert truth == 0.9 -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) -def test_truth_setter(respx_mock: MockRouter): +def test_truth_setter(matchbox_api: MockRouter): """Test setting model truth threshold via the API.""" dummy = model_factory(model_type="linker") # Mock the PATCH /models/{name}/truth endpoint - route = respx_mock.patch(f"/models/{dummy.model.metadata.name}/truth").mock( + route = matchbox_api.patch(f"/models/{dummy.model.metadata.name}/truth").mock( return_value=Response( 200, json=ModelOperationStatus( @@ -255,13 +246,12 @@ def test_truth_setter(respx_mock: MockRouter): assert float(route.calls.last.request.read()) == 0.9 -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) -def test_truth_setter_validation_error(respx_mock: MockRouter): +def test_truth_setter_validation_error(matchbox_api: MockRouter): """Test setting invalid truth values.""" dummy = model_factory(model_type="linker") # Mock the PATCH endpoint with a validation error - route = respx_mock.patch(f"/models/{dummy.model.metadata.name}/truth").mock( + route = matchbox_api.patch(f"/models/{dummy.model.metadata.name}/truth").mock( return_value=Response(422) ) @@ -272,8 +262,7 @@ def test_truth_setter_validation_error(respx_mock: MockRouter): assert route.called -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) -def test_ancestors_getter(respx_mock: MockRouter): +def test_ancestors_getter(matchbox_api: MockRouter): """Test getting model ancestors via the API.""" dummy = model_factory(model_type="linker") @@ -283,7 +272,7 @@ def test_ancestors_getter(respx_mock: MockRouter): ] # Mock the GET /models/{name}/ancestors endpoint - route = respx_mock.get(f"/models/{dummy.model.metadata.name}/ancestors").mock( + route = matchbox_api.get(f"/models/{dummy.model.metadata.name}/ancestors").mock( return_value=Response(200, json=ancestors_data) ) @@ -295,13 +284,12 @@ def test_ancestors_getter(respx_mock: MockRouter): assert ancestors == {"model1": 0.9, "model2": 0.8} -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) -def test_ancestors_cache_operations(respx_mock: MockRouter): +def test_ancestors_cache_operations(matchbox_api: MockRouter): """Test getting and setting model ancestors cache via the API.""" dummy = model_factory(model_type="linker") # Mock the GET endpoint - get_route = respx_mock.get( + get_route = matchbox_api.get( f"/models/{dummy.model.metadata.name}/ancestors_cache" ).mock( return_value=Response( @@ -310,7 +298,7 @@ def test_ancestors_cache_operations(respx_mock: MockRouter): ) # Mock the POST endpoint - set_route = respx_mock.post( + set_route = matchbox_api.post( f"/models/{dummy.model.metadata.name}/ancestors_cache" ).mock( return_value=Response( @@ -336,13 +324,12 @@ def test_ancestors_cache_operations(respx_mock: MockRouter): ] -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) -def test_ancestors_cache_set_error(respx_mock: MockRouter): +def test_ancestors_cache_set_error(matchbox_api: MockRouter): """Test error handling when setting ancestors cache.""" dummy = model_factory(model_type="linker") # Mock the POST endpoint with an error - route = respx_mock.post( + route = matchbox_api.post( f"/models/{dummy.model.metadata.name}/ancestors_cache" ).mock( return_value=Response( @@ -363,14 +350,13 @@ def test_ancestors_cache_set_error(respx_mock: MockRouter): assert route.called -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) -def test_delete_model(respx_mock: MockRouter): +def test_delete_model(matchbox_api: MockRouter): """Test successfully deleting a model.""" # Create test model using factory dummy = model_factory() # Mock the DELETE endpoint with success response - route = respx_mock.delete( + route = matchbox_api.delete( f"/models/{dummy.model.metadata.name}", params={"certain": True} ).mock( return_value=Response( @@ -392,15 +378,14 @@ def test_delete_model(respx_mock: MockRouter): assert route.calls.last.request.url.params["certain"] == "true" -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) -def test_delete_model_needs_confirmation(respx_mock: MockRouter): +def test_delete_model_needs_confirmation(matchbox_api: MockRouter): """Test attempting to delete a model without confirmation returns 409.""" # Create test model using factory dummy = model_factory() # Mock the DELETE endpoint with 409 confirmation required response error_details = "Cannot delete model with dependent models: dedupe1, dedupe2" - route = respx_mock.delete(f"/models/{dummy.model.metadata.name}").mock( + route = matchbox_api.delete(f"/models/{dummy.model.metadata.name}").mock( return_value=Response( 409, json=ModelOperationStatus( diff --git a/test/client/test_visualisation.py b/test/client/test_visualisation.py index f9dc511c..2b341193 100644 --- a/test/client/test_visualisation.py +++ b/test/client/test_visualisation.py @@ -1,6 +1,3 @@ -from os import getenv - -import pytest from httpx import Response from matplotlib.figure import Figure from respx import MockRouter @@ -9,11 +6,10 @@ from matchbox.common.graph import ResolutionGraph -@pytest.mark.respx(base_url=getenv("MB__CLIENT__API_ROOT")) def test_draw_resolution_graph( - respx_mock: MockRouter, resolution_graph: ResolutionGraph + matchbox_api: MockRouter, resolution_graph: ResolutionGraph ): - respx_mock.get("/report/resolutions").mock( + matchbox_api.get("/report/resolutions").mock( return_value=Response(200, content=resolution_graph.model_dump_json()), ) diff --git a/test/fixtures/db.py b/test/fixtures/db.py index b4a23faf..3436fcea 100644 --- a/test/fixtures/db.py +++ b/test/fixtures/db.py @@ -1,12 +1,15 @@ import os +from os import getenv from typing import TYPE_CHECKING, Any, Callable, Generator, Literal import boto3 import pytest +import respx from _pytest.fixtures import FixtureRequest from dotenv import find_dotenv, load_dotenv from moto import mock_aws from pandas import DataFrame +from respx import MockRouter from sqlalchemy import Engine, create_engine from sqlalchemy import text as sqltext @@ -384,3 +387,14 @@ def s3(aws_credentials: None) -> Generator[S3Client, None, None]: """Return a mocked S3 client.""" with mock_aws(): yield boto3.client("s3", region_name="eu-west-2") + + +# Mock API + + +@pytest.fixture(scope="function") +def matchbox_api() -> Generator[MockRouter, None, None]: + with respx.mock( + base_url=getenv("MB__CLIENT__API_ROOT"), assert_all_called=True + ) as respx_mock: + yield respx_mock From ba14792794be98fa280dfbc6a0d3418b8279b197 Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Fri, 14 Feb 2025 11:27:52 +0000 Subject: [PATCH 15/19] Amended pytest config to deal with some warnings --- pyproject.toml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index e0620bb5..6875312a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -107,3 +107,9 @@ log_cli = false log_cli_level = "INFO" log_cli_format = "%(asctime)s [%(levelname)8s] %(message)s (%(filename)s:%(lineno)s)" log_cli_date_format = "%Y-%m-%d %H:%M:%S" +asyncio_mode = "strict" +asyncio_default_fixture_loop_scope = "function" +filterwarnings = [ + 'ignore:.*__fields.*:pydantic.PydanticDeprecatedSince20:unittest.mock', + 'ignore:.*__fields_set__.*:pydantic.PydanticDeprecatedSince20:unittest.mock' +] From 58414513709864cf763924e30739a2a866486deb Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Fri, 14 Feb 2025 12:31:49 +0000 Subject: [PATCH 16/19] Updated API success responses with more granular 20x codes --- src/matchbox/common/dtos.py | 12 + src/matchbox/server/api/routes.py | 17 +- test/server/api/test_routes.py | 814 +++++++++++++++--------------- 3 files changed, 431 insertions(+), 412 deletions(-) diff --git a/src/matchbox/common/dtos.py b/src/matchbox/common/dtos.py index 5652e778..2a6ffe62 100644 --- a/src/matchbox/common/dtos.py +++ b/src/matchbox/common/dtos.py @@ -155,6 +155,18 @@ class UploadStatus(BaseModel): details: str | None = None entity: BackendUploadType | None = None + _status_code_mapping = { + "ready": 200, + "complete": 200, + "failed": 400, + "awaiting_upload": 202, + "queued": 202, + "processing": 202, + } + + def get_http_code(self) -> int: + return self._status_code_mapping[self.status] + @classmethod def status_400_examples(cls) -> dict: return { diff --git a/src/matchbox/server/api/routes.py b/src/matchbox/server/api/routes.py index 4d2fbf48..3c622ad8 100644 --- a/src/matchbox/server/api/routes.py +++ b/src/matchbox/server/api/routes.py @@ -10,6 +10,7 @@ HTTPException, Query, UploadFile, + status, ) from fastapi.responses import JSONResponse, Response from starlette.exceptions import HTTPException as StarletteHTTPException @@ -107,6 +108,7 @@ def get_count(e: BackendCountableType) -> int: @app.post( "/upload/{upload_id}", responses={400: {"model": UploadStatus, **UploadStatus.status_400_examples()}}, + status_code=status.HTTP_202_ACCEPTED, ) async def upload_file( backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], @@ -183,7 +185,11 @@ async def upload_file( @app.get( "/upload/{upload_id}/status", - responses={400: {"model": UploadStatus, **UploadStatus.status_400_examples()}}, + responses={ + 200: {"model": UploadStatus}, + 202: {"model": UploadStatus}, + 400: {"model": UploadStatus, **UploadStatus.status_400_examples()}, + }, ) async def get_upload_status( upload_id: str, @@ -210,7 +216,10 @@ async def get_upload_status( ).model_dump(), ) - return source_cache.status + return JSONResponse( + status_code=source_cache.status.get_http_code(), + content=source_cache.status.model_dump(), + ) # Retrieval @@ -307,7 +316,7 @@ async def match( # Data management -@app.post("/sources") +@app.post("/sources", status_code=status.HTTP_202_ACCEPTED) async def add_source(source: Source) -> UploadStatus: """Create an upload and insert task for indexed source data.""" upload_id = metadata_store.cache_source(metadata=source) @@ -354,6 +363,7 @@ async def get_resolutions( **ModelOperationStatus.status_500_examples(), }, }, + status_code=status.HTTP_201_CREATED, ) async def insert_model( backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], model: ModelMetadata @@ -400,6 +410,7 @@ async def get_model( @app.post( "/models/{name}/results", responses={404: {"model": NotFoundError}}, + status_code=status.HTTP_202_ACCEPTED, ) async def set_results( backend: Annotated[MatchboxDBAdapter, Depends(get_backend)], name: str diff --git a/test/server/api/test_routes.py b/test/server/api/test_routes.py index 8c1f615f..de14443a 100644 --- a/test/server/api/test_routes.py +++ b/test/server/api/test_routes.py @@ -37,6 +37,9 @@ client = TestClient(app) +# General + + def test_healthcheck(): """Test the healthcheck endpoint.""" response = client.get("/health") @@ -80,68 +83,10 @@ def test_count_backend_item(get_backend: MatchboxDBAdapter): assert response.json() == {"entities": {"models": 20}} -# def test_clear_backend(): -# response = client.post("/testing/clear") -# assert response.status_code == 200 - -# def test_list_sources(): -# response = client.get("/sources") -# assert response.status_code == 200 - - -# Source endpoints - - -@patch("matchbox.server.base.BackendManager.get_backend") -def test_get_source(get_backend): - dummy_source = Source( - address=SourceAddress(full_name="foo", warehouse_hash=b"bar"), db_pk="pk" - ) - mock_backend = Mock() - mock_backend.get_source = Mock(return_value=dummy_source) - get_backend.return_value = mock_backend - - response = client.get(f"/sources/{hash_to_base64(b'bar')}/foo") - assert response.status_code == 200 - assert Source.model_validate(response.json()) - - -@patch("matchbox.server.base.BackendManager.get_backend") -def test_get_source_404(get_backend): - mock_backend = Mock() - mock_backend.get_source = Mock(side_effect=MatchboxSourceNotFoundError) - get_backend.return_value = mock_backend - - response = client.get(f"/sources/{hash_to_base64(b'bar')}/foo") - assert response.status_code == 404 - assert response.json()["entity"] == BackendRetrievableType.SOURCE - - -@patch("matchbox.server.base.BackendManager.get_backend") -def test_add_source(get_backend: Mock): - """Test the source addition endpoint.""" - # Setup - mock_backend = Mock() - mock_backend.index = Mock(return_value=None) - get_backend.return_value = mock_backend - - dummy_source = source_factory() - - # Make request - response = client.post("/sources", json=dummy_source.source.model_dump()) - - # Validate response - assert UploadStatus.model_validate(response.json()) - assert response.status_code == 200, response.json() - assert response.json()["status"] == "awaiting_upload" - assert response.json().get("id") is not None - mock_backend.index.assert_not_called() - - @patch("matchbox.server.base.BackendManager.get_backend") @patch("matchbox.server.api.routes.metadata_store") @patch("matchbox.server.api.routes.BackgroundTasks.add_task") -def test_source_upload( +def test_upload( mock_add_task: Mock, metadata_store: Mock, get_backend: Mock, s3: S3Client ): """Test uploading a file, happy path.""" @@ -178,7 +123,7 @@ def test_source_upload( # Validate response assert UploadStatus.model_validate(response.json()) - assert response.status_code == 200, response.json() + assert response.status_code == 202, response.json() assert response.json()["status"] == "queued" # Updated to check for queued status # Check both status updates were called in correct order assert metadata_store.update_status.call_args_list == [ @@ -188,6 +133,48 @@ def test_source_upload( mock_add_task.assert_called_once() # Verify background task was queued +@patch("matchbox.server.base.BackendManager.get_backend") +@patch("matchbox.server.api.routes.metadata_store") +@patch("matchbox.server.api.routes.BackgroundTasks.add_task") +def test_upload_wrong_schema( + mock_add_task: Mock, metadata_store: Mock, get_backend: Mock, s3: S3Client +): + """Test uploading a file with wrong schema.""" + # Setup + mock_backend = Mock() + mock_backend.settings.datastore.get_client.return_value = s3 + mock_backend.settings.datastore.cache_bucket_name = "test-bucket" + get_backend.return_value = mock_backend + + # Create source with results schema instead of index + dummy_source = source_factory() + + # Setup store + store = MetadataStore() + update_id = store.cache_source(dummy_source.source) + metadata_store.get.side_effect = store.get + metadata_store.update_status.side_effect = store.update_status + + # Make request with actual data instead of the hashes -- wrong schema + response = client.post( + f"/upload/{update_id}", + files={ + "file": ( + "hashes.parquet", + table_to_buffer(dummy_source.data), + "application/octet-stream", + ), + }, + ) + + # Should fail before background task starts + assert response.status_code == 400 + assert response.json()["status"] == "failed" + assert "schema mismatch" in response.json()["details"].lower() + metadata_store.update_status.assert_called_with(update_id, "failed", details=ANY) + mock_add_task.assert_not_called() # Background task should not be queued + + @patch("matchbox.server.base.BackendManager.get_backend") # Stops real backend call @patch("matchbox.server.api.routes.metadata_store") def test_upload_status_check(metadata_store: Mock, _: Mock): @@ -205,7 +192,7 @@ def test_upload_status_check(metadata_store: Mock, _: Mock): response = client.get(f"/upload/{update_id}/status") # Should return current status - assert response.status_code == 200 + assert response.status_code == 202 assert response.json()["status"] == "processing" metadata_store.update_status.assert_not_called() @@ -256,48 +243,6 @@ def test_upload_already_queued(metadata_store: Mock, _: Mock): assert response.json()["status"] == "queued" -@patch("matchbox.server.base.BackendManager.get_backend") -@patch("matchbox.server.api.routes.metadata_store") -@patch("matchbox.server.api.routes.BackgroundTasks.add_task") -def test_source_upload_wrong_schema( - mock_add_task: Mock, metadata_store: Mock, get_backend: Mock, s3: S3Client -): - """Test uploading a file with wrong schema.""" - # Setup - mock_backend = Mock() - mock_backend.settings.datastore.get_client.return_value = s3 - mock_backend.settings.datastore.cache_bucket_name = "test-bucket" - get_backend.return_value = mock_backend - - # Create source with results schema instead of index - dummy_source = source_factory() - - # Setup store - store = MetadataStore() - update_id = store.cache_source(dummy_source.source) - metadata_store.get.side_effect = store.get - metadata_store.update_status.side_effect = store.update_status - - # Make request with actual data instead of the hashes -- wrong schema - response = client.post( - f"/upload/{update_id}", - files={ - "file": ( - "hashes.parquet", - table_to_buffer(dummy_source.data), - "application/octet-stream", - ), - }, - ) - - # Should fail before background task starts - assert response.status_code == 400 - assert response.json()["status"] == "failed" - assert "schema mismatch" in response.json()["details"].lower() - metadata_store.update_status.assert_called_with(update_id, "failed", details=ANY) - mock_add_task.assert_not_called() # Background task should not be queued - - @patch("matchbox.server.api.routes.metadata_store") def test_status_check_not_found(metadata_store: Mock): """Test checking status for non-existent upload ID.""" @@ -310,96 +255,321 @@ def test_status_check_not_found(metadata_store: Mock): assert "not found or expired" in response.json()["details"].lower() -@pytest.mark.asyncio +# Retrieval + + @patch("matchbox.server.base.BackendManager.get_backend") -async def test_complete_source_upload_process(get_backend: Mock, s3: S3Client): - """Test the complete upload process from source creation through processing.""" - # Setup the backend +def test_query(get_backend: Mock): + # Mock backend mock_backend = Mock() - mock_backend.settings.datastore.get_client.return_value = s3 - mock_backend.settings.datastore.cache_bucket_name = "test-bucket" - mock_backend.index = Mock(return_value=None) + mock_backend.query = Mock( + return_value=pa.Table.from_pylist( + [ + {"source_pk": "a", "id": 1}, + {"source_pk": "b", "id": 2}, + ], + schema=SCHEMA_MB_IDS, + ) + ) get_backend.return_value = mock_backend - # Create test bucket - s3.create_bucket( - Bucket="test-bucket", - CreateBucketConfiguration={"LocationConstraint": "eu-west-2"}, + # Hit endpoint + response = client.get( + "/query", + params={ + "full_name": "foo", + "warehouse_hash_b64": hash_to_base64(b"bar"), + }, ) - # Create test data - dummy_source = source_factory() - - # Step 1: Add source - response = client.post("/sources", json=dummy_source.source.model_dump()) - assert response.status_code == 200 - upload_id = response.json()["id"] - assert response.json()["status"] == "awaiting_upload" + # Process response + buffer = BytesIO(response.content) + table = pq.read_table(buffer) - # Step 2: Upload file with real background tasks - response = client.post( - f"/upload/{upload_id}", - files={ - "file": ( - "hashes.parquet", - table_to_buffer(dummy_source.data_hashes), - "application/octet-stream", - ), - }, - ) + # Check response assert response.status_code == 200 - assert response.json()["status"] == "queued" - - # Step 3: Poll status until complete or timeout - max_attempts = 10 - current_attempt = 0 - while current_attempt < max_attempts: - response = client.get(f"/upload/{upload_id}/status") - assert response.status_code == 200 + assert table.schema.equals(SCHEMA_MB_IDS) - status = response.json()["status"] - if status == "complete": - break - elif status == "failed": - pytest.fail(f"Upload failed: {response.json().get('details')}") - elif status in ["processing", "queued"]: - await asyncio.sleep(0.1) # Small delay between polls - else: - pytest.fail(f"Unexpected status: {status}") - current_attempt += 1 +@patch("matchbox.server.base.BackendManager.get_backend") +def test_query_404_resolution(get_backend: Mock): + # Mock backend + mock_backend = Mock() + mock_backend.query = Mock(side_effect=MatchboxResolutionNotFoundError()) + get_backend.return_value = mock_backend - assert current_attempt < max_attempts, ( - "Timed out waiting for processing to complete" + # Hit endpoint + response = client.get( + "/query", + params={ + "full_name": "foo", + "warehouse_hash_b64": hash_to_base64(b"bar"), + }, ) - assert status == "complete" - - # Verify backend.index was called with correct arguments - mock_backend.index.assert_called_once() - call_args = mock_backend.index.call_args - assert call_args[1]["source"] == dummy_source.source # Check source matches - assert call_args[1]["data_hashes"].equals(dummy_source.data_hashes) # Check data - -# Model endpoints + # Check response + assert response.status_code == 404 @patch("matchbox.server.base.BackendManager.get_backend") -def test_insert_model(get_backend: Mock): +def test_query_404_source(get_backend: Mock): + # Mock backend mock_backend = Mock() + mock_backend.query = Mock(side_effect=MatchboxSourceNotFoundError()) get_backend.return_value = mock_backend - dummy = model_factory(name="test_model") - response = client.post("/models", json=dummy.model.metadata.model_dump()) + # Hit endpoint + response = client.get( + "/query", + params={ + "full_name": "foo", + "warehouse_hash_b64": hash_to_base64(b"bar"), + }, + ) - assert response.status_code == 200 - assert response.json() == { - "success": True, - "model_name": "test_model", - "operation": ModelOperationType.INSERT.value, - "details": None, - } - mock_backend.insert_model.assert_called_once_with(dummy.model.metadata) + # Check response + assert response.status_code == 404 + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_match(get_backend: Mock): + # Mock backend + mock_matches = [ + Match( + cluster=1, + source=SourceAddress(full_name="foo", warehouse_hash=b"foo"), + source_id={"1"}, + target=SourceAddress(full_name="bar", warehouse_hash=b"bar"), + target_id={"a"}, + ) + ] + mock_backend = Mock() + mock_backend.match = Mock(return_value=mock_matches) + get_backend.return_value = mock_backend + + # Hit endpoint + response = client.get( + "/match", + params={ + "target_full_names": ["foo"], + "target_warehouse_hashes_b64": [hash_to_base64(b"foo")], + "source_full_name": "bar", + "source_warehouse_hash_b64": hash_to_base64(b"bar"), + "source_pk": 1, + "resolution_name": "res", + "threshold": 50, + }, + ) + + # Check response + assert response.status_code == 200 + [Match.model_validate(m) for m in response.json()] + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_match_404_resolution(get_backend: Mock): + # Mock backend + mock_backend = Mock() + mock_backend.match = Mock(side_effect=MatchboxResolutionNotFoundError()) + get_backend.return_value = mock_backend + + # Hit endpoint + response = client.get( + "/match", + params={ + "target_full_names": ["foo"], + "target_warehouse_hashes_b64": [hash_to_base64(b"foo")], + "source_full_name": "bar", + "source_warehouse_hash_b64": hash_to_base64(b"bar"), + "source_pk": 1, + "resolution_name": "res", + }, + ) + + # Check response + assert response.status_code == 404 + assert response.json()["entity"] == BackendRetrievableType.RESOLUTION + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_match_404_source(get_backend: Mock): + # Mock backend + mock_backend = Mock() + mock_backend.match = Mock(side_effect=MatchboxSourceNotFoundError()) + get_backend.return_value = mock_backend + + # Hit endpoint + response = client.get( + "/match", + params={ + "target_full_names": ["foo"], + "target_warehouse_hashes_b64": [hash_to_base64(b"foo")], + "source_full_name": "bar", + "source_warehouse_hash_b64": hash_to_base64(b"bar"), + "source_pk": 1, + "resolution_name": "res", + }, + ) + + # Check response + assert response.status_code == 404 + assert response.json()["entity"] == BackendRetrievableType.SOURCE + + +# Data management + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_get_source(get_backend): + dummy_source = Source( + address=SourceAddress(full_name="foo", warehouse_hash=b"bar"), db_pk="pk" + ) + mock_backend = Mock() + mock_backend.get_source = Mock(return_value=dummy_source) + get_backend.return_value = mock_backend + + response = client.get(f"/sources/{hash_to_base64(b'bar')}/foo") + assert response.status_code == 200 + assert Source.model_validate(response.json()) + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_get_source_404(get_backend): + mock_backend = Mock() + mock_backend.get_source = Mock(side_effect=MatchboxSourceNotFoundError) + get_backend.return_value = mock_backend + + response = client.get(f"/sources/{hash_to_base64(b'bar')}/foo") + assert response.status_code == 404 + assert response.json()["entity"] == BackendRetrievableType.SOURCE + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_add_source(get_backend: Mock): + """Test the source addition endpoint.""" + # Setup + mock_backend = Mock() + mock_backend.index = Mock(return_value=None) + get_backend.return_value = mock_backend + + dummy_source = source_factory() + + # Make request + response = client.post("/sources", json=dummy_source.source.model_dump()) + + # Validate response + assert UploadStatus.model_validate(response.json()) + assert response.status_code == 202, response.json() + assert response.json()["status"] == "awaiting_upload" + assert response.json().get("id") is not None + mock_backend.index.assert_not_called() + + +@pytest.mark.asyncio +@patch("matchbox.server.base.BackendManager.get_backend") +async def test_complete_source_upload_process(get_backend: Mock, s3: S3Client): + """Test the complete upload process from source creation through processing.""" + # Setup the backend + mock_backend = Mock() + mock_backend.settings.datastore.get_client.return_value = s3 + mock_backend.settings.datastore.cache_bucket_name = "test-bucket" + mock_backend.index = Mock(return_value=None) + get_backend.return_value = mock_backend + + # Create test bucket + s3.create_bucket( + Bucket="test-bucket", + CreateBucketConfiguration={"LocationConstraint": "eu-west-2"}, + ) + + # Create test data + dummy_source = source_factory() + + # Step 1: Add source + response = client.post("/sources", json=dummy_source.source.model_dump()) + assert response.status_code == 202 + upload_id = response.json()["id"] + assert response.json()["status"] == "awaiting_upload" + + # Step 2: Upload file with real background tasks + response = client.post( + f"/upload/{upload_id}", + files={ + "file": ( + "hashes.parquet", + table_to_buffer(dummy_source.data_hashes), + "application/octet-stream", + ), + }, + ) + assert response.status_code == 202 + assert response.json()["status"] == "queued" + + # Step 3: Poll status until complete or timeout + max_attempts = 10 + current_attempt = 0 + while current_attempt < max_attempts: + response = client.get(f"/upload/{upload_id}/status") + assert response.status_code == 200 or response.status_code == 202 + + status = response.json()["status"] + if status == "complete": + break + elif status == "failed": + pytest.fail(f"Upload failed: {response.json().get('details')}") + elif status in ["processing", "queued"]: + await asyncio.sleep(0.1) # Small delay between polls + else: + pytest.fail(f"Unexpected status: {status}") + + current_attempt += 1 + + assert current_attempt < max_attempts, ( + "Timed out waiting for processing to complete" + ) + assert status == "complete" + assert response.status_code == 200 + + # Verify backend.index was called with correct arguments + mock_backend.index.assert_called_once() + call_args = mock_backend.index.call_args + assert call_args[1]["source"] == dummy_source.source # Check source matches + assert call_args[1]["data_hashes"].equals(dummy_source.data_hashes) # Check data + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_get_resolution_graph( + get_backend: MatchboxDBAdapter, resolution_graph: ResolutionGraph +): + """Test the resolution graph report endpoint.""" + mock_backend = Mock() + mock_backend.get_resolution_graph = Mock(return_value=resolution_graph) + get_backend.return_value = mock_backend + + response = client.get("/report/resolutions") + assert response.status_code == 200 + assert ResolutionGraph.model_validate(response.json()) + + +# Model management + + +@patch("matchbox.server.base.BackendManager.get_backend") +def test_insert_model(get_backend: Mock): + mock_backend = Mock() + get_backend.return_value = mock_backend + + dummy = model_factory(name="test_model") + response = client.post("/models", json=dummy.model.metadata.model_dump()) + + assert response.status_code == 201 + assert response.json() == { + "success": True, + "model_name": "test_model", + "operation": ModelOperationType.INSERT.value, + "details": None, + } + mock_backend.insert_model.assert_called_once_with(dummy.model.metadata) @patch("matchbox.server.base.BackendManager.get_backend") @@ -487,7 +657,7 @@ def test_model_upload( ) # Validate response - assert response.status_code == 200 + assert response.status_code == 202 assert response.json()["status"] == "queued" mock_add_task.assert_called_once() @@ -521,13 +691,13 @@ async def test_complete_model_upload_process( # Step 1: Create model response = client.post("/models", json=dummy.model.metadata.model_dump()) - assert response.status_code == 200 + assert response.status_code == 201 assert response.json()["success"] is True assert response.json()["model_name"] == dummy.model.metadata.name # Step 2: Initialize results upload response = client.post(f"/models/{dummy.model.metadata.name}/results") - assert response.status_code == 200 + assert response.status_code == 202 upload_id = response.json()["id"] assert response.json()["status"] == "awaiting_upload" @@ -542,7 +712,7 @@ async def test_complete_model_upload_process( ), }, ) - assert response.status_code == 200 + assert response.status_code == 202 assert response.json()["status"] == "queued" # Step 4: Poll status until complete or timeout @@ -552,7 +722,7 @@ async def test_complete_model_upload_process( while current_attempt < max_attempts: response = client.get(f"/upload/{upload_id}/status") - assert response.status_code == 200 + assert response.status_code == 200 or response.status_code == 202 status = response.json()["status"] if status == "complete": @@ -570,6 +740,7 @@ async def test_complete_model_upload_process( "Timed out waiting for processing to complete" ) assert status == "complete" + assert response.status_code == 200 # Step 5: Verify results were stored correctly mock_backend.set_model_results.assert_called_once() @@ -609,65 +780,16 @@ async def test_complete_model_upload_process( @patch("matchbox.server.base.BackendManager.get_backend") -def test_delete_model(get_backend: Mock): +def test_set_results(get_backend: Mock): mock_backend = Mock() + dummy = model_factory() + mock_backend.get_model = Mock(return_value=dummy.model.metadata) get_backend.return_value = mock_backend - dummy = model_factory() - response = client.delete( - f"/models/{dummy.model.metadata.name}", params={"certain": True} - ) + response = client.post(f"/models/{dummy.model.metadata.name}/results") - assert response.status_code == 200 - assert response.json() == { - "success": True, - "model_name": dummy.model.metadata.name, - "operation": ModelOperationType.DELETE, - "details": None, - } - - -@patch("matchbox.server.base.BackendManager.get_backend") -def test_delete_model_needs_confirmation(get_backend: Mock): - mock_backend = Mock() - mock_backend.delete_model = Mock( - side_effect=MatchboxDeletionNotConfirmed(children=["dedupe1", "dedupe2"]) - ) - get_backend.return_value = mock_backend - - dummy = model_factory() - response = client.delete(f"/models/{dummy.model.metadata.name}") - - assert response.status_code == 409 - assert response.json()["success"] is False - message = response.json()["details"] - assert "dedupe1" in message and "dedupe2" in message - - -@patch("matchbox.server.base.BackendManager.get_backend") -def test_get_results(get_backend: Mock): - mock_backend = Mock() - dummy = model_factory() - mock_backend.get_model_results = Mock(return_value=dummy.data) - get_backend.return_value = mock_backend - - response = client.get(f"/models/{dummy.model.metadata.name}/results") - - assert response.status_code == 200 - assert response.headers["content-type"] == "application/octet-stream" - - -@patch("matchbox.server.base.BackendManager.get_backend") -def test_set_results(get_backend: Mock): - mock_backend = Mock() - dummy = model_factory() - mock_backend.get_model = Mock(return_value=dummy.model.metadata) - get_backend.return_value = mock_backend - - response = client.post(f"/models/{dummy.model.metadata.name}/results") - - assert response.status_code == 200 - assert response.json()["status"] == "awaiting_upload" + assert response.status_code == 202 + assert response.json()["status"] == "awaiting_upload" @patch("matchbox.server.base.BackendManager.get_backend") @@ -684,16 +806,16 @@ def test_set_results_model_not_found(get_backend: Mock): @patch("matchbox.server.base.BackendManager.get_backend") -def test_get_truth(get_backend: Mock): +def test_get_results(get_backend: Mock): mock_backend = Mock() dummy = model_factory() - mock_backend.get_model_truth = Mock(return_value=0.95) + mock_backend.get_model_results = Mock(return_value=dummy.data) get_backend.return_value = mock_backend - response = client.get(f"/models/{dummy.model.metadata.name}/truth") + response = client.get(f"/models/{dummy.model.metadata.name}/results") assert response.status_code == 200 - assert response.json() == 0.95 + assert response.headers["content-type"] == "application/octet-stream" @patch("matchbox.server.base.BackendManager.get_backend") @@ -728,46 +850,34 @@ def test_set_truth_invalid_value(get_backend: Mock): @patch("matchbox.server.base.BackendManager.get_backend") -def test_get_ancestors(get_backend: Mock): +def test_get_truth(get_backend: Mock): mock_backend = Mock() dummy = model_factory() - mock_ancestors = [ - ModelAncestor(name="parent_model", truth=0.7), - ModelAncestor(name="grandparent_model", truth=0.97), - ] - mock_backend.get_model_ancestors = Mock(return_value=mock_ancestors) + mock_backend.get_model_truth = Mock(return_value=0.95) get_backend.return_value = mock_backend - response = client.get(f"/models/{dummy.model.metadata.name}/ancestors") + response = client.get(f"/models/{dummy.model.metadata.name}/truth") assert response.status_code == 200 - assert len(response.json()) == 2 - assert [ModelAncestor.model_validate(a) for a in response.json()] == mock_ancestors + assert response.json() == 0.95 @patch("matchbox.server.base.BackendManager.get_backend") -def test_set_ancestors_cache(get_backend: Mock): - """Test setting the ancestors cache for a model.""" +def test_get_ancestors(get_backend: Mock): mock_backend = Mock() dummy = model_factory() - get_backend.return_value = mock_backend - - ancestors_data = [ + mock_ancestors = [ ModelAncestor(name="parent_model", truth=0.7), - ModelAncestor(name="grandparent_model", truth=0.8), + ModelAncestor(name="grandparent_model", truth=0.97), ] + mock_backend.get_model_ancestors = Mock(return_value=mock_ancestors) + get_backend.return_value = mock_backend - response = client.patch( - f"/models/{dummy.model.metadata.name}/ancestors_cache", - json=[a.model_dump() for a in ancestors_data], - ) + response = client.get(f"/models/{dummy.model.metadata.name}/ancestors") assert response.status_code == 200 - assert response.json()["success"] is True - assert response.json()["operation"] == ModelOperationType.UPDATE_ANCESTOR_CACHE - mock_backend.set_model_ancestors_cache.assert_called_once_with( - model=dummy.model.metadata.name, ancestors_cache=ancestors_data - ) + assert len(response.json()) == 2 + assert [ModelAncestor.model_validate(a) for a in response.json()] == mock_ancestors @patch("matchbox.server.base.BackendManager.get_backend") @@ -789,176 +899,62 @@ def test_get_ancestors_cache(get_backend: Mock): assert [ModelAncestor.model_validate(a) for a in response.json()] == mock_ancestors -# Query and match endpoints - - @patch("matchbox.server.base.BackendManager.get_backend") -def test_query(get_backend: Mock): - # Mock backend +def test_set_ancestors_cache(get_backend: Mock): + """Test setting the ancestors cache for a model.""" mock_backend = Mock() - mock_backend.query = Mock( - return_value=pa.Table.from_pylist( - [ - {"source_pk": "a", "id": 1}, - {"source_pk": "b", "id": 2}, - ], - schema=SCHEMA_MB_IDS, - ) - ) + dummy = model_factory() get_backend.return_value = mock_backend - # Hit endpoint - response = client.get( - "/query", - params={ - "full_name": "foo", - "warehouse_hash_b64": hash_to_base64(b"bar"), - }, - ) - - # Process response - buffer = BytesIO(response.content) - table = pq.read_table(buffer) - - # Check response - assert response.status_code == 200 - assert table.schema.equals(SCHEMA_MB_IDS) - - -@patch("matchbox.server.base.BackendManager.get_backend") -def test_query_404_resolution(get_backend: Mock): - # Mock backend - mock_backend = Mock() - mock_backend.query = Mock(side_effect=MatchboxResolutionNotFoundError()) - get_backend.return_value = mock_backend + ancestors_data = [ + ModelAncestor(name="parent_model", truth=0.7), + ModelAncestor(name="grandparent_model", truth=0.8), + ] - # Hit endpoint - response = client.get( - "/query", - params={ - "full_name": "foo", - "warehouse_hash_b64": hash_to_base64(b"bar"), - }, + response = client.patch( + f"/models/{dummy.model.metadata.name}/ancestors_cache", + json=[a.model_dump() for a in ancestors_data], ) - # Check response - assert response.status_code == 404 - - -@patch("matchbox.server.base.BackendManager.get_backend") -def test_query_404_source(get_backend: Mock): - # Mock backend - mock_backend = Mock() - mock_backend.query = Mock(side_effect=MatchboxSourceNotFoundError()) - get_backend.return_value = mock_backend - - # Hit endpoint - response = client.get( - "/query", - params={ - "full_name": "foo", - "warehouse_hash_b64": hash_to_base64(b"bar"), - }, + assert response.status_code == 200 + assert response.json()["success"] is True + assert response.json()["operation"] == ModelOperationType.UPDATE_ANCESTOR_CACHE + mock_backend.set_model_ancestors_cache.assert_called_once_with( + model=dummy.model.metadata.name, ancestors_cache=ancestors_data ) - # Check response - assert response.status_code == 404 - @patch("matchbox.server.base.BackendManager.get_backend") -def test_match(get_backend: Mock): - # Mock backend - mock_matches = [ - Match( - cluster=1, - source=SourceAddress(full_name="foo", warehouse_hash=b"foo"), - source_id={"1"}, - target=SourceAddress(full_name="bar", warehouse_hash=b"bar"), - target_id={"a"}, - ) - ] +def test_delete_model(get_backend: Mock): mock_backend = Mock() - mock_backend.match = Mock(return_value=mock_matches) get_backend.return_value = mock_backend - # Hit endpoint - response = client.get( - "/match", - params={ - "target_full_names": ["foo"], - "target_warehouse_hashes_b64": [hash_to_base64(b"foo")], - "source_full_name": "bar", - "source_warehouse_hash_b64": hash_to_base64(b"bar"), - "source_pk": 1, - "resolution_name": "res", - "threshold": 50, - }, + dummy = model_factory() + response = client.delete( + f"/models/{dummy.model.metadata.name}", params={"certain": True} ) - # Check response assert response.status_code == 200 - [Match.model_validate(m) for m in response.json()] + assert response.json() == { + "success": True, + "model_name": dummy.model.metadata.name, + "operation": ModelOperationType.DELETE, + "details": None, + } @patch("matchbox.server.base.BackendManager.get_backend") -def test_match_404_resolution(get_backend: Mock): - # Mock backend +def test_delete_model_needs_confirmation(get_backend: Mock): mock_backend = Mock() - mock_backend.match = Mock(side_effect=MatchboxResolutionNotFoundError()) - get_backend.return_value = mock_backend - - # Hit endpoint - response = client.get( - "/match", - params={ - "target_full_names": ["foo"], - "target_warehouse_hashes_b64": [hash_to_base64(b"foo")], - "source_full_name": "bar", - "source_warehouse_hash_b64": hash_to_base64(b"bar"), - "source_pk": 1, - "resolution_name": "res", - }, + mock_backend.delete_model = Mock( + side_effect=MatchboxDeletionNotConfirmed(children=["dedupe1", "dedupe2"]) ) - - # Check response - assert response.status_code == 404 - assert response.json()["entity"] == BackendRetrievableType.RESOLUTION - - -@patch("matchbox.server.base.BackendManager.get_backend") -def test_match_404_source(get_backend: Mock): - # Mock backend - mock_backend = Mock() - mock_backend.match = Mock(side_effect=MatchboxSourceNotFoundError()) get_backend.return_value = mock_backend - # Hit endpoint - response = client.get( - "/match", - params={ - "target_full_names": ["foo"], - "target_warehouse_hashes_b64": [hash_to_base64(b"foo")], - "source_full_name": "bar", - "source_warehouse_hash_b64": hash_to_base64(b"bar"), - "source_pk": 1, - "resolution_name": "res", - }, - ) - - # Check response - assert response.status_code == 404 - assert response.json()["entity"] == BackendRetrievableType.SOURCE - - -@patch("matchbox.server.base.BackendManager.get_backend") -def test_get_resolution_graph( - get_backend: MatchboxDBAdapter, resolution_graph: ResolutionGraph -): - """Test the resolution graph report endpoint.""" - mock_backend = Mock() - mock_backend.get_resolution_graph = Mock(return_value=resolution_graph) - get_backend.return_value = mock_backend + dummy = model_factory() + response = client.delete(f"/models/{dummy.model.metadata.name}") - response = client.get("/report/resolutions") - assert response.status_code == 200 - assert ResolutionGraph.model_validate(response.json()) + assert response.status_code == 409 + assert response.json()["success"] is False + message = response.json()["details"] + assert "dedupe1" in message and "dedupe2" in message From d7b44b0186732c88f4b6532aa3e6add1d44a9f42 Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Fri, 14 Feb 2025 13:13:00 +0000 Subject: [PATCH 17/19] Fixed broken status code handling in the _handler --- src/matchbox/client/_handler.py | 2 +- test/client/test_helpers.py | 6 +-- test/client/test_model.py | 10 ++-- test/server/api/test_routes.py | 82 +++++++++++++++++++++++++++++++++ 4 files changed, 91 insertions(+), 9 deletions(-) diff --git a/src/matchbox/client/_handler.py b/src/matchbox/client/_handler.py index 17394e55..57f76d8a 100644 --- a/src/matchbox/client/_handler.py +++ b/src/matchbox/client/_handler.py @@ -60,7 +60,7 @@ def handle_http_code(res: httpx.Response) -> httpx.Response: """Handle HTTP status codes and raise appropriate exceptions.""" res.read() - if res.status_code == 200: + if 299 >= res.status_code >= 200: return res if res.status_code == 400: diff --git a/test/client/test_helpers.py b/test/client/test_helpers.py index 5ac67475..6bfe0005 100644 --- a/test/client/test_helpers.py +++ b/test/client/test_helpers.py @@ -446,7 +446,7 @@ def test_index_success(MockSource: Mock, matchbox_api: MockRouter): # Mock the initial source metadata upload source_route = matchbox_api.post("/sources").mock( return_value=Response( - 200, + 202, json=UploadStatus( id="test-upload-id", status="awaiting_upload", @@ -519,7 +519,7 @@ def test_index_with_columns( # Mock the API endpoints source_route = matchbox_api.post("/sources").mock( return_value=Response( - 200, + 202, json=UploadStatus( id="test-upload-id", status="awaiting_upload", @@ -576,7 +576,7 @@ def test_index_upload_failure(MockSource: Mock, matchbox_api: MockRouter): # Mock successful source creation source_route = matchbox_api.post("/sources").mock( return_value=Response( - 200, + 202, json=UploadStatus( id="test-upload-id", status="awaiting_upload", diff --git a/test/client/test_model.py b/test/client/test_model.py index 8b8ece0c..1a8e0505 100644 --- a/test/client/test_model.py +++ b/test/client/test_model.py @@ -33,7 +33,7 @@ def test_insert_model(matchbox_api: MockRouter): # Mock the POST /models endpoint route = matchbox_api.post("/models").mock( return_value=Response( - 200, + 201, json=ModelOperationStatus( success=True, model_name=dummy.model.metadata.name, @@ -83,7 +83,7 @@ def test_results_getter(matchbox_api: MockRouter): # Mock the GET /models/{name}/results endpoint route = matchbox_api.get(f"/models/{dummy.model.metadata.name}/results").mock( - return_value=Response(200, content=table_to_buffer(dummy.data).read()) + return_value=Response(202, content=table_to_buffer(dummy.data).read()) ) # Get results @@ -123,7 +123,7 @@ def test_results_setter(matchbox_api: MockRouter): # Mock the endpoints needed for results upload init_route = matchbox_api.post(f"/models/{dummy.model.metadata.name}/results").mock( return_value=Response( - 200, + 202, json=UploadStatus( id="test-upload-id", status="awaiting_upload", @@ -134,7 +134,7 @@ def test_results_setter(matchbox_api: MockRouter): upload_route = matchbox_api.post("/upload/test-upload-id").mock( return_value=Response( - 200, + 202, json=UploadStatus( id="test-upload-id", status="processing", @@ -174,7 +174,7 @@ def test_results_setter_upload_failure(matchbox_api: MockRouter): # Mock the initial POST endpoint init_route = matchbox_api.post(f"/models/{dummy.model.metadata.name}/results").mock( return_value=Response( - 200, + 202, json=UploadStatus( id="test-upload-id", status="awaiting_upload", diff --git a/test/server/api/test_routes.py b/test/server/api/test_routes.py index de14443a..83a5598b 100644 --- a/test/server/api/test_routes.py +++ b/test/server/api/test_routes.py @@ -13,6 +13,7 @@ BackendRetrievableType, ModelAncestor, ModelOperationType, + NotFoundError, UploadStatus, ) from matchbox.common.exceptions import ( @@ -924,6 +925,66 @@ def test_set_ancestors_cache(get_backend: Mock): ) +@pytest.mark.parametrize( + "endpoint", + ["results", "truth", "ancestors", "ancestors_cache"], +) +@patch("matchbox.server.base.BackendManager.get_backend") +def test_model_get_endpoints_404( + get_backend: Mock, + endpoint: str, +) -> None: + """Test 404 responses for model GET endpoints when model doesn't exist.""" + # Setup backend mock + mock_backend = Mock() + mock_method = getattr(mock_backend, f"get_model_{endpoint}") + mock_method.side_effect = MatchboxResolutionNotFoundError() + get_backend.return_value = mock_backend + + # Make request + response = client.get(f"/models/nonexistent-model/{endpoint}") + + # Verify response + assert response.status_code == 404 + error = NotFoundError.model_validate(response.json()) + assert error.entity == BackendRetrievableType.RESOLUTION + + +@pytest.mark.parametrize( + ("endpoint", "payload"), + [ + ("truth", 0.95), + ( + "ancestors_cache", + [ + ModelAncestor(name="parent_model", truth=0.7).model_dump(), + ModelAncestor(name="grandparent_model", truth=0.8).model_dump(), + ], + ), + ], +) +@patch("matchbox.server.base.BackendManager.get_backend") +def test_model_patch_endpoints_404( + get_backend: Mock, + endpoint: str, + payload: float | list[dict[str, Any]], +) -> None: + """Test 404 responses for model PATCH endpoints when model doesn't exist.""" + # Setup backend mock + mock_backend = Mock() + mock_method = getattr(mock_backend, f"set_model_{endpoint}") + mock_method.side_effect = MatchboxResolutionNotFoundError() + get_backend.return_value = mock_backend + + # Make request + response = client.patch(f"/models/nonexistent-model/{endpoint}", json=payload) + + # Verify response + assert response.status_code == 404 + error = NotFoundError.model_validate(response.json()) + assert error.entity == BackendRetrievableType.RESOLUTION + + @patch("matchbox.server.base.BackendManager.get_backend") def test_delete_model(get_backend: Mock): mock_backend = Mock() @@ -958,3 +1019,24 @@ def test_delete_model_needs_confirmation(get_backend: Mock): assert response.json()["success"] is False message = response.json()["details"] assert "dedupe1" in message and "dedupe2" in message + + +@pytest.mark.parametrize( + "certain", + [True, False], +) +@patch("matchbox.server.base.BackendManager.get_backend") +def test_delete_model_404(get_backend: Mock, certain: bool) -> None: + """Test 404 response when trying to delete a non-existent model.""" + # Setup backend mock + mock_backend = Mock() + mock_backend.delete_model.side_effect = MatchboxResolutionNotFoundError() + get_backend.return_value = mock_backend + + # Make request + response = client.delete("/models/nonexistent-model", params={"certain": certain}) + + # Verify response + assert response.status_code == 404 + error = NotFoundError.model_validate(response.json()) + assert error.entity == BackendRetrievableType.RESOLUTION From 0610a4dc450028990429ae302ad927432c9cc929 Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Fri, 14 Feb 2025 14:44:34 +0000 Subject: [PATCH 18/19] Updated _handler.delete_model to match other deletion function signatures --- src/matchbox/client/_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/matchbox/client/_handler.py b/src/matchbox/client/_handler.py index 57f76d8a..e5e2ddfb 100644 --- a/src/matchbox/client/_handler.py +++ b/src/matchbox/client/_handler.py @@ -303,7 +303,7 @@ def get_model_ancestors_cache(name: str) -> list[ModelAncestor]: return [ModelAncestor.model_validate(m) for m in res.json()] -def delete_model(name: str, certain: bool | None = False) -> ModelOperationStatus: +def delete_model(name: str, certain: bool = False) -> ModelOperationStatus: """Delete a model in Matchbox.""" res = CLIENT.delete(f"/models/{name}", params={"certain": certain}) return ModelOperationStatus.model_validate(res.json()) From 7b1f91fe2517e73116247d012a754828bf6874be Mon Sep 17 00:00:00 2001 From: Will Langdale Date: Fri, 14 Feb 2025 16:21:29 +0000 Subject: [PATCH 19/19] Tidied up response codes: 202 for the upload of metadata and data, 200 for status updates --- src/matchbox/common/dtos.py | 8 +++++--- src/matchbox/server/api/routes.py | 25 +++++++++++++++++-------- test/client/test_helpers.py | 4 ++-- test/client/test_model.py | 2 +- test/server/api/test_routes.py | 6 +++--- 5 files changed, 28 insertions(+), 17 deletions(-) diff --git a/src/matchbox/common/dtos.py b/src/matchbox/common/dtos.py index 2a6ffe62..b4afb422 100644 --- a/src/matchbox/common/dtos.py +++ b/src/matchbox/common/dtos.py @@ -160,11 +160,13 @@ class UploadStatus(BaseModel): "complete": 200, "failed": 400, "awaiting_upload": 202, - "queued": 202, - "processing": 202, + "queued": 200, + "processing": 200, } - def get_http_code(self) -> int: + def get_http_code(self, status: bool) -> int: + if self.status == "failed": + return 400 return self._status_code_mapping[self.status] @classmethod diff --git a/src/matchbox/server/api/routes.py b/src/matchbox/server/api/routes.py index 3c622ad8..3f7b412d 100644 --- a/src/matchbox/server/api/routes.py +++ b/src/matchbox/server/api/routes.py @@ -107,7 +107,9 @@ def get_count(e: BackendCountableType) -> int: @app.post( "/upload/{upload_id}", - responses={400: {"model": UploadStatus, **UploadStatus.status_400_examples()}}, + responses={ + 400: {"model": UploadStatus, **UploadStatus.status_400_examples()}, + }, status_code=status.HTTP_202_ACCEPTED, ) async def upload_file( @@ -180,16 +182,24 @@ async def upload_file( metadata_store=metadata_store, ) - return metadata_store.get(upload_id).status + source_cache = metadata_store.get(upload_id) + + # Check for error in async task + if source_cache.status.status == "failed": + raise HTTPException( + status_code=400, + detail=source_cache.status.model_dump(), + ) + else: + return source_cache.status @app.get( "/upload/{upload_id}/status", responses={ - 200: {"model": UploadStatus}, - 202: {"model": UploadStatus}, 400: {"model": UploadStatus, **UploadStatus.status_400_examples()}, }, + status_code=status.HTTP_200_OK, ) async def get_upload_status( upload_id: str, @@ -216,10 +226,7 @@ async def get_upload_status( ).model_dump(), ) - return JSONResponse( - status_code=source_cache.status.get_http_code(), - content=source_cache.status.model_dump(), - ) + return source_cache.status # Retrieval @@ -237,6 +244,7 @@ async def query( threshold: int | None = None, limit: int | None = None, ) -> ParquetResponse: + """Query Matchbox for matches based on a source address.""" source_address = SourceAddress( full_name=full_name, warehouse_hash=warehouse_hash_b64 ) @@ -280,6 +288,7 @@ async def match( resolution_name: str, threshold: int | None = None, ) -> list[Match]: + """Match a source primary key against a list of target addresses.""" targets = [ SourceAddress(full_name=n, warehouse_hash=w) for n, w in zip(target_full_names, target_warehouse_hashes_b64, strict=True) diff --git a/test/client/test_helpers.py b/test/client/test_helpers.py index 6bfe0005..194e4b35 100644 --- a/test/client/test_helpers.py +++ b/test/client/test_helpers.py @@ -458,7 +458,7 @@ def test_index_success(MockSource: Mock, matchbox_api: MockRouter): # Mock the data upload upload_route = matchbox_api.post("/upload/test-upload-id").mock( return_value=Response( - 200, + 202, json=UploadStatus( id="test-upload-id", status="complete", entity=BackendUploadType.INDEX ).model_dump(), @@ -530,7 +530,7 @@ def test_index_with_columns( upload_route = matchbox_api.post("/upload/test-upload-id").mock( return_value=Response( - 200, + 202, json=UploadStatus( id="test-upload-id", status="complete", entity=BackendUploadType.INDEX ).model_dump(), diff --git a/test/client/test_model.py b/test/client/test_model.py index 1a8e0505..d192d106 100644 --- a/test/client/test_model.py +++ b/test/client/test_model.py @@ -83,7 +83,7 @@ def test_results_getter(matchbox_api: MockRouter): # Mock the GET /models/{name}/results endpoint route = matchbox_api.get(f"/models/{dummy.model.metadata.name}/results").mock( - return_value=Response(202, content=table_to_buffer(dummy.data).read()) + return_value=Response(200, content=table_to_buffer(dummy.data).read()) ) # Get results diff --git a/test/server/api/test_routes.py b/test/server/api/test_routes.py index 83a5598b..abf31bf4 100644 --- a/test/server/api/test_routes.py +++ b/test/server/api/test_routes.py @@ -193,7 +193,7 @@ def test_upload_status_check(metadata_store: Mock, _: Mock): response = client.get(f"/upload/{update_id}/status") # Should return current status - assert response.status_code == 202 + assert response.status_code == 200 assert response.json()["status"] == "processing" metadata_store.update_status.assert_not_called() @@ -511,7 +511,7 @@ async def test_complete_source_upload_process(get_backend: Mock, s3: S3Client): current_attempt = 0 while current_attempt < max_attempts: response = client.get(f"/upload/{upload_id}/status") - assert response.status_code == 200 or response.status_code == 202 + assert response.status_code == 200 status = response.json()["status"] if status == "complete": @@ -723,7 +723,7 @@ async def test_complete_model_upload_process( while current_attempt < max_attempts: response = client.get(f"/upload/{upload_id}/status") - assert response.status_code == 200 or response.status_code == 202 + assert response.status_code == 200 status = response.json()["status"] if status == "complete":