Skip to content

Commit

Permalink
feat: make Metric the basic unit for metrics, not str (#44)
Browse files Browse the repository at this point in the history
* feat: make Metric the basic unit for metrics, not str

* Remove old type

* feat: use Metric internally everywhere
  • Loading branch information
stephantul authored Dec 6, 2024
1 parent 7a74946 commit aeaccaa
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 52 deletions.
17 changes: 8 additions & 9 deletions vicinity/backends/annoy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
@dataclass
class AnnoyArgs(BaseArgs):
dim: int = 0
metric: str = "cosine"
metric: Metric = Metric.COSINE
internal_metric: str = "dot"
trees: int = 100
length: int | None = None
Expand All @@ -25,7 +25,7 @@ class AnnoyArgs(BaseArgs):
class AnnoyBackend(AbstractBackend[AnnoyArgs]):
argument_class = AnnoyArgs
supported_metrics = {Metric.COSINE, Metric.EUCLIDEAN}
inverse_metric_mapping = {
inverse_metric_mapping: dict[Metric, str] = {
Metric.COSINE: "dot",
Metric.EUCLIDEAN: "euclidean",
}
Expand Down Expand Up @@ -56,7 +56,6 @@ def from_vectors(
if metric_enum not in cls.supported_metrics:
raise ValueError(f"Metric '{metric_enum.value}' is not supported by AnnoyBackend.")

metric_string = metric_enum.value
internal_metric = cls._map_metric_to_string(metric_enum)

if metric_enum == Metric.COSINE:
Expand All @@ -68,9 +67,7 @@ def from_vectors(
index.add_item(i, vector)
index.build(trees)

arguments = AnnoyArgs(
dim=dim, metric=metric_string, trees=trees, length=len(vectors), internal_metric=internal_metric
) # type: ignore
arguments = AnnoyArgs(dim=dim, metric=metric, trees=trees, length=len(vectors), internal_metric=internal_metric) # type: ignore
return AnnoyBackend(index, arguments=arguments)

@property
Expand All @@ -91,8 +88,10 @@ def __len__(self) -> int:
def load(cls: type[AnnoyBackend], base_path: Path) -> AnnoyBackend:
"""Load the vectors from a path."""
path = Path(base_path) / "index.bin"

arguments = AnnoyArgs.load(base_path / "arguments.json")
index = AnnoyIndex(arguments.dim, arguments.internal_metric) # type: ignore
metric = cls._map_metric_to_string(arguments.metric)
index = AnnoyIndex(arguments.dim, metric) # type: ignore
index.load(str(path))

return cls(index, arguments=arguments)
Expand All @@ -109,11 +108,11 @@ def query(self, vectors: npt.NDArray, k: int) -> QueryResult:
"""Query the backend."""
out = []
for vec in vectors:
if self.arguments.metric == "cosine":
if self.arguments.metric == Metric.COSINE:
vec = normalize(vec)
indices, scores = self.index.get_nns_by_vector(vec, k, include_distances=True)
scores_array = np.asarray(scores)
if self.arguments.metric == "cosine":
if self.arguments.metric == Metric.COSINE:
# Convert cosine similarity to cosine distance
scores_array = 1 - scores_array
out.append((np.asarray(indices), scores_array))
Expand Down
12 changes: 9 additions & 3 deletions vicinity/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,25 @@

@dataclass
class BaseArgs:
metric: Metric

def dump(self, file: Path) -> None:
"""Dump the arguments to a file."""
with open(file, "w") as f:
json.dump(asdict(self), f)
d = self.dict()
d["metric"] = d["metric"].value
json.dump(d, f)

@classmethod
def load(cls: type[ArgType], file: Path) -> ArgType:
"""Load the arguments from a file."""
with open(file, "r") as f:
return cls(**json.load(f))
data = json.load(f)
data["metric"] = Metric.from_string(data["metric"])
return cls(**data)

def dict(self) -> dict[str, Any]:
"""Dump the arguments to a string."""
"""Dump the arguments to a dict."""
return asdict(self)


Expand Down
13 changes: 6 additions & 7 deletions vicinity/backends/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

@dataclass
class BasicArgs(BaseArgs):
metric: str = "cosine"
metric: Metric = Metric.COSINE


class BasicBackend(AbstractBackend[BasicArgs], ABC):
Expand Down Expand Up @@ -72,11 +72,10 @@ def from_vectors(cls, vectors: npt.NDArray, metric: Union[str, Metric] = "cosine
if metric_enum not in cls.supported_metrics:
raise ValueError(f"Metric '{metric_enum.value}' is not supported by BasicBackend.")

metric = metric_enum.value
arguments = BasicArgs(metric=metric)
if metric == "cosine":
arguments = BasicArgs(metric=metric_enum)
if metric_enum == Metric.COSINE:
return CosineBasicBackend(vectors, arguments)
elif metric == "euclidean":
elif metric_enum == Metric.EUCLIDEAN:
return EuclideanBasicBackend(vectors, arguments)
else:
raise ValueError(f"Unsupported metric: {metric}")
Expand All @@ -88,9 +87,9 @@ def load(cls, folder: Path) -> BasicBackend:
arguments = BasicArgs.load(folder / "arguments.json")
with open(path, "rb") as f:
vectors = np.load(f)
if arguments.metric == "cosine":
if arguments.metric == Metric.COSINE:
return CosineBasicBackend(vectors, arguments)
elif arguments.metric == "euclidean":
elif arguments.metric == Metric.EUCLIDEAN:
return EuclideanBasicBackend(vectors, arguments)
else:
raise ValueError(f"Unsupported metric: {arguments.metric}")
Expand Down
4 changes: 2 additions & 2 deletions vicinity/backends/faiss.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
class FaissArgs(BaseArgs):
dim: int = 0
index_type: str = "flat"
metric: str = "cosine"
metric: Metric = Metric.COSINE
nlist: int = 100
m: int = 8
nbits: int = 8
Expand Down Expand Up @@ -122,7 +122,7 @@ def from_vectors( # noqa: C901
arguments = FaissArgs(
dim=dim,
index_type=index_type,
metric=metric_enum.value,
metric=metric_enum,
nlist=nlist,
m=m,
nbits=nbits,
Expand Down
7 changes: 4 additions & 3 deletions vicinity/backends/hnsw.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
@dataclass
class HNSWArgs(BaseArgs):
dim: int = 0
metric: str = "cosine"
metric: Metric = Metric.COSINE
ef_construction: int = 200
m: int = 16

Expand Down Expand Up @@ -58,7 +58,7 @@ def from_vectors(
index = HnswIndex(space=metric, dim=dim)
index.init_index(max_elements=vectors.shape[0], ef_construction=ef_construction, M=m)
index.add_items(vectors)
arguments = HNSWArgs(dim=dim, metric=metric, ef_construction=ef_construction, m=m)
arguments = HNSWArgs(dim=dim, metric=metric_enum, ef_construction=ef_construction, m=m)
return HNSWBackend(index, arguments=arguments)

@property
Expand All @@ -80,7 +80,8 @@ def load(cls: type[HNSWBackend], base_path: Path) -> HNSWBackend:
"""Load the vectors from a path."""
path = Path(base_path) / "index.bin"
arguments = HNSWArgs.load(base_path / "arguments.json")
index = HnswIndex(space=arguments.metric, dim=arguments.dim)
mapped_metric = cls.inverse_metric_mapping[arguments.metric]
index = HnswIndex(space=mapped_metric, dim=arguments.dim)
index.load_index(str(path))
return cls(index, arguments=arguments)

Expand Down
9 changes: 3 additions & 6 deletions vicinity/backends/pynndescent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
@dataclass
class PyNNDescentArgs(BaseArgs):
n_neighbors: int = 15
metric: str = "cosine"
metric: Metric = Metric.COSINE


class PyNNDescentBackend(AbstractBackend[PyNNDescentArgs]):
Expand Down Expand Up @@ -49,7 +49,7 @@ def from_vectors(
metric = metric_enum.value

index = NNDescent(vectors, n_neighbors=n_neighbors, metric=metric, **kwargs)
arguments = PyNNDescentArgs(n_neighbors=n_neighbors, metric=metric)
arguments = PyNNDescentArgs(n_neighbors=n_neighbors, metric=metric_enum)
return cls(index=index, arguments=arguments)

def __len__(self) -> int:
Expand Down Expand Up @@ -105,10 +105,7 @@ def load(cls: type[PyNNDescentBackend], base_path: Path) -> PyNNDescentBackend:
arguments = PyNNDescentArgs.load(base_path / "arguments.json")
vectors = np.load(Path(base_path) / "vectors.npy")

metric_enum = Metric.from_string(arguments.metric)
pynndescent_metric = metric_enum.value

index = NNDescent(vectors, n_neighbors=arguments.n_neighbors, metric=pynndescent_metric)
index = NNDescent(vectors, n_neighbors=arguments.n_neighbors, metric=arguments.metric.value)

# Load the neighbor graph if it was saved
neighbor_graph_path = base_path / "neighbor_graph.npy"
Expand Down
10 changes: 5 additions & 5 deletions vicinity/backends/usearch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
@dataclass
class UsearchArgs(BaseArgs):
dim: int = 0
metric: str = "cos"
metric: Metric = Metric.COSINE
connectivity: int = 16
expansion_add: int = 128
expansion_search: int = 64
Expand Down Expand Up @@ -67,10 +67,10 @@ def from_vectors(
expansion_add=expansion_add,
expansion_search=expansion_search,
)
index.add(keys=None, vectors=vectors) # type: ignore
index.add(keys=None, vectors=vectors) # type: ignore # None keys are allowed but not typed
arguments = UsearchArgs(
dim=dim,
metric=metric,
metric=metric_enum,
connectivity=connectivity,
expansion_add=expansion_add,
expansion_search=expansion_search,
Expand Down Expand Up @@ -99,7 +99,7 @@ def load(cls: type[UsearchBackend], base_path: Path) -> UsearchBackend:

index = UsearchIndex(
ndim=arguments.dim,
metric=arguments.metric,
metric=cls._map_metric_to_string(arguments.metric),
connectivity=arguments.connectivity,
expansion_add=arguments.expansion_add,
expansion_search=arguments.expansion_search,
Expand All @@ -122,7 +122,7 @@ def query(self, vectors: npt.NDArray, k: int) -> QueryResult:

def insert(self, vectors: npt.NDArray) -> None:
"""Insert vectors into the backend."""
self.index.add(None, vectors) # type: ignore
self.index.add(None, vectors) # type: ignore # None keys are allowed, but not typed.

def delete(self, indices: list[int]) -> None:
"""Delete vectors from the index (not supported by Usearch)."""
Expand Down
25 changes: 8 additions & 17 deletions vicinity/backends/voyager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,34 +4,28 @@
from pathlib import Path
from typing import Any, Union

import numpy as np
from numpy import typing as npt
from voyager import Index, Space

from vicinity.backends.base import AbstractBackend, BaseArgs
from vicinity.datatypes import Backend, QueryResult
from vicinity.utils import Metric, normalize
from vicinity.utils import Metric


@dataclass
class VoyagerArgs(BaseArgs):
dim: int = 0
metric: str = "cosine"
metric: Metric = Metric.COSINE
ef_construction: int = 200
m: int = 16


class VoyagerBackend(AbstractBackend[VoyagerArgs]):
argument_class = VoyagerArgs
supported_metrics = {Metric.COSINE, Metric.EUCLIDEAN}
inverse_metric_mapping = {
Metric.COSINE: "cosine",
Metric.EUCLIDEAN: "l2",
}

metric_int_mapping = {
"l2": 0,
"cosine": 2,
_metric_to_space = {
Metric.COSINE: Space.Cosine,
Metric.EUCLIDEAN: Space.Euclidean,
}

def __init__(
Expand All @@ -56,13 +50,10 @@ def from_vectors(
metric_enum = Metric.from_string(metric)

if metric_enum not in cls.supported_metrics:
raise ValueError(
f"Metric '{metric_enum.value}' is not supported by VoyagerBackend."
)
raise ValueError(f"Metric '{metric_enum.value}' is not supported by VoyagerBackend.")

metric = cls._map_metric_to_string(metric_enum)
space = cls._metric_to_space[metric_enum]
dim = vectors.shape[1]
space = Space(value=cls.metric_int_mapping[metric])
index = Index(
space=space,
num_dimensions=dim,
Expand All @@ -72,7 +63,7 @@ def from_vectors(
index.add_items(vectors)
return cls(
index,
VoyagerArgs(dim=dim, metric=metric, ef_construction=ef_construction, m=m),
VoyagerArgs(dim=dim, metric=metric_enum, ef_construction=ef_construction, m=m),
)

def query(self, query: npt.NDArray, k: int) -> QueryResult:
Expand Down

0 comments on commit aeaccaa

Please sign in to comment.