Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

sets the default datatype in our vectorizers to float32 if not specified by users #253

Merged
merged 13 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 67 additions & 35 deletions docs/user_guide/vectorizers_04.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -305,33 +305,25 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/tyler.hutcherson/redis/redis-vl-python/.venv/lib/python3.9/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n",
" return self.fget.__get__(instance, owner)()\n"
]
},
{
"data": {
"text/plain": [
"[0.00037810884532518685,\n",
" -0.05080341175198555,\n",
" -0.03514723479747772,\n",
" -0.02325104922056198,\n",
" -0.044158220291137695,\n",
" 0.020487844944000244,\n",
" 0.0014617963461205363,\n",
" 0.031261757016181946,\n",
"[0.0003780885017476976,\n",
" -0.05080340430140495,\n",
" -0.035147231072187424,\n",
" -0.02325103059411049,\n",
" -0.04415831342339516,\n",
" 0.02048780582845211,\n",
" 0.0014618589775636792,\n",
" 0.03126184269785881,\n",
" 0.05605152249336243,\n",
" 0.018815357238054276]"
" 0.018815429881215096]"
]
},
"execution_count": 6,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -532,14 +524,14 @@
}
],
"source": [
"# from redisvl.utils.vectorize import MistralAITextVectorizer\n",
"from redisvl.utils.vectorize import MistralAITextVectorizer\n",
"\n",
"# mistral = MistralAITextVectorizer()\n",
"mistral = MistralAITextVectorizer()\n",
"\n",
"# # embed a sentence using their asyncronous method\n",
"# test = await mistral.aembed(\"This is a test sentence.\")\n",
"# print(\"Vector dimensions: \", len(test))\n",
"# print(test[:10])"
"# embed a sentence using their asyncronous method\n",
"test = await mistral.aembed(\"This is a test sentence.\")\n",
"print(\"Vector dimensions: \", len(test))\n",
"print(test[:10])"
]
},
{
Expand Down Expand Up @@ -588,9 +580,17 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Vector dimensions: 1024\n"
]
}
],
"source": [
"from redisvl.utils.vectorize import BedrockTextVectorizer\n",
"\n",
Expand Down Expand Up @@ -823,6 +823,43 @@
" print(doc[\"text\"], doc[\"vector_distance\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Selecting your float data type\n",
"When embedding text as byte arrays RedisVL supports 4 different floating point data types, `float16`, `float32`, `float64` and `bfloat16`.\n",
"Your dtype set for your vectorizer must match what is defined in your search index. If one is not explicitly set the default is `float32`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vectorizer = HFTextVectorizer(dtype=\"float16\")\n",
"\n",
"# subsequent calls to embed('', as_buffer=True) and embed_many('', as_buffer=True) will now encode as float16\n",
"float16_bytes = vectorizer.embed('test sentence', as_buffer=True)\n",
"\n",
"# you can override this setting on each individual method call\n",
"float64_bytes = vectorizer.embed('test sentence', as_buffer=True, dtype=\"float64\")\n",
"\n",
"float16_bytes != float64_bytes"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -836,7 +873,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.13 ('redisvl2')",
"display_name": "redisvl-dev",
"language": "python",
"name": "python3"
},
Expand All @@ -852,12 +889,7 @@
"pygments_lexer": "ipython3",
"version": "3.12.2"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "9b1e6e9c2967143209c2f955cb869d1d3234f92dc4787f49f155f3abbdfb1316"
}
}
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
Expand Down
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ sentence-transformers = { version = ">=2.2.2", optional = true }
google-cloud-aiplatform = { version = ">=1.26", optional = true }
protobuf = { version = ">=5.29.1,<6.0.0.dev0", optional = true }
cohere = { version = ">=4.44", optional = true }
mistralai = { version = ">=0.2.0", optional = true }
mistralai = { version = ">=1.0.0", optional = true }
boto3 = { version = ">=1.34.0", optional = true }

[tool.poetry.extras]
Expand Down
20 changes: 13 additions & 7 deletions redisvl/utils/vectorize/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pydantic.v1 import BaseModel, validator

from redisvl.redis.utils import array_to_buffer
from redisvl.schema.fields import VectorDataType


class Vectorizers(Enum):
Expand All @@ -19,11 +20,22 @@ class Vectorizers(Enum):
class BaseVectorizer(BaseModel, ABC):
model: str
dims: int
dtype: str

@property
def type(self) -> str:
return "base"

@validator("dtype")
justin-cechmanek marked this conversation as resolved.
Show resolved Hide resolved
def check_dtype(dtype):
try:
VectorDataType(dtype.upper())
except ValueError:
raise ValueError(
f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}"
)
return dtype

@validator("dims")
@classmethod
def check_dims(cls, value):
Expand Down Expand Up @@ -81,13 +93,7 @@ def batchify(self, seq: list, size: int, preprocess: Optional[Callable] = None):
else:
yield seq[pos : pos + size]

def _process_embedding(
self, embedding: List[float], as_buffer: bool, dtype: Optional[str]
):
def _process_embedding(self, embedding: List[float], as_buffer: bool, dtype: str):
if as_buffer:
if not dtype:
raise RuntimeError(
"dtype is required if converting from float to byte string."
)
return array_to_buffer(embedding, dtype)
return embedding
19 changes: 13 additions & 6 deletions redisvl/utils/vectorize/text/azureopenai.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ class AzureOpenAITextVectorizer(BaseVectorizer):
_aclient: Any = PrivateAttr()

def __init__(
self, model: str = "text-embedding-ada-002", api_config: Optional[Dict] = None
self,
model: str = "text-embedding-ada-002",
api_config: Optional[Dict] = None,
dtype: str = "float32",
):
"""Initialize the AzureOpenAI vectorizer.

Expand All @@ -63,13 +66,17 @@ def __init__(
api_config (Optional[Dict], optional): Dictionary containing the
API key, API version, Azure endpoint, and any other API options.
Defaults to None.
dtype (str): the default datatype to use when embedding text as byte arrays.
Used when setting `as_buffer=True` in calls to embed() and embed_many().
Defaults to 'float32'.

Raises:
ImportError: If the openai library is not installed.
ValueError: If the AzureOpenAI API key, version, or endpoint are not provided.
ValueError: If an invalid dtype is provided.
"""
self._initialize_clients(api_config)
super().__init__(model=model, dims=self._set_model_dims(model))
super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype)

def _initialize_clients(self, api_config: Optional[Dict]):
"""
Expand Down Expand Up @@ -190,7 +197,7 @@ def embed_many(
if len(texts) > 0 and not isinstance(texts[0], str):
raise TypeError("Must pass in a list of str values to embed.")

dtype = kwargs.pop("dtype", None)
tylerhutcherson marked this conversation as resolved.
Show resolved Hide resolved
dtype = kwargs.pop("dtype", self.dtype)

embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
Expand Down Expand Up @@ -234,7 +241,7 @@ def embed(
if preprocess:
text = preprocess(text)

dtype = kwargs.pop("dtype", None)
dtype = kwargs.pop("dtype", self.dtype)

result = self._client.embeddings.create(input=[text], model=self.model)
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
Expand Down Expand Up @@ -274,7 +281,7 @@ async def aembed_many(
if len(texts) > 0 and not isinstance(texts[0], str):
raise TypeError("Must pass in a list of str values to embed.")

dtype = kwargs.pop("dtype", None)
dtype = kwargs.pop("dtype", self.dtype)

embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
Expand Down Expand Up @@ -320,7 +327,7 @@ async def aembed(
if preprocess:
text = preprocess(text)

dtype = kwargs.pop("dtype", None)
dtype = kwargs.pop("dtype", self.dtype)

result = await self._aclient.embeddings.create(input=[text], model=self.model)
return self._process_embedding(result.data[0].embedding, as_buffer, dtype)
Expand Down
11 changes: 8 additions & 3 deletions redisvl/utils/vectorize/text/bedrock.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
self,
model: str = "amazon.titan-embed-text-v2:0",
api_config: Optional[Dict[str, str]] = None,
dtype: str = "float32",
) -> None:
"""Initialize the AWS Bedrock Vectorizer.

Expand All @@ -57,10 +58,14 @@ def __init__(
api_config (Optional[Dict[str, str]]): AWS credentials and config.
Can include: aws_access_key_id, aws_secret_access_key, aws_region
If not provided, will use environment variables.
dtype (str): the default datatype to use when embedding text as byte arrays.
Used when setting `as_buffer=True` in calls to embed() and embed_many().
Defaults to 'float32'.

Raises:
ValueError: If credentials are not provided in config or environment.
ImportError: If boto3 is not installed.
ValueError: If an invalid dtype is provided.
"""
try:
import boto3 # type: ignore
Expand Down Expand Up @@ -94,7 +99,7 @@ def __init__(
region_name=aws_region,
)

super().__init__(model=model, dims=self._set_model_dims(model))
super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype)

def _set_model_dims(self, model: str) -> int:
"""Initialize model and determine embedding dimensions."""
Expand Down Expand Up @@ -145,7 +150,7 @@ def embed(
response_body = json.loads(response["body"].read())
embedding = response_body["embedding"]

dtype = kwargs.pop("dtype", None)
dtype = kwargs.pop("dtype", self.dtype)
return self._process_embedding(embedding, as_buffer, dtype)

@retry(
Expand Down Expand Up @@ -181,7 +186,7 @@ def embed_many(
raise TypeError("Texts must be a list of strings")

embeddings: List[List[float]] = []
dtype = kwargs.pop("dtype", None)
dtype = kwargs.pop("dtype", self.dtype)

for batch in self.batchify(texts, batch_size, preprocess):
# Process each text in the batch individually since Bedrock
Expand Down
Loading
Loading