Skip to content

Commit

Permalink
feat: clean-up annoy metrics (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
stephantul authored Dec 6, 2024
1 parent 8c7e13b commit 7a74946
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions vicinity/backends/annoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,17 @@
class AnnoyArgs(BaseArgs):
dim: int = 0
metric: str = "cosine"
internal_metric: str = "dot"
trees: int = 100
length: int | None = None


class AnnoyBackend(AbstractBackend[AnnoyArgs]):
argument_class = AnnoyArgs
supported_metrics = {Metric.COSINE, Metric.EUCLIDEAN, Metric.INNER_PRODUCT}
supported_metrics = {Metric.COSINE, Metric.EUCLIDEAN}
inverse_metric_mapping = {
Metric.COSINE: "dot",
Metric.EUCLIDEAN: "euclidean",
Metric.INNER_PRODUCT: "dot",
}

def __init__(
Expand Down Expand Up @@ -56,18 +56,21 @@ def from_vectors(
if metric_enum not in cls.supported_metrics:
raise ValueError(f"Metric '{metric_enum.value}' is not supported by AnnoyBackend.")

metric = cls._map_metric_to_string(metric_enum)
metric_string = metric_enum.value
internal_metric = cls._map_metric_to_string(metric_enum)

if metric == "dot":
if metric_enum == Metric.COSINE:
vectors = normalize(vectors)

dim = vectors.shape[1]
index = AnnoyIndex(f=dim, metric=metric) # type: ignore
index = AnnoyIndex(f=dim, metric=internal_metric) # type: ignore
for i, vector in enumerate(vectors):
index.add_item(i, vector)
index.build(trees)

arguments = AnnoyArgs(dim=dim, metric=metric, trees=trees, length=len(vectors)) # type: ignore
arguments = AnnoyArgs(
dim=dim, metric=metric_string, trees=trees, length=len(vectors), internal_metric=internal_metric
) # type: ignore
return AnnoyBackend(index, arguments=arguments)

@property
Expand All @@ -89,7 +92,7 @@ def load(cls: type[AnnoyBackend], base_path: Path) -> AnnoyBackend:
"""Load the vectors from a path."""
path = Path(base_path) / "index.bin"
arguments = AnnoyArgs.load(base_path / "arguments.json")
index = AnnoyIndex(arguments.dim, arguments.metric) # type: ignore
index = AnnoyIndex(arguments.dim, arguments.internal_metric) # type: ignore
index.load(str(path))

return cls(index, arguments=arguments)
Expand All @@ -106,11 +109,11 @@ def query(self, vectors: npt.NDArray, k: int) -> QueryResult:
"""Query the backend."""
out = []
for vec in vectors:
if self.arguments.metric == "dot":
if self.arguments.metric == "cosine":
vec = normalize(vec)
indices, scores = self.index.get_nns_by_vector(vec, k, include_distances=True)
scores_array = np.asarray(scores)
if self.arguments.metric == "dot":
if self.arguments.metric == "cosine":
# Convert cosine similarity to cosine distance
scores_array = 1 - scores_array
out.append((np.asarray(indices), scores_array))
Expand Down

0 comments on commit 7a74946

Please sign in to comment.