Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sets the default datatype in our vectorizers to float32 if not specified by users #253

Merged
merged 13 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 39 additions & 7 deletions docs/user_guide/vectorizers_04.ipynb
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Two small things on this

  1. this does have the warning which doesn't really matter but a heads up
image 2. is there a reason mistral is commented out? is there something we need to fix with that

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah let's get rid of this warning. Good catch. 2 -- I think it had to do with stale API keys. @justin-cechmanek are those active again?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the Mistral client is deprecated. I'll look into updating our vectorizer class for it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice. I see this one here: #255

After we get this warning removed from the notebook, I am happy to merge!

Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,43 @@
" print(doc[\"text\"], doc[\"vector_distance\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Selecting your float data type\n",
"When embedding text as byte arrays RedisVL supports 4 different floating point data types, `float16`, `float32`, `float64` and `bfloat16`.\n",
"Your dtype set for your vectorizer must match what is defined in your search index. If one is not explicitly set the default is `float32`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vectorizer = HFTextVectorizer(dtype=\"float16\")\n",
"\n",
"# subsequent calls to embed('', as_buffer=True) and embed_many('', as_buffer=True) will now encode as float16\n",
"float16_bytes = vectorizer.embed('test sentence', as_buffer=True)\n",
"\n",
"# you can override this setting on each individual method call\n",
"float64_bytes = vectorizer.embed('test sentence', as_buffer=True, dtype=\"float64\")\n",
"\n",
"float16_bytes != float64_bytes"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -836,7 +873,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.13 ('redisvl2')",
"display_name": "redisvl-dev",
"language": "python",
"name": "python3"
},
Expand All @@ -852,12 +889,7 @@
"pygments_lexer": "ipython3",
"version": "3.12.2"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "9b1e6e9c2967143209c2f955cb869d1d3234f92dc4787f49f155f3abbdfb1316"
}
}
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
Expand Down
20 changes: 13 additions & 7 deletions redisvl/utils/vectorize/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic.v1 import BaseModel, validator

from redisvl.redis.utils import array_to_buffer
from redisvl.schema.fields import VectorDataType


class Vectorizers(Enum):
Expand All @@ -19,11 +20,22 @@ class Vectorizers(Enum):
class BaseVectorizer(BaseModel, ABC):
model: str
dims: int
dtype: str

@property
def type(self) -> str:
return "base"

@validator("dtype")
justin-cechmanek marked this conversation as resolved.
Show resolved Hide resolved
def check_dtype(dtype):
try:
VectorDataType(dtype.upper())
except ValueError:
raise ValueError(
f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}"
)
return dtype

@validator("dims")
@classmethod
def check_dims(cls, value):
Expand Down Expand Up @@ -81,13 +93,7 @@ def batchify(self, seq: list, size: int, preprocess: Optional[Callable] = None):
else:
yield seq[pos : pos + size]

def _process_embedding(
self, embedding: List[float], as_buffer: bool, dtype: Optional[str]
):
def _process_embedding(self, embedding: List[float], as_buffer: bool, dtype: str):
if as_buffer:
if not dtype:
raise RuntimeError(
"dtype is required if converting from float to byte string."
)
return array_to_buffer(embedding, dtype)
return embedding
19 changes: 13 additions & 6 deletions redisvl/utils/vectorize/text/azureopenai.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ class AzureOpenAITextVectorizer(BaseVectorizer):
_aclient: Any = PrivateAttr()

def __init__(
self, model: str = "text-embedding-ada-002", api_config: Optional[Dict] = None
self,
model: str = "text-embedding-ada-002",
api_config: Optional[Dict] = None,
dtype: str = "float32",
):
"""Initialize the AzureOpenAI vectorizer.

Expand All @@ -63,13 +66,17 @@ def __init__(
api_config (Optional[Dict], optional): Dictionary containing the
API key, API version, Azure endpoint, and any other API options.
Defaults to None.
dtype (str): the default datatype to use when embedding text as byte arrays.
Used when setting `as_buffer=True` in calls to embed() and embed_many().
Defaults to 'float32'.

Raises:
ImportError: If the openai library is not installed.
ValueError: If the AzureOpenAI API key, version, or endpoint are not provided.
ValueError: If an invalid dtype is provided.
"""
self._initialize_clients(api_config)
super().__init__(model=model, dims=self._set_model_dims(model))
super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype)

def _initialize_clients(self, api_config: Optional[Dict]):
"""
Expand Down Expand Up @@ -190,7 +197,7 @@ def embed_many(
if len(texts) > 0 and not isinstance(texts[0], str):
raise TypeError("Must pass in a list of str values to embed.")

dtype = kwargs.pop("dtype", None)
tylerhutcherson marked this conversation as resolved.
Show resolved Hide resolved
dtype = kwargs.pop("dtype", self.dtype)

embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
Expand Down Expand Up @@ -234,7 +241,7 @@ def embed(
if preprocess:
text = preprocess(text)

dtype = kwargs.pop("dtype", None)
dtype = kwargs.pop("dtype", self.dtype)

result = self._client.embeddings.create(input=[text], model=self.model)
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
Expand Down Expand Up @@ -274,7 +281,7 @@ async def aembed_many(
if len(texts) > 0 and not isinstance(texts[0], str):
raise TypeError("Must pass in a list of str values to embed.")

dtype = kwargs.pop("dtype", None)
dtype = kwargs.pop("dtype", self.dtype)

embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
Expand Down Expand Up @@ -320,7 +327,7 @@ async def aembed(
if preprocess:
text = preprocess(text)

dtype = kwargs.pop("dtype", None)
dtype = kwargs.pop("dtype", self.dtype)

result = await self._aclient.embeddings.create(input=[text], model=self.model)
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
Expand Down
11 changes: 8 additions & 3 deletions redisvl/utils/vectorize/text/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
self,
model: str = "amazon.titan-embed-text-v2:0",
api_config: Optional[Dict[str, str]] = None,
dtype: str = "float32",
) -> None:
"""Initialize the AWS Bedrock Vectorizer.

Expand All @@ -57,10 +58,14 @@ def __init__(
api_config (Optional[Dict[str, str]]): AWS credentials and config.
Can include: aws_access_key_id, aws_secret_access_key, aws_region
If not provided, will use environment variables.
dtype (str): the default datatype to use when embedding text as byte arrays.
Used when setting `as_buffer=True` in calls to embed() and embed_many().
Defaults to 'float32'.

Raises:
ValueError: If credentials are not provided in config or environment.
ImportError: If boto3 is not installed.
ValueError: If an invalid dtype is provided.
"""
try:
import boto3 # type: ignore
Expand Down Expand Up @@ -94,7 +99,7 @@ def __init__(
region_name=aws_region,
)

super().__init__(model=model, dims=self._set_model_dims(model))
super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype)

def _set_model_dims(self, model: str) -> int:
"""Initialize model and determine embedding dimensions."""
Expand Down Expand Up @@ -145,7 +150,7 @@ def embed(
response_body = json.loads(response["body"].read())
embedding = response_body["embedding"]

dtype = kwargs.pop("dtype", None)
dtype = kwargs.pop("dtype", self.dtype)
return self._process_embedding(embedding, as_buffer, dtype)

@retry(
Expand Down Expand Up @@ -181,7 +186,7 @@ def embed_many(
raise TypeError("Texts must be a list of strings")

embeddings: List[List[float]] = []
dtype = kwargs.pop("dtype", None)
dtype = kwargs.pop("dtype", self.dtype)

for batch in self.batchify(texts, batch_size, preprocess):
# Process each text in the batch individually since Bedrock
Expand Down
16 changes: 11 additions & 5 deletions redisvl/utils/vectorize/text/cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ class CohereTextVectorizer(BaseVectorizer):
_client: Any = PrivateAttr()

def __init__(
self, model: str = "embed-english-v3.0", api_config: Optional[Dict] = None
self,
model: str = "embed-english-v3.0",
api_config: Optional[Dict] = None,
dtype: str = "float32",
):
"""Initialize the Cohere vectorizer.

Expand All @@ -57,14 +60,17 @@ def __init__(
model (str): Model to use for embedding. Defaults to 'embed-english-v3.0'.
api_config (Optional[Dict], optional): Dictionary containing the API key.
Defaults to None.
dtype (str): the default datatype to use when embedding text as byte arrays.
Used when setting `as_buffer=True` in calls to embed() and embed_many().
Defaults to 'float32'.

Raises:
ImportError: If the cohere library is not installed.
ValueError: If the API key is not provided.

ValueError: If an invalid dtype is provided.
"""
self._initialize_client(api_config)
super().__init__(model=model, dims=self._set_model_dims(model))
super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype)

def _initialize_client(self, api_config: Optional[Dict]):
"""
Expand Down Expand Up @@ -159,7 +165,7 @@ def embed(
if preprocess:
text = preprocess(text)

dtype = kwargs.pop("dtype", None)
dtype = kwargs.pop("dtype", self.dtype)

embedding = self._client.embed(
texts=[text], model=self.model, input_type=input_type
Expand Down Expand Up @@ -228,7 +234,7 @@ def embed_many(
See https://docs.cohere.com/reference/embed."
)

dtype = kwargs.pop("dtype", None)
dtype = kwargs.pop("dtype", self.dtype)

embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
Expand Down
21 changes: 13 additions & 8 deletions redisvl/utils/vectorize/text/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


class CustomTextVectorizer(BaseVectorizer):
"""The CustomTextVectorizer class wraps user-defined embeding methods to create
"""The CustomTextVectorizer class wraps user-defined embedding methods to create
embeddings for text data.

This vectorizer is designed to accept a provided callable text vectorizer and
Expand Down Expand Up @@ -44,6 +44,7 @@ def __init__(
embed_many: Optional[Callable] = None,
aembed: Optional[Callable] = None,
aembed_many: Optional[Callable] = None,
dtype: str = "float32",
):
"""Initialize the Custom vectorizer.

Expand All @@ -52,10 +53,14 @@ def __init__(
embed_many (Optional[Callable)]: a Callable function that accepts a list of string objects and returns a list containing lists of floats. Defaults to None.
aembed (Optional[Callable]): an asyncronous Callable function that accepts a string object and returns a lists of floats. Defaults to None.
aembed_many (Optional[Callable]): an asyncronous Callable function that accepts a list of string objects and returns a list containing lists of floats. Defaults to None.
dtype (str): the default datatype to use when embedding text as byte arrays.
Used when setting `as_buffer=True` in calls to embed() and embed_many().
Defaults to 'float32'.

Raises:
ValueError if any of the provided functions accept or return incorrect types.
TypeError if any of the provided functions are not Callable objects.
ValueError: if any of the provided functions accept or return incorrect types.
TypeError: if any of the provided functions are not Callable objects.
ValueError: If an invalid dtype is provided.
"""

self._validate_embed(embed)
Expand All @@ -71,7 +76,7 @@ def __init__(
self._validate_aembed_many(aembed_many)
self._aembed_many_func = aembed_many

super().__init__(model=self.type, dims=self._set_model_dims())
super().__init__(model=self.type, dims=self._set_model_dims(), dtype=dtype)

def _validate_embed(self, func: Callable):
"""calls the func with dummy input and validates that it returns a vector"""
Expand Down Expand Up @@ -173,7 +178,7 @@ def embed(
if preprocess:
text = preprocess(text)

dtype = kwargs.pop("dtype", None)
dtype = kwargs.pop("dtype", self.dtype)

result = self._embed_func(text, **kwargs)
return self._process_embedding(result, as_buffer, dtype)
Expand Down Expand Up @@ -212,7 +217,7 @@ def embed_many(
if not self._embed_many_func:
raise NotImplementedError

dtype = kwargs.pop("dtype", None)
dtype = kwargs.pop("dtype", self.dtype)

embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
Expand Down Expand Up @@ -254,7 +259,7 @@ async def aembed(
if preprocess:
text = preprocess(text)

dtype = kwargs.pop("dtype", None)
dtype = kwargs.pop("dtype", self.dtype)

result = await self._aembed_func(text, **kwargs)
return self._process_embedding(result, as_buffer, dtype)
Expand Down Expand Up @@ -293,7 +298,7 @@ async def aembed_many(
if not self._aembed_many_func:
raise NotImplementedError

dtype = kwargs.pop("dtype", None)
dtype = kwargs.pop("dtype", self.dtype)

embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
Expand Down
Loading
Loading