Skip to content

Commit

Permalink
removes dtype from class definitions, and uses constants instead
Browse files Browse the repository at this point in the history
  • Loading branch information
justin-cechmanek committed Sep 25, 2024
1 parent 41f5693 commit 6898f1d
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 34 deletions.
6 changes: 2 additions & 4 deletions redisvl/extensions/llmcache/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ class CacheEntry(BaseModel):
"""Optional metadata stored on the cache entry"""
filters: Optional[Dict[str, Any]] = Field(default=None)
"""Optional filter data stored on the cache entry for customizing retrieval"""
dtype: str
"""The data type for the prompt vector."""

@root_validator(pre=True)
@classmethod
Expand All @@ -43,9 +41,9 @@ def non_empty_metadata(cls, v):
raise TypeError("Metadata must be a dictionary.")
return v

def to_dict(self) -> Dict:
def to_dict(self, dtype: str) -> Dict:
data = self.dict(exclude_none=True)
data["prompt_vector"] = array_to_buffer(self.prompt_vector, self.dtype)
data["prompt_vector"] = array_to_buffer(self.prompt_vector, dtype)
if self.metadata is not None:
data["metadata"] = serialize(self.metadata)
if self.filters is not None:
Expand Down
24 changes: 12 additions & 12 deletions redisvl/extensions/llmcache/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from redisvl.utils.utils import current_timestamp, serialize, validate_vector_dims
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer

VECTOR_FIELD_NAME = "prompt_vector" ###


class SemanticCache(BaseLLMCache):
"""Semantic Cache for Large Language Models."""
Expand All @@ -23,7 +25,7 @@ class SemanticCache(BaseLLMCache):
entry_id_field_name: str = "entry_id"
prompt_field_name: str = "prompt"
response_field_name: str = "response"
vector_field_name: str = "prompt_vector"
###vector_field_name: str = "prompt_vector"
inserted_at_field_name: str = "inserted_at"
updated_at_field_name: str = "updated_at"
metadata_field_name: str = "metadata"
Expand Down Expand Up @@ -136,9 +138,10 @@ def __init__(

validate_vector_dims(
vectorizer.dims,
self._index.schema.fields[self.vector_field_name].attrs.dims, # type: ignore
self._index.schema.fields[VECTOR_FIELD_NAME].attrs.dims, # type: ignore
)
self._vectorizer = vectorizer
self._dtype = self.index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype # type: ignore[union-attr]

def _modify_schema(
self,
Expand Down Expand Up @@ -290,8 +293,7 @@ def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]:
if not isinstance(prompt, str):
raise TypeError("Prompt must be a string.")

dtype = self.index.schema.fields[self.vector_field_name].attrs.datatype # type: ignore[union-attr]
return self._vectorizer.embed(prompt, dtype=dtype)
return self._vectorizer.embed(prompt, dtype=self._dtype)

async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]:
"""Converts a text prompt to its vector representation using the
Expand All @@ -304,7 +306,7 @@ async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]:
def _check_vector_dims(self, vector: List[float]):
"""Checks the size of the provided vector and raises an error if it
doesn't match the search index vector dimensions."""
schema_vector_dims = self._index.schema.fields[self.vector_field_name].attrs.dims # type: ignore
schema_vector_dims = self._index.schema.fields[VECTOR_FIELD_NAME].attrs.dims # type: ignore
validate_vector_dims(len(vector), schema_vector_dims)

def check(
Expand Down Expand Up @@ -367,13 +369,13 @@ def check(

query = RangeQuery(
vector=vector,
vector_field_name=self.vector_field_name,
vector_field_name=VECTOR_FIELD_NAME,
return_fields=self.return_fields,
distance_threshold=distance_threshold,
num_results=num_results,
return_score=True,
filter_expression=filter_expression,
dtype=self.index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr]
dtype=self._dtype,
)

# Search the cache!
Expand Down Expand Up @@ -449,7 +451,7 @@ async def acheck(

query = RangeQuery(
vector=vector,
vector_field_name=self.vector_field_name,
vector_field_name=VECTOR_FIELD_NAME,
return_fields=self.return_fields,
distance_threshold=distance_threshold,
num_results=num_results,
Expand Down Expand Up @@ -539,13 +541,12 @@ def store(
prompt_vector=vector,
metadata=metadata,
filters=filters,
dtype=self.index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr]
)

# Load cache entry with TTL
ttl = ttl or self._ttl
keys = self._index.load(
data=[cache_entry.to_dict()],
data=[cache_entry.to_dict(self._dtype)],
ttl=ttl,
id_field=self.entry_id_field_name,
)
Expand Down Expand Up @@ -604,13 +605,12 @@ async def astore(
prompt_vector=vector,
metadata=metadata,
filters=filters,
dtype=self.index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr]
)

# Load cache entry with TTL
ttl = ttl or self._ttl
keys = await aindex.load(
data=[cache_entry.to_dict()],
data=[cache_entry.to_dict(self._dtype)],
ttl=ttl,
id_field=self.entry_id_field_name,
)
Expand Down
10 changes: 6 additions & 4 deletions redisvl/extensions/router/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@

logger = get_logger(__name__)

VECTOR_FIELD_NAME = "vector" ###


class SemanticRouter(BaseModel):
"""Semantic Router for managing and querying route vectors."""
Expand All @@ -40,7 +42,7 @@ class SemanticRouter(BaseModel):
"""The vectorizer used to embed route references."""
routing_config: RoutingConfig = Field(default_factory=RoutingConfig)
"""Configuration for routing behavior."""
vector_field_name: str = "vector"
### vector_field_name: str = "vector"

_index: SearchIndex = PrivateAttr()

Expand Down Expand Up @@ -171,7 +173,7 @@ def _add_routes(self, routes: List[Route]):
reference_vectors = self.vectorizer.embed_many(
[reference for reference in route.references],
as_buffer=True,
dtype=self._index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr]
dtype=self._index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
)
# set route references
for i, reference in enumerate(route.references):
Expand Down Expand Up @@ -248,7 +250,7 @@ def _classify_route(
vector_field_name="vector",
distance_threshold=distance_threshold,
return_fields=["route_name"],
dtype=self._index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr]
dtype=self._index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
)

aggregate_request = self._build_aggregate_request(
Expand Down Expand Up @@ -301,7 +303,7 @@ def _classify_multi_route(
vector_field_name="vector",
distance_threshold=distance_threshold,
return_fields=["route_name"],
dtype=self._index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr]
dtype=self._index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
)
aggregate_request = self._build_aggregate_request(
vector_range_query, aggregation_method, max_k
Expand Down
12 changes: 7 additions & 5 deletions redisvl/extensions/session_manager/semantic_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
from redisvl.utils.utils import validate_vector_dims
from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer

VECTOR_FIELD_NAME = "vector_field" ###


class SemanticSessionManager(BaseSessionManager):
vector_field_name: str = "vector_field"
###vector_field_name: str = "vector_field"

def __init__(
self,
Expand Down Expand Up @@ -201,13 +203,13 @@ def get_relevant(

query = RangeQuery(
vector=self._vectorizer.embed(prompt),
vector_field_name=self.vector_field_name,
vector_field_name=VECTOR_FIELD_NAME,
return_fields=return_fields,
distance_threshold=distance_threshold,
num_results=top_k,
return_score=True,
filter_expression=session_filter,
dtype=self._index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr]
dtype=self._index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
)
messages = self._index.query(query)

Expand Down Expand Up @@ -321,15 +323,15 @@ def add_messages(
content_vector = self._vectorizer.embed(message[self.content_field_name])
validate_vector_dims(
len(content_vector),
self._index.schema.fields[self.vector_field_name].attrs.dims, # type: ignore
self._index.schema.fields[VECTOR_FIELD_NAME].attrs.dims, # type: ignore
)

chat_message = ChatMessage(
role=message[self.role_field_name],
content=message[self.content_field_name],
session_tag=session_tag,
vector_field=content_vector,
dtype=self._index.schema.fields[self.vector_field_name].attrs.datatype, # type: ignore[union-attr]
dtype=self._index.schema.fields[VECTOR_FIELD_NAME].attrs.datatype, # type: ignore[union-attr]
)

if self.tool_field_name in message:
Expand Down
3 changes: 1 addition & 2 deletions redisvl/index/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import uuid
from typing import Any, Callable, Dict, Iterable, List, Optional

from numpy import frombuffer
from pydantic.v1 import BaseModel
from redis import Redis
from redis.asyncio import Redis as AsyncRedis
Expand Down Expand Up @@ -394,7 +393,7 @@ class HashStorage(BaseStorage):
"""Hash data type for the index"""

def _validate(self, obj: Dict[str, Any]):
"""Validate that the given object is a dictionary suitable for storage
"""Validate that the given object is a dictionary, suitable for storage
as a Redis hash.
Args:
Expand Down
9 changes: 2 additions & 7 deletions tests/unit/test_llmcache_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def test_valid_cache_entry_creation():
prompt="What is AI?",
response="AI is artificial intelligence.",
prompt_vector=[0.1, 0.2, 0.3],
dtype="float16",
)
assert entry.entry_id == hashify("What is AI?")
assert entry.prompt == "What is AI?"
Expand All @@ -26,7 +25,6 @@ def test_cache_entry_with_given_entry_id():
prompt="What is AI?",
response="AI is artificial intelligence.",
prompt_vector=[0.1, 0.2, 0.3],
dtype="float16",
)
assert entry.entry_id == "custom_id"

Expand All @@ -38,7 +36,6 @@ def test_cache_entry_with_invalid_metadata():
response="AI is artificial intelligence.",
prompt_vector=[0.1, 0.2, 0.3],
metadata="invalid_metadata",
dtype="float64",
)


Expand All @@ -49,9 +46,8 @@ def test_cache_entry_to_dict():
prompt_vector=[0.1, 0.2, 0.3],
metadata={"author": "John"},
filters={"category": "technology"},
dtype="float32",
)
result = entry.to_dict()
result = entry.to_dict(dtype="float32")
assert result["entry_id"] == hashify("What is AI?")
assert result["metadata"] == json.dumps({"author": "John"})
assert result["prompt_vector"] == array_to_buffer([0.1, 0.2, 0.3], "float32")
Expand Down Expand Up @@ -112,9 +108,8 @@ def test_cache_entry_with_empty_optional_fields():
prompt="What is AI?",
response="AI is artificial intelligence.",
prompt_vector=[0.1, 0.2, 0.3],
dtype="bfloat16",
)
result = entry.to_dict()
result = entry.to_dict(dtype="float32")
assert "metadata" not in result
assert "filters" not in result

Expand Down

0 comments on commit 6898f1d

Please sign in to comment.