diff --git a/chromadb/test/utils/distance_functions.py b/chromadb/test/utils/distance_functions.py new file mode 100644 index 00000000000..d8f05e8e3a5 --- /dev/null +++ b/chromadb/test/utils/distance_functions.py @@ -0,0 +1,7 @@ +from chromadb.utils.distance_functions import cosine +import numpy as np + + +def test_cosine_zero() -> None: + x = np.array([0.0, 0.0], dtype=np.float16) + assert cosine(x, x) == 1.0 diff --git a/chromadb/utils/distance_functions.py b/chromadb/utils/distance_functions.py index e7e77bf7f94..e4c95f87832 100644 --- a/chromadb/utils/distance_functions.py +++ b/chromadb/utils/distance_functions.py @@ -1,22 +1,32 @@ """ These functions match what the spec of hnswlib is. """ +from typing import Union, cast import numpy as np -from numpy.typing import ArrayLike +from numpy.typing import NDArray +Vector = NDArray[Union[np.int32, np.float32, np.int16, np.float16]] -def l2(x: ArrayLike, y: ArrayLike) -> float: - return np.linalg.norm(x - y) ** 2 +def l2(x: Vector, y: Vector) -> float: + return (np.linalg.norm(x - y) ** 2).item() -def cosine(x: ArrayLike, y: ArrayLike) -> float: + +def cosine(x: Vector, y: Vector) -> float: # This epsilon is used to prevent division by zero, and the value is the same # https://github.com/nmslib/hnswlib/blob/359b2ba87358224963986f709e593d799064ace6/python_bindings/bindings.cpp#L238 + + # We need to adapt the epsilon to the precision of the input NORM_EPS = 1e-30 - return 1 - np.dot(x, y) / ( - (np.linalg.norm(x) + NORM_EPS) * (np.linalg.norm(y) + NORM_EPS) + if x.dtype == np.float16 or y.dtype == np.float16: + NORM_EPS = 1e-7 + return cast( + float, + ( + 1.0 - np.dot(x, y) / ((np.linalg.norm(x) * np.linalg.norm(y)) + NORM_EPS) + ).item(), ) -def ip(x: ArrayLike, y: ArrayLike) -> float: - return 1 - np.dot(x, y) +def ip(x: Vector, y: Vector) -> float: + return cast(float, (1.0 - np.dot(x, y)).item())