Skip to content

Commit

Permalink
fix: Fixed metric bug (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
Pringled authored Dec 2, 2024
1 parent 03f17db commit a905cf8
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 10 deletions.
3 changes: 1 addition & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ pip install vicinity[all]
The following code snippet demonstrates how to use Vicinity for nearest neighbor search:
```python
import numpy as np
from vicinity import Vicinity
from vicinity.datatypes import Backend, Metric
from vicinity import Vicinity, Backend, Metric

# Create some dummy data
items = ["triforce", "master sword", "hylian shield", "boomerang", "hookshot"]
Expand Down
20 changes: 19 additions & 1 deletion uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

19 changes: 12 additions & 7 deletions vicinity/backends/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Literal
from typing import Any, Literal, Union

import numpy as np
from numpy import typing as npt
Expand All @@ -15,7 +15,7 @@

@dataclass
class BasicArgs(BaseArgs):
metric: Literal["cosine", "euclidean"] = "cosine"
metric: str = "cosine"


class BasicBackend(AbstractBackend[BasicArgs], ABC):
Expand Down Expand Up @@ -66,15 +66,20 @@ def _dist(self, x: npt.NDArray) -> npt.NDArray:
raise NotImplementedError()

@classmethod
def from_vectors(cls, vectors: npt.NDArray, **kwargs: Any) -> BasicBackend:
def from_vectors(cls, vectors: npt.NDArray, metric: Union[str, Metric] = "cosine", **kwargs: Any) -> BasicBackend:
"""Create a new instance from vectors."""
arguments = BasicArgs(**kwargs)
if arguments.metric == "cosine":
metric_enum = Metric.from_string(metric)
if metric_enum not in cls.supported_metrics:
raise ValueError(f"Metric '{metric_enum.value}' is not supported by BasicBackend.")

metric = metric_enum.value
arguments = BasicArgs(metric=metric)
if metric == "cosine":
return CosineBasicBackend(vectors, arguments)
elif arguments.metric == "euclidean":
elif metric == "euclidean":
return EuclideanBasicBackend(vectors, arguments)
else:
raise ValueError(f"Unsupported metric: {arguments.metric}")
raise ValueError(f"Unsupported metric: {metric}")

@classmethod
def load(cls, folder: Path) -> BasicBackend:
Expand Down

0 comments on commit a905cf8

Please sign in to comment.