From 9bd7dca4cd45e4a7991d3c316d62ad372057f03c Mon Sep 17 00:00:00 2001 From: Nikhil Thorat Date: Thu, 29 Feb 2024 12:30:12 -0500 Subject: [PATCH] save --- lilac/embeddings/vector_store.py | 6 +++--- lilac/signals/semantic_similarity_test.py | 6 ++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/lilac/embeddings/vector_store.py b/lilac/embeddings/vector_store.py index 2f9c3222f..efa8c54ee 100644 --- a/lilac/embeddings/vector_store.py +++ b/lilac/embeddings/vector_store.py @@ -3,7 +3,7 @@ import abc import os import pickle -from typing import Iterable, Optional, Sequence, Type, cast +from typing import Iterable, Iterator, Optional, Sequence, Type, cast import numpy as np @@ -50,7 +50,7 @@ def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None: pass @abc.abstractmethod - def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray: + def get(self, keys: Optional[Iterable[VectorKey]] = None) -> Iterator[np.ndarray]: """Return the embeddings for given keys. Args: @@ -160,7 +160,7 @@ def get(self, keys: Iterable[PathKey]) -> Iterable[list[SpanVector]]: all_vector_keys.append([(*path_key, i) for i in range(len(spans))]) flat_vector_keys = [key for vector_keys in all_vector_keys for key in (vector_keys or [])] - all_vectors = iter(self._vector_store.get(flat_vector_keys)) + all_vectors = self._vector_store.get(flat_vector_keys) for spans in all_spans: yield [{'span': span, 'vector': next(all_vectors)} for span in spans] diff --git a/lilac/signals/semantic_similarity_test.py b/lilac/signals/semantic_similarity_test.py index 2a54a9d1c..cc7c45dd7 100644 --- a/lilac/signals/semantic_similarity_test.py +++ b/lilac/signals/semantic_similarity_test.py @@ -49,9 +49,11 @@ def add(self, keys: list[VectorKey], embeddings: np.ndarray) -> None: pass @override - def get(self, keys: Optional[Iterable[VectorKey]] = None) -> np.ndarray: + def get(self, keys: Optional[Iterable[VectorKey]] = None) -> Iterator[np.ndarray]: keys = keys or [] - return np.array([EMBEDDINGS[tuple(path_key)][cast(int, index)] for *path_key, index in keys]) + yield from [ + np.array(EMBEDDINGS[tuple(path_key)][cast(int, index)]) for *path_key, index in keys + ] @override def delete(self, base_path: str) -> None: