diff --git a/redisvl/extensions/llmcache/schema.py b/redisvl/extensions/llmcache/schema.py index e90c8bf6..77903f7c 100644 --- a/redisvl/extensions/llmcache/schema.py +++ b/redisvl/extensions/llmcache/schema.py @@ -26,8 +26,6 @@ class CacheEntry(BaseModel): """Optional metadata stored on the cache entry""" filters: Optional[Dict[str, Any]] = Field(default=None) """Optional filter data stored on the cache entry for customizing retrieval""" - dtype: str - """The data type for the prompt vector.""" @root_validator(pre=True) @classmethod @@ -43,9 +41,9 @@ def non_empty_metadata(cls, v): raise TypeError("Metadata must be a dictionary.") return v - def to_dict(self) -> Dict: + def to_dict(self, dtype: str) -> Dict: data = self.dict(exclude_none=True) - data["prompt_vector"] = array_to_buffer(self.prompt_vector, self.dtype) + data["prompt_vector"] = array_to_buffer(self.prompt_vector, dtype) if self.metadata is not None: data["metadata"] = serialize(self.metadata) if self.filters is not None: diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index d57861af..3c165a09 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -15,6 +15,8 @@ from redisvl.utils.utils import current_timestamp, serialize, validate_vector_dims from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer +VECTOR_FIELD_NAME = "prompt_vector" ### + class SemanticCache(BaseLLMCache): """Semantic Cache for Large Language Models.""" @@ -23,7 +25,7 @@ class SemanticCache(BaseLLMCache): entry_id_field_name: str = "entry_id" prompt_field_name: str = "prompt" response_field_name: str = "response" - vector_field_name: str = "prompt_vector" + ###vector_field_name: str = "prompt_vector" inserted_at_field_name: str = "inserted_at" updated_at_field_name: str = "updated_at" metadata_field_name: str = "metadata" @@ -136,9 +138,10 @@ def __init__( validate_vector_dims( vectorizer.dims, - self._index.schema.fields[self.vector_field_name].attrs.dims, # type: ignore + self._index.schema.fields[VECTOR_FIELD_NAME].attrs.dims, # type: ignore ) self._vectorizer = vectorizer + self._dtype = self.index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype # type: ignore[union-attr] def _modify_schema( self, @@ -290,8 +293,7 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]: if not isinstance(prompt, str): raise TypeError("Prompt must be a string.") - dtype = self.index.schema.fields[self.vector_field_name].attrs.datatype # type: ignore[union-attr] - return self._vectorizer.embed(prompt, dtype=dtype) + return self._vectorizer.embed(prompt, dtype=self._dtype) async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]: """Converts a text prompt to its vector representation using the @@ -304,7 +306,7 @@ async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]: def _check_vector_dims(self, vector: List[float]): """Checks the size of the provided vector and raises an error if it doesn't match the search index vector dimensions.""" - schema_vector_dims = self._index.schema.fields[self.vector_field_name].attrs.dims # type: ignore + schema_vector_dims = self._index.schema.fields[VECTOR_FIELD_NAME].attrs.dims # type: ignore validate_vector_dims(len(vector), schema_vector_dims) def check( @@ -367,13 +369,13 @@ def check( query = RangeQuery( vector=vector, - vector_field_name=self.vector_field_name, + vector_field_name=VECTOR_FIELD_NAME, return_fields=self.return_fields, distance_threshold=distance_threshold, num_results=num_results, return_score=True, filter_expression=filter_expression, - dtype=self.index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr] + dtype=self._dtype, ) # Search the cache! @@ -449,7 +451,7 @@ async def acheck( query = RangeQuery( vector=vector, - vector_field_name=self.vector_field_name, + vector_field_name=VECTOR_FIELD_NAME, return_fields=self.return_fields, distance_threshold=distance_threshold, num_results=num_results, @@ -539,13 +541,12 @@ def store( prompt_vector=vector, metadata=metadata, filters=filters, - dtype=self.index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr] ) # Load cache entry with TTL ttl = ttl or self._ttl keys = self._index.load( - data=[cache_entry.to_dict()], + data=[cache_entry.to_dict(self._dtype)], ttl=ttl, id_field=self.entry_id_field_name, ) @@ -604,13 +605,12 @@ async def astore( prompt_vector=vector, metadata=metadata, filters=filters, - dtype=self.index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr] ) # Load cache entry with TTL ttl = ttl or self._ttl keys = await aindex.load( - data=[cache_entry.to_dict()], + data=[cache_entry.to_dict(self._dtype)], ttl=ttl, id_field=self.entry_id_field_name, ) diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index 7e3fd9b1..66d38507 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -28,6 +28,8 @@ logger = get_logger(__name__) +VECTOR_FIELD_NAME = "vector" ### + class SemanticRouter(BaseModel): """Semantic Router for managing and querying route vectors.""" @@ -40,7 +42,7 @@ class SemanticRouter(BaseModel): """The vectorizer used to embed route references.""" routing_config: RoutingConfig = Field(default_factory=RoutingConfig) """Configuration for routing behavior.""" - vector_field_name: str = "vector" + ### vector_field_name: str = "vector" _index: SearchIndex = PrivateAttr() @@ -171,7 +173,7 @@ def _add_routes(self, routes: List[Route]): reference_vectors = self.vectorizer.embed_many( [reference for reference in route.references], as_buffer=True, - dtype=self._index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr] + dtype=self._index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr] ) # set route references for i, reference in enumerate(route.references): @@ -248,7 +250,7 @@ def _classify_route( vector_field_name="vector", distance_threshold=distance_threshold, return_fields=["route_name"], - dtype=self._index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr] + dtype=self._index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr] ) aggregate_request = self._build_aggregate_request( @@ -301,7 +303,7 @@ def _classify_multi_route( vector_field_name="vector", distance_threshold=distance_threshold, return_fields=["route_name"], - dtype=self._index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr] + dtype=self._index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr] ) aggregate_request = self._build_aggregate_request( vector_range_query, aggregation_method, max_k diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index 1d2c553b..9835dc11 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -13,9 +13,11 @@ from redisvl.utils.utils import validate_vector_dims from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer +VECTOR_FIELD_NAME = "vector_field" ### + class SemanticSessionManager(BaseSessionManager): - vector_field_name: str = "vector_field" + ###vector_field_name: str = "vector_field" def __init__( self, @@ -201,13 +203,13 @@ def get_relevant( query = RangeQuery( vector=self._vectorizer.embed(prompt), - vector_field_name=self.vector_field_name, + vector_field_name=VECTOR_FIELD_NAME, return_fields=return_fields, distance_threshold=distance_threshold, num_results=top_k, return_score=True, filter_expression=session_filter, - dtype=self._index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr] + dtype=self._index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr] ) messages = self._index.query(query) @@ -321,7 +323,7 @@ def add_messages( content_vector = self._vectorizer.embed(message[self.content_field_name]) validate_vector_dims( len(content_vector), - self._index.schema.fields[self.vector_field_name].attrs.dims, # type: ignore + self._index.schema.fields[VECTOR_FIELD_NAME].attrs.dims, # type: ignore ) chat_message = ChatMessage( @@ -329,7 +331,7 @@ def add_messages( content=message[self.content_field_name], session_tag=session_tag, vector_field=content_vector, - dtype=self._index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr] + dtype=self._index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr] ) if self.tool_field_name in message: diff --git a/redisvl/index/storage.py b/redisvl/index/storage.py index 209ea6f4..12ef2052 100644 --- a/redisvl/index/storage.py +++ b/redisvl/index/storage.py @@ -2,7 +2,6 @@ import uuid from typing import Any, Callable, Dict, Iterable, List, Optional -from numpy import frombuffer from pydantic.v1 import BaseModel from redis import Redis from redis.asyncio import Redis as AsyncRedis @@ -394,7 +393,7 @@ class HashStorage(BaseStorage): """Hash data type for the index""" def _validate(self, obj: Dict[str, Any]): - """Validate that the given object is a dictionary suitable for storage + """Validate that the given object is a dictionary, suitable for storage as a Redis hash. Args: diff --git a/tests/unit/test_llmcache_schema.py b/tests/unit/test_llmcache_schema.py index 7c1755d2..f210efce 100644 --- a/tests/unit/test_llmcache_schema.py +++ b/tests/unit/test_llmcache_schema.py @@ -12,7 +12,6 @@ def test_valid_cache_entry_creation(): prompt="What is AI?", response="AI is artificial intelligence.", prompt_vector=[0.1, 0.2, 0.3], - dtype="float16", ) assert entry.entry_id == hashify("What is AI?") assert entry.prompt == "What is AI?" @@ -26,7 +25,6 @@ def test_cache_entry_with_given_entry_id(): prompt="What is AI?", response="AI is artificial intelligence.", prompt_vector=[0.1, 0.2, 0.3], - dtype="float16", ) assert entry.entry_id == "custom_id" @@ -38,7 +36,6 @@ def test_cache_entry_with_invalid_metadata(): response="AI is artificial intelligence.", prompt_vector=[0.1, 0.2, 0.3], metadata="invalid_metadata", - dtype="float64", ) @@ -49,9 +46,8 @@ def test_cache_entry_to_dict(): prompt_vector=[0.1, 0.2, 0.3], metadata={"author": "John"}, filters={"category": "technology"}, - dtype="float32", ) - result = entry.to_dict() + result = entry.to_dict(dtype="float32") assert result["entry_id"] == hashify("What is AI?") assert result["metadata"] == json.dumps({"author": "John"}) assert result["prompt_vector"] == array_to_buffer([0.1, 0.2, 0.3], "float32") @@ -112,9 +108,8 @@ def test_cache_entry_with_empty_optional_fields(): prompt="What is AI?", response="AI is artificial intelligence.", prompt_vector=[0.1, 0.2, 0.3], - dtype="bfloat16", ) - result = entry.to_dict() + result = entry.to_dict(dtype="float32") assert "metadata" not in result assert "filters" not in result