diff --git a/vicinity/backends/annoy.py b/vicinity/backends/annoy.py index 053a6db..ce2184b 100644 --- a/vicinity/backends/annoy.py +++ b/vicinity/backends/annoy.py @@ -16,7 +16,7 @@ @dataclass class AnnoyArgs(BaseArgs): dim: int = 0 - metric: str = "cosine" + metric: Metric = Metric.COSINE internal_metric: str = "dot" trees: int = 100 length: int | None = None @@ -25,7 +25,7 @@ class AnnoyArgs(BaseArgs): class AnnoyBackend(AbstractBackend[AnnoyArgs]): argument_class = AnnoyArgs supported_metrics = {Metric.COSINE, Metric.EUCLIDEAN} - inverse_metric_mapping = { + inverse_metric_mapping: dict[Metric, str] = { Metric.COSINE: "dot", Metric.EUCLIDEAN: "euclidean", } @@ -56,7 +56,6 @@ def from_vectors( if metric_enum not in cls.supported_metrics: raise ValueError(f"Metric '{metric_enum.value}' is not supported by AnnoyBackend.") - metric_string = metric_enum.value internal_metric = cls._map_metric_to_string(metric_enum) if metric_enum == Metric.COSINE: @@ -68,9 +67,7 @@ def from_vectors( index.add_item(i, vector) index.build(trees) - arguments = AnnoyArgs( - dim=dim, metric=metric_string, trees=trees, length=len(vectors), internal_metric=internal_metric - ) # type: ignore + arguments = AnnoyArgs(dim=dim, metric=metric, trees=trees, length=len(vectors), internal_metric=internal_metric) # type: ignore return AnnoyBackend(index, arguments=arguments) @property @@ -91,8 +88,10 @@ def __len__(self) -> int: def load(cls: type[AnnoyBackend], base_path: Path) -> AnnoyBackend: """Load the vectors from a path.""" path = Path(base_path) / "index.bin" + arguments = AnnoyArgs.load(base_path / "arguments.json") - index = AnnoyIndex(arguments.dim, arguments.internal_metric) # type: ignore + metric = cls._map_metric_to_string(arguments.metric) + index = AnnoyIndex(arguments.dim, metric) # type: ignore index.load(str(path)) return cls(index, arguments=arguments) @@ -109,11 +108,11 @@ def query(self, vectors: npt.NDArray, k: int) -> QueryResult: """Query the backend.""" out = [] for vec in vectors: - if self.arguments.metric == "cosine": + if self.arguments.metric == Metric.COSINE: vec = normalize(vec) indices, scores = self.index.get_nns_by_vector(vec, k, include_distances=True) scores_array = np.asarray(scores) - if self.arguments.metric == "cosine": + if self.arguments.metric == Metric.COSINE: # Convert cosine similarity to cosine distance scores_array = 1 - scores_array out.append((np.asarray(indices), scores_array)) diff --git a/vicinity/backends/base.py b/vicinity/backends/base.py index 4876047..788efb6 100644 --- a/vicinity/backends/base.py +++ b/vicinity/backends/base.py @@ -14,19 +14,25 @@ @dataclass class BaseArgs: + metric: Metric + def dump(self, file: Path) -> None: """Dump the arguments to a file.""" with open(file, "w") as f: - json.dump(asdict(self), f) + d = self.dict() + d["metric"] = d["metric"].value + json.dump(d, f) @classmethod def load(cls: type[ArgType], file: Path) -> ArgType: """Load the arguments from a file.""" with open(file, "r") as f: - return cls(**json.load(f)) + data = json.load(f) + data["metric"] = Metric.from_string(data["metric"]) + return cls(**data) def dict(self) -> dict[str, Any]: - """Dump the arguments to a string.""" + """Dump the arguments to a dict.""" return asdict(self) diff --git a/vicinity/backends/basic.py b/vicinity/backends/basic.py index 3ca4325..f727a63 100644 --- a/vicinity/backends/basic.py +++ b/vicinity/backends/basic.py @@ -15,7 +15,7 @@ @dataclass class BasicArgs(BaseArgs): - metric: str = "cosine" + metric: Metric = Metric.COSINE class BasicBackend(AbstractBackend[BasicArgs], ABC): @@ -72,11 +72,10 @@ def from_vectors(cls, vectors: npt.NDArray, metric: Union[str, Metric] = "cosine if metric_enum not in cls.supported_metrics: raise ValueError(f"Metric '{metric_enum.value}' is not supported by BasicBackend.") - metric = metric_enum.value - arguments = BasicArgs(metric=metric) - if metric == "cosine": + arguments = BasicArgs(metric=metric_enum) + if metric_enum == Metric.COSINE: return CosineBasicBackend(vectors, arguments) - elif metric == "euclidean": + elif metric_enum == Metric.EUCLIDEAN: return EuclideanBasicBackend(vectors, arguments) else: raise ValueError(f"Unsupported metric: {metric}") @@ -88,9 +87,9 @@ def load(cls, folder: Path) -> BasicBackend: arguments = BasicArgs.load(folder / "arguments.json") with open(path, "rb") as f: vectors = np.load(f) - if arguments.metric == "cosine": + if arguments.metric == Metric.COSINE: return CosineBasicBackend(vectors, arguments) - elif arguments.metric == "euclidean": + elif arguments.metric == Metric.EUCLIDEAN: return EuclideanBasicBackend(vectors, arguments) else: raise ValueError(f"Unsupported metric: {arguments.metric}") diff --git a/vicinity/backends/faiss.py b/vicinity/backends/faiss.py index 4466347..fc0ca31 100644 --- a/vicinity/backends/faiss.py +++ b/vicinity/backends/faiss.py @@ -36,7 +36,7 @@ class FaissArgs(BaseArgs): dim: int = 0 index_type: str = "flat" - metric: str = "cosine" + metric: Metric = Metric.COSINE nlist: int = 100 m: int = 8 nbits: int = 8 @@ -122,7 +122,7 @@ def from_vectors( # noqa: C901 arguments = FaissArgs( dim=dim, index_type=index_type, - metric=metric_enum.value, + metric=metric_enum, nlist=nlist, m=m, nbits=nbits, diff --git a/vicinity/backends/hnsw.py b/vicinity/backends/hnsw.py index 09b4512..e9c124b 100644 --- a/vicinity/backends/hnsw.py +++ b/vicinity/backends/hnsw.py @@ -15,7 +15,7 @@ @dataclass class HNSWArgs(BaseArgs): dim: int = 0 - metric: str = "cosine" + metric: Metric = Metric.COSINE ef_construction: int = 200 m: int = 16 @@ -58,7 +58,7 @@ def from_vectors( index = HnswIndex(space=metric, dim=dim) index.init_index(max_elements=vectors.shape[0], ef_construction=ef_construction, M=m) index.add_items(vectors) - arguments = HNSWArgs(dim=dim, metric=metric, ef_construction=ef_construction, m=m) + arguments = HNSWArgs(dim=dim, metric=metric_enum, ef_construction=ef_construction, m=m) return HNSWBackend(index, arguments=arguments) @property @@ -80,7 +80,8 @@ def load(cls: type[HNSWBackend], base_path: Path) -> HNSWBackend: """Load the vectors from a path.""" path = Path(base_path) / "index.bin" arguments = HNSWArgs.load(base_path / "arguments.json") - index = HnswIndex(space=arguments.metric, dim=arguments.dim) + mapped_metric = cls.inverse_metric_mapping[arguments.metric] + index = HnswIndex(space=mapped_metric, dim=arguments.dim) index.load_index(str(path)) return cls(index, arguments=arguments) diff --git a/vicinity/backends/pynndescent.py b/vicinity/backends/pynndescent.py index 964ab08..3fc1c7c 100644 --- a/vicinity/backends/pynndescent.py +++ b/vicinity/backends/pynndescent.py @@ -16,7 +16,7 @@ @dataclass class PyNNDescentArgs(BaseArgs): n_neighbors: int = 15 - metric: str = "cosine" + metric: Metric = Metric.COSINE class PyNNDescentBackend(AbstractBackend[PyNNDescentArgs]): @@ -49,7 +49,7 @@ def from_vectors( metric = metric_enum.value index = NNDescent(vectors, n_neighbors=n_neighbors, metric=metric, **kwargs) - arguments = PyNNDescentArgs(n_neighbors=n_neighbors, metric=metric) + arguments = PyNNDescentArgs(n_neighbors=n_neighbors, metric=metric_enum) return cls(index=index, arguments=arguments) def __len__(self) -> int: @@ -105,10 +105,7 @@ def load(cls: type[PyNNDescentBackend], base_path: Path) -> PyNNDescentBackend: arguments = PyNNDescentArgs.load(base_path / "arguments.json") vectors = np.load(Path(base_path) / "vectors.npy") - metric_enum = Metric.from_string(arguments.metric) - pynndescent_metric = metric_enum.value - - index = NNDescent(vectors, n_neighbors=arguments.n_neighbors, metric=pynndescent_metric) + index = NNDescent(vectors, n_neighbors=arguments.n_neighbors, metric=arguments.metric.value) # Load the neighbor graph if it was saved neighbor_graph_path = base_path / "neighbor_graph.npy" diff --git a/vicinity/backends/usearch.py b/vicinity/backends/usearch.py index 2671bc8..2781cdc 100644 --- a/vicinity/backends/usearch.py +++ b/vicinity/backends/usearch.py @@ -16,7 +16,7 @@ @dataclass class UsearchArgs(BaseArgs): dim: int = 0 - metric: str = "cos" + metric: Metric = Metric.COSINE connectivity: int = 16 expansion_add: int = 128 expansion_search: int = 64 @@ -67,10 +67,10 @@ def from_vectors( expansion_add=expansion_add, expansion_search=expansion_search, ) - index.add(keys=None, vectors=vectors) # type: ignore + index.add(keys=None, vectors=vectors) # type: ignore # None keys are allowed but not typed arguments = UsearchArgs( dim=dim, - metric=metric, + metric=metric_enum, connectivity=connectivity, expansion_add=expansion_add, expansion_search=expansion_search, @@ -99,7 +99,7 @@ def load(cls: type[UsearchBackend], base_path: Path) -> UsearchBackend: index = UsearchIndex( ndim=arguments.dim, - metric=arguments.metric, + metric=cls._map_metric_to_string(arguments.metric), connectivity=arguments.connectivity, expansion_add=arguments.expansion_add, expansion_search=arguments.expansion_search, @@ -122,7 +122,7 @@ def query(self, vectors: npt.NDArray, k: int) -> QueryResult: def insert(self, vectors: npt.NDArray) -> None: """Insert vectors into the backend.""" - self.index.add(None, vectors) # type: ignore + self.index.add(None, vectors) # type: ignore # None keys are allowed, but not typed. def delete(self, indices: list[int]) -> None: """Delete vectors from the index (not supported by Usearch).""" diff --git a/vicinity/backends/voyager.py b/vicinity/backends/voyager.py index 3c73bfa..5b2c316 100644 --- a/vicinity/backends/voyager.py +++ b/vicinity/backends/voyager.py @@ -4,19 +4,18 @@ from pathlib import Path from typing import Any, Union -import numpy as np from numpy import typing as npt from voyager import Index, Space from vicinity.backends.base import AbstractBackend, BaseArgs from vicinity.datatypes import Backend, QueryResult -from vicinity.utils import Metric, normalize +from vicinity.utils import Metric @dataclass class VoyagerArgs(BaseArgs): dim: int = 0 - metric: str = "cosine" + metric: Metric = Metric.COSINE ef_construction: int = 200 m: int = 16 @@ -24,14 +23,9 @@ class VoyagerArgs(BaseArgs): class VoyagerBackend(AbstractBackend[VoyagerArgs]): argument_class = VoyagerArgs supported_metrics = {Metric.COSINE, Metric.EUCLIDEAN} - inverse_metric_mapping = { - Metric.COSINE: "cosine", - Metric.EUCLIDEAN: "l2", - } - - metric_int_mapping = { - "l2": 0, - "cosine": 2, + _metric_to_space = { + Metric.COSINE: Space.Cosine, + Metric.EUCLIDEAN: Space.Euclidean, } def __init__( @@ -56,13 +50,10 @@ def from_vectors( metric_enum = Metric.from_string(metric) if metric_enum not in cls.supported_metrics: - raise ValueError( - f"Metric '{metric_enum.value}' is not supported by VoyagerBackend." - ) + raise ValueError(f"Metric '{metric_enum.value}' is not supported by VoyagerBackend.") - metric = cls._map_metric_to_string(metric_enum) + space = cls._metric_to_space[metric_enum] dim = vectors.shape[1] - space = Space(value=cls.metric_int_mapping[metric]) index = Index( space=space, num_dimensions=dim, @@ -72,7 +63,7 @@ def from_vectors( index.add_items(vectors) return cls( index, - VoyagerArgs(dim=dim, metric=metric, ef_construction=ef_construction, m=m), + VoyagerArgs(dim=dim, metric=metric_enum, ef_construction=ef_construction, m=m), ) def query(self, query: npt.NDArray, k: int) -> QueryResult: