From 7a7494643f7e6553266f5861520062200e3f8b58 Mon Sep 17 00:00:00 2001 From: Stephan Tulkens Date: Fri, 6 Dec 2024 14:14:08 +0100 Subject: [PATCH] feat: clean-up annoy metrics (#42) --- vicinity/backends/annoy.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/vicinity/backends/annoy.py b/vicinity/backends/annoy.py index 3961cc0..053a6db 100644 --- a/vicinity/backends/annoy.py +++ b/vicinity/backends/annoy.py @@ -17,17 +17,17 @@ class AnnoyArgs(BaseArgs): dim: int = 0 metric: str = "cosine" + internal_metric: str = "dot" trees: int = 100 length: int | None = None class AnnoyBackend(AbstractBackend[AnnoyArgs]): argument_class = AnnoyArgs - supported_metrics = {Metric.COSINE, Metric.EUCLIDEAN, Metric.INNER_PRODUCT} + supported_metrics = {Metric.COSINE, Metric.EUCLIDEAN} inverse_metric_mapping = { Metric.COSINE: "dot", Metric.EUCLIDEAN: "euclidean", - Metric.INNER_PRODUCT: "dot", } def __init__( @@ -56,18 +56,21 @@ def from_vectors( if metric_enum not in cls.supported_metrics: raise ValueError(f"Metric '{metric_enum.value}' is not supported by AnnoyBackend.") - metric = cls._map_metric_to_string(metric_enum) + metric_string = metric_enum.value + internal_metric = cls._map_metric_to_string(metric_enum) - if metric == "dot": + if metric_enum == Metric.COSINE: vectors = normalize(vectors) dim = vectors.shape[1] - index = AnnoyIndex(f=dim, metric=metric) # type: ignore + index = AnnoyIndex(f=dim, metric=internal_metric) # type: ignore for i, vector in enumerate(vectors): index.add_item(i, vector) index.build(trees) - arguments = AnnoyArgs(dim=dim, metric=metric, trees=trees, length=len(vectors)) # type: ignore + arguments = AnnoyArgs( + dim=dim, metric=metric_string, trees=trees, length=len(vectors), internal_metric=internal_metric + ) # type: ignore return AnnoyBackend(index, arguments=arguments) @property @@ -89,7 +92,7 @@ def load(cls: type[AnnoyBackend], base_path: Path) -> AnnoyBackend: """Load the vectors from a path.""" path = Path(base_path) / "index.bin" arguments = AnnoyArgs.load(base_path / "arguments.json") - index = AnnoyIndex(arguments.dim, arguments.metric) # type: ignore + index = AnnoyIndex(arguments.dim, arguments.internal_metric) # type: ignore index.load(str(path)) return cls(index, arguments=arguments) @@ -106,11 +109,11 @@ def query(self, vectors: npt.NDArray, k: int) -> QueryResult: """Query the backend.""" out = [] for vec in vectors: - if self.arguments.metric == "dot": + if self.arguments.metric == "cosine": vec = normalize(vec) indices, scores = self.index.get_nns_by_vector(vec, k, include_distances=True) scores_array = np.asarray(scores) - if self.arguments.metric == "dot": + if self.arguments.metric == "cosine": # Convert cosine similarity to cosine distance scores_array = 1 - scores_array out.append((np.asarray(indices), scores_array))