diff --git a/docs/user_guide/vectorizers_04.ipynb b/docs/user_guide/vectorizers_04.ipynb index 83e18d00..06bd198f 100644 --- a/docs/user_guide/vectorizers_04.ipynb +++ b/docs/user_guide/vectorizers_04.ipynb @@ -31,7 +31,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -305,33 +305,25 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 3, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/tyler.hutcherson/redis/redis-vl-python/.venv/lib/python3.9/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly. To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()\n", - " return self.fget.__get__(instance, owner)()\n" - ] - }, { "data": { "text/plain": [ - "[0.00037810884532518685,\n", - " -0.05080341175198555,\n", - " -0.03514723479747772,\n", - " -0.02325104922056198,\n", - " -0.044158220291137695,\n", - " 0.020487844944000244,\n", - " 0.0014617963461205363,\n", - " 0.031261757016181946,\n", + "[0.0003780885017476976,\n", + " -0.05080340430140495,\n", + " -0.035147231072187424,\n", + " -0.02325103059411049,\n", + " -0.04415831342339516,\n", + " 0.02048780582845211,\n", + " 0.0014618589775636792,\n", + " 0.03126184269785881,\n", " 0.05605152249336243,\n", - " 0.018815357238054276]" + " 0.018815429881215096]" ] }, - "execution_count": 6, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -532,14 +524,14 @@ } ], "source": [ - "# from redisvl.utils.vectorize import MistralAITextVectorizer\n", + "from redisvl.utils.vectorize import MistralAITextVectorizer\n", "\n", - "# mistral = MistralAITextVectorizer()\n", + "mistral = MistralAITextVectorizer()\n", "\n", - "# # embed a sentence using their asyncronous method\n", - "# test = await mistral.aembed(\"This is a test sentence.\")\n", - "# print(\"Vector dimensions: \", len(test))\n", - "# print(test[:10])" + "# embed a sentence using their asyncronous method\n", + "test = await mistral.aembed(\"This is a test sentence.\")\n", + "print(\"Vector dimensions: \", len(test))\n", + "print(test[:10])" ] }, { @@ -588,9 +580,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Vector dimensions: 1024\n" + ] + } + ], "source": [ "from redisvl.utils.vectorize import BedrockTextVectorizer\n", "\n", @@ -823,6 +823,43 @@ " print(doc[\"text\"], doc[\"vector_distance\"])" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Selecting your float data type\n", + "When embedding text as byte arrays RedisVL supports 4 different floating point data types, `float16`, `float32`, `float64` and `bfloat16`.\n", + "Your dtype set for your vectorizer must match what is defined in your search index. If one is not explicitly set the default is `float32`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "vectorizer = HFTextVectorizer(dtype=\"float16\")\n", + "\n", + "# subsequent calls to embed('', as_buffer=True) and embed_many('', as_buffer=True) will now encode as float16\n", + "float16_bytes = vectorizer.embed('test sentence', as_buffer=True)\n", + "\n", + "# you can override this setting on each individual method call\n", + "float64_bytes = vectorizer.embed('test sentence', as_buffer=True, dtype=\"float64\")\n", + "\n", + "float16_bytes != float64_bytes" + ] + }, { "cell_type": "code", "execution_count": null, @@ -836,7 +873,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3.8.13 ('redisvl2')", + "display_name": "redisvl-dev", "language": "python", "name": "python3" }, @@ -852,12 +889,7 @@ "pygments_lexer": "ipython3", "version": "3.12.2" }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "9b1e6e9c2967143209c2f955cb869d1d3234f92dc4787f49f155f3abbdfb1316" - } - } + "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 diff --git a/poetry.lock b/poetry.lock index 38c7b00f..bf582bb6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. [[package]] name = "accessible-pygments" @@ -355,8 +355,8 @@ files = [ jmespath = ">=0.7.1,<2.0.0" python-dateutil = ">=2.1,<3.0.0" urllib3 = [ - {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""}, {version = ">=1.25.4,<2.2.0 || >2.2.0,<3", markers = "python_version >= \"3.10\""}, + {version = ">=1.25.4,<1.27", markers = "python_version < \"3.10\""}, ] [package.extras] @@ -2550,8 +2550,8 @@ files = [ [package.dependencies] numpy = [ - {version = ">1.20", markers = "python_version < \"3.10\""}, {version = ">=1.21.2", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">1.20", markers = "python_version < \"3.10\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.23.3", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, ] @@ -5738,4 +5738,4 @@ vertexai = ["google-cloud-aiplatform", "protobuf"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.13" -content-hash = "b8f26c50b4713ac0dd90faa202c1e4d3732481aaa4382a7bb4d2cb2fd776d3a4" +content-hash = "da2883e4b839be0a25e4dd5bb2861b27881f4dca9ea8ed159999d015b07410e5" diff --git a/pyproject.toml b/pyproject.toml index c1088641..954184fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,7 +32,7 @@ sentence-transformers = { version = ">=2.2.2", optional = true } google-cloud-aiplatform = { version = ">=1.26", optional = true } protobuf = { version = ">=5.29.1,<6.0.0.dev0", optional = true } cohere = { version = ">=4.44", optional = true } -mistralai = { version = ">=0.2.0", optional = true } +mistralai = { version = ">=1.0.0", optional = true } boto3 = { version = ">=1.34.0", optional = true } [tool.poetry.extras] diff --git a/redisvl/utils/vectorize/base.py b/redisvl/utils/vectorize/base.py index 238aa6a0..d27553be 100644 --- a/redisvl/utils/vectorize/base.py +++ b/redisvl/utils/vectorize/base.py @@ -5,6 +5,7 @@ from pydantic.v1 import BaseModel, validator from redisvl.redis.utils import array_to_buffer +from redisvl.schema.fields import VectorDataType class Vectorizers(Enum): @@ -19,11 +20,22 @@ class Vectorizers(Enum): class BaseVectorizer(BaseModel, ABC): model: str dims: int + dtype: str @property def type(self) -> str: return "base" + @validator("dtype") + def check_dtype(dtype): + try: + VectorDataType(dtype.upper()) + except ValueError: + raise ValueError( + f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}" + ) + return dtype + @validator("dims") @classmethod def check_dims(cls, value): @@ -81,13 +93,7 @@ def batchify(self, seq: list, size: int, preprocess: Optional[Callable] = None): else: yield seq[pos : pos + size] - def _process_embedding( - self, embedding: List[float], as_buffer: bool, dtype: Optional[str] - ): + def _process_embedding(self, embedding: List[float], as_buffer: bool, dtype: str): if as_buffer: - if not dtype: - raise RuntimeError( - "dtype is required if converting from float to byte string." - ) return array_to_buffer(embedding, dtype) return embedding diff --git a/redisvl/utils/vectorize/text/azureopenai.py b/redisvl/utils/vectorize/text/azureopenai.py index a387e238..5af97dfa 100644 --- a/redisvl/utils/vectorize/text/azureopenai.py +++ b/redisvl/utils/vectorize/text/azureopenai.py @@ -52,7 +52,10 @@ class AzureOpenAITextVectorizer(BaseVectorizer): _aclient: Any = PrivateAttr() def __init__( - self, model: str = "text-embedding-ada-002", api_config: Optional[Dict] = None + self, + model: str = "text-embedding-ada-002", + api_config: Optional[Dict] = None, + dtype: str = "float32", ): """Initialize the AzureOpenAI vectorizer. @@ -63,13 +66,17 @@ def __init__( api_config (Optional[Dict], optional): Dictionary containing the API key, API version, Azure endpoint, and any other API options. Defaults to None. + dtype (str): the default datatype to use when embedding text as byte arrays. + Used when setting `as_buffer=True` in calls to embed() and embed_many(). + Defaults to 'float32'. Raises: ImportError: If the openai library is not installed. ValueError: If the AzureOpenAI API key, version, or endpoint are not provided. + ValueError: If an invalid dtype is provided. """ self._initialize_clients(api_config) - super().__init__(model=model, dims=self._set_model_dims(model)) + super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype) def _initialize_clients(self, api_config: Optional[Dict]): """ @@ -190,7 +197,7 @@ def embed_many( if len(texts) > 0 and not isinstance(texts[0], str): raise TypeError("Must pass in a list of str values to embed.") - dtype = kwargs.pop("dtype", None) + dtype = kwargs.pop("dtype", self.dtype) embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): @@ -234,7 +241,7 @@ def embed( if preprocess: text = preprocess(text) - dtype = kwargs.pop("dtype", None) + dtype = kwargs.pop("dtype", self.dtype) result = self._client.embeddings.create(input=[text], model=self.model) return self._process_embedding(result.data[0].embedding, as_buffer, dtype) @@ -274,7 +281,7 @@ async def aembed_many( if len(texts) > 0 and not isinstance(texts[0], str): raise TypeError("Must pass in a list of str values to embed.") - dtype = kwargs.pop("dtype", None) + dtype = kwargs.pop("dtype", self.dtype) embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): @@ -320,7 +327,7 @@ async def aembed( if preprocess: text = preprocess(text) - dtype = kwargs.pop("dtype", None) + dtype = kwargs.pop("dtype", self.dtype) result = await self._aclient.embeddings.create(input=[text], model=self.model) return self._process_embedding(result.data[0].embedding, as_buffer, dtype) diff --git a/redisvl/utils/vectorize/text/bedrock.py b/redisvl/utils/vectorize/text/bedrock.py index 3414880a..091beadb 100644 --- a/redisvl/utils/vectorize/text/bedrock.py +++ b/redisvl/utils/vectorize/text/bedrock.py @@ -49,6 +49,7 @@ def __init__( self, model: str = "amazon.titan-embed-text-v2:0", api_config: Optional[Dict[str, str]] = None, + dtype: str = "float32", ) -> None: """Initialize the AWS Bedrock Vectorizer. @@ -57,10 +58,14 @@ def __init__( api_config (Optional[Dict[str, str]]): AWS credentials and config. Can include: aws_access_key_id, aws_secret_access_key, aws_region If not provided, will use environment variables. + dtype (str): the default datatype to use when embedding text as byte arrays. + Used when setting `as_buffer=True` in calls to embed() and embed_many(). + Defaults to 'float32'. Raises: ValueError: If credentials are not provided in config or environment. ImportError: If boto3 is not installed. + ValueError: If an invalid dtype is provided. """ try: import boto3 # type: ignore @@ -94,7 +99,7 @@ def __init__( region_name=aws_region, ) - super().__init__(model=model, dims=self._set_model_dims(model)) + super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype) def _set_model_dims(self, model: str) -> int: """Initialize model and determine embedding dimensions.""" @@ -145,7 +150,7 @@ def embed( response_body = json.loads(response["body"].read()) embedding = response_body["embedding"] - dtype = kwargs.pop("dtype", None) + dtype = kwargs.pop("dtype", self.dtype) return self._process_embedding(embedding, as_buffer, dtype) @retry( @@ -181,7 +186,7 @@ def embed_many( raise TypeError("Texts must be a list of strings") embeddings: List[List[float]] = [] - dtype = kwargs.pop("dtype", None) + dtype = kwargs.pop("dtype", self.dtype) for batch in self.batchify(texts, batch_size, preprocess): # Process each text in the batch individually since Bedrock diff --git a/redisvl/utils/vectorize/text/cohere.py b/redisvl/utils/vectorize/text/cohere.py index 469035fa..c30863b7 100644 --- a/redisvl/utils/vectorize/text/cohere.py +++ b/redisvl/utils/vectorize/text/cohere.py @@ -47,7 +47,10 @@ class CohereTextVectorizer(BaseVectorizer): _client: Any = PrivateAttr() def __init__( - self, model: str = "embed-english-v3.0", api_config: Optional[Dict] = None + self, + model: str = "embed-english-v3.0", + api_config: Optional[Dict] = None, + dtype: str = "float32", ): """Initialize the Cohere vectorizer. @@ -57,14 +60,17 @@ def __init__( model (str): Model to use for embedding. Defaults to 'embed-english-v3.0'. api_config (Optional[Dict], optional): Dictionary containing the API key. Defaults to None. + dtype (str): the default datatype to use when embedding text as byte arrays. + Used when setting `as_buffer=True` in calls to embed() and embed_many(). + Defaults to 'float32'. Raises: ImportError: If the cohere library is not installed. ValueError: If the API key is not provided. - + ValueError: If an invalid dtype is provided. """ self._initialize_client(api_config) - super().__init__(model=model, dims=self._set_model_dims(model)) + super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype) def _initialize_client(self, api_config: Optional[Dict]): """ @@ -159,7 +165,7 @@ def embed( if preprocess: text = preprocess(text) - dtype = kwargs.pop("dtype", None) + dtype = kwargs.pop("dtype", self.dtype) embedding = self._client.embed( texts=[text], model=self.model, input_type=input_type @@ -228,7 +234,7 @@ def embed_many( See https://docs.cohere.com/reference/embed." ) - dtype = kwargs.pop("dtype", None) + dtype = kwargs.pop("dtype", self.dtype) embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): diff --git a/redisvl/utils/vectorize/text/custom.py b/redisvl/utils/vectorize/text/custom.py index a950f6df..6c8787d5 100644 --- a/redisvl/utils/vectorize/text/custom.py +++ b/redisvl/utils/vectorize/text/custom.py @@ -7,7 +7,7 @@ class CustomTextVectorizer(BaseVectorizer): - """The CustomTextVectorizer class wraps user-defined embeding methods to create + """The CustomTextVectorizer class wraps user-defined embedding methods to create embeddings for text data. This vectorizer is designed to accept a provided callable text vectorizer and @@ -44,6 +44,7 @@ def __init__( embed_many: Optional[Callable] = None, aembed: Optional[Callable] = None, aembed_many: Optional[Callable] = None, + dtype: str = "float32", ): """Initialize the Custom vectorizer. @@ -52,10 +53,14 @@ def __init__( embed_many (Optional[Callable)]: a Callable function that accepts a list of string objects and returns a list containing lists of floats. Defaults to None. aembed (Optional[Callable]): an asyncronous Callable function that accepts a string object and returns a lists of floats. Defaults to None. aembed_many (Optional[Callable]): an asyncronous Callable function that accepts a list of string objects and returns a list containing lists of floats. Defaults to None. + dtype (str): the default datatype to use when embedding text as byte arrays. + Used when setting `as_buffer=True` in calls to embed() and embed_many(). + Defaults to 'float32'. Raises: - ValueError if any of the provided functions accept or return incorrect types. - TypeError if any of the provided functions are not Callable objects. + ValueError: if any of the provided functions accept or return incorrect types. + TypeError: if any of the provided functions are not Callable objects. + ValueError: If an invalid dtype is provided. """ self._validate_embed(embed) @@ -71,7 +76,7 @@ def __init__( self._validate_aembed_many(aembed_many) self._aembed_many_func = aembed_many - super().__init__(model=self.type, dims=self._set_model_dims()) + super().__init__(model=self.type, dims=self._set_model_dims(), dtype=dtype) def _validate_embed(self, func: Callable): """calls the func with dummy input and validates that it returns a vector""" @@ -173,7 +178,7 @@ def embed( if preprocess: text = preprocess(text) - dtype = kwargs.pop("dtype", None) + dtype = kwargs.pop("dtype", self.dtype) result = self._embed_func(text, **kwargs) return self._process_embedding(result, as_buffer, dtype) @@ -212,7 +217,7 @@ def embed_many( if not self._embed_many_func: raise NotImplementedError - dtype = kwargs.pop("dtype", None) + dtype = kwargs.pop("dtype", self.dtype) embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): @@ -254,7 +259,7 @@ async def aembed( if preprocess: text = preprocess(text) - dtype = kwargs.pop("dtype", None) + dtype = kwargs.pop("dtype", self.dtype) result = await self._aembed_func(text, **kwargs) return self._process_embedding(result, as_buffer, dtype) @@ -293,7 +298,7 @@ async def aembed_many( if not self._aembed_many_func: raise NotImplementedError - dtype = kwargs.pop("dtype", None) + dtype = kwargs.pop("dtype", self.dtype) embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): diff --git a/redisvl/utils/vectorize/text/huggingface.py b/redisvl/utils/vectorize/text/huggingface.py index b570a03f..b2fbabc0 100644 --- a/redisvl/utils/vectorize/text/huggingface.py +++ b/redisvl/utils/vectorize/text/huggingface.py @@ -33,7 +33,10 @@ class HFTextVectorizer(BaseVectorizer): _client: Any = PrivateAttr() def __init__( - self, model: str = "sentence-transformers/all-mpnet-base-v2", **kwargs + self, + model: str = "sentence-transformers/all-mpnet-base-v2", + dtype: str = "float32", + **kwargs, ): """Initialize the Hugging Face text vectorizer. @@ -41,13 +44,17 @@ def __init__( model (str): The pre-trained model from Hugging Face's Sentence Transformers to be used for embedding. Defaults to 'sentence-transformers/all-mpnet-base-v2'. + dtype (str): the default datatype to use when embedding text as byte arrays. + Used when setting `as_buffer=True` in calls to embed() and embed_many(). + Defaults to 'float32'. Raises: ImportError: If the sentence-transformers library is not installed. ValueError: If there is an error setting the embedding model dimensions. + ValueError: If an invalid dtype is provided. """ self._initialize_client(model) - super().__init__(model=model, dims=self._set_model_dims()) + super().__init__(model=model, dims=self._set_model_dims(), dtype=dtype) def _initialize_client(self, model: str): """Setup the HuggingFace client""" @@ -100,7 +107,7 @@ def embed( if preprocess: text = preprocess(text) - dtype = kwargs.pop("dtype", None) + dtype = kwargs.pop("dtype", self.dtype) embedding = self._client.encode([text], **kwargs)[0] return self._process_embedding(embedding.tolist(), as_buffer, dtype) @@ -136,7 +143,7 @@ def embed_many( if len(texts) > 0 and not isinstance(texts[0], str): raise TypeError("Must pass in a list of str values to embed.") - dtype = kwargs.pop("dtype", None) + dtype = kwargs.pop("dtype", self.dtype) embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): diff --git a/redisvl/utils/vectorize/text/mistral.py b/redisvl/utils/vectorize/text/mistral.py index 28377778..bb636b33 100644 --- a/redisvl/utils/vectorize/text/mistral.py +++ b/redisvl/utils/vectorize/text/mistral.py @@ -44,9 +44,13 @@ class MistralAITextVectorizer(BaseVectorizer): """ _client: Any = PrivateAttr() - _aclient: Any = PrivateAttr() - def __init__(self, model: str = "mistral-embed", api_config: Optional[Dict] = None): + def __init__( + self, + model: str = "mistral-embed", + api_config: Optional[Dict] = None, + dtype: str = "float32", + ): """Initialize the MistralAI vectorizer. Args: @@ -54,13 +58,17 @@ def __init__(self, model: str = "mistral-embed", api_config: Optional[Dict] = No 'text-embedding-ada-002'. api_config (Optional[Dict], optional): Dictionary containing the API key. Defaults to None. + dtype (str): the default datatype to use when embedding text as byte arrays. + Used when setting `as_buffer=True` in calls to embed() and embed_many(). + Defaults to 'float32'. Raises: ImportError: If the mistralai library is not installed. ValueError: If the Mistral API key is not provided. + ValueError: If an invalid dtype is provided. """ self._initialize_clients(api_config) - super().__init__(model=model, dims=self._set_model_dims(model)) + super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype) def _initialize_clients(self, api_config: Optional[Dict]): """ @@ -69,8 +77,7 @@ def _initialize_clients(self, api_config: Optional[Dict]): """ # Dynamic import of the mistralai module try: - from mistralai.async_client import MistralAsyncClient - from mistralai.client import MistralClient + from mistralai import Mistral except ImportError: raise ImportError( "MistralAI vectorizer requires the mistralai library. \ @@ -88,13 +95,12 @@ def _initialize_clients(self, api_config: Optional[Dict]): environment variable." ) - self._client = MistralClient(api_key=api_key) - self._aclient = MistralAsyncClient(api_key=api_key) + self._client = Mistral(api_key=api_key) def _set_model_dims(self, model) -> int: try: embedding = ( - self._client.embeddings(model=model, input=["dimension test"]) + self._client.embeddings.create(model=model, inputs=["dimension test"]) .data[0] .embedding ) @@ -140,11 +146,11 @@ def embed_many( if len(texts) > 0 and not isinstance(texts[0], str): raise TypeError("Must pass in a list of str values to embed.") - dtype = kwargs.pop("dtype", None) + dtype = kwargs.pop("dtype", self.dtype) embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = self._client.embeddings(model=self.model, input=batch) + response = self._client.embeddings.create(model=self.model, inputs=batch) embeddings += [ self._process_embedding(r.embedding, as_buffer, dtype) for r in response.data @@ -184,9 +190,9 @@ def embed( if preprocess: text = preprocess(text) - dtype = kwargs.pop("dtype", None) + dtype = kwargs.pop("dtype", self.dtype) - result = self._client.embeddings(model=self.model, input=[text]) + result = self._client.embeddings.create(model=self.model, inputs=[text]) return self._process_embedding(result.data[0].embedding, as_buffer, dtype) @retry( @@ -224,11 +230,13 @@ async def aembed_many( if len(texts) > 0 and not isinstance(texts[0], str): raise TypeError("Must pass in a list of str values to embed.") - dtype = kwargs.pop("dtype", None) + dtype = kwargs.pop("dtype", self.dtype) embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = await self._aclient.embeddings(model=self.model, input=batch) + response = await self._client.embeddings.create_async( + model=self.model, inputs=batch + ) embeddings += [ self._process_embedding(r.embedding, as_buffer, dtype) for r in response.data @@ -268,9 +276,11 @@ async def aembed( if preprocess: text = preprocess(text) - dtype = kwargs.pop("dtype", None) + dtype = kwargs.pop("dtype", self.dtype) - result = await self._aclient.embeddings(model=self.model, input=[text]) + result = await self._client.embeddings.create_async( + model=self.model, inputs=[text] + ) return self._process_embedding(result.data[0].embedding, as_buffer, dtype) @property diff --git a/redisvl/utils/vectorize/text/openai.py b/redisvl/utils/vectorize/text/openai.py index aad29198..7dd2cf93 100644 --- a/redisvl/utils/vectorize/text/openai.py +++ b/redisvl/utils/vectorize/text/openai.py @@ -47,7 +47,10 @@ class OpenAITextVectorizer(BaseVectorizer): _aclient: Any = PrivateAttr() def __init__( - self, model: str = "text-embedding-ada-002", api_config: Optional[Dict] = None + self, + model: str = "text-embedding-ada-002", + api_config: Optional[Dict] = None, + dtype: str = "float32", ): """Initialize the OpenAI vectorizer. @@ -56,13 +59,17 @@ def __init__( 'text-embedding-ada-002'. api_config (Optional[Dict], optional): Dictionary containing the API key and any additional OpenAI API options. Defaults to None. + dtype (str): the default datatype to use when embedding text as byte arrays. + Used when setting `as_buffer=True` in calls to embed() and embed_many(). + Defaults to 'float32'. Raises: ImportError: If the openai library is not installed. ValueError: If the OpenAI API key is not provided. + ValueError: If an invalid dtype is provided. """ self._initialize_clients(api_config) - super().__init__(model=model, dims=self._set_model_dims(model)) + super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype) def _initialize_clients(self, api_config: Optional[Dict]): """ @@ -144,7 +151,7 @@ def embed_many( if len(texts) > 0 and not isinstance(texts[0], str): raise TypeError("Must pass in a list of str values to embed.") - dtype = kwargs.pop("dtype", None) + dtype = kwargs.pop("dtype", self.dtype) embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): @@ -188,7 +195,7 @@ def embed( if preprocess: text = preprocess(text) - dtype = kwargs.pop("dtype", None) + dtype = kwargs.pop("dtype", self.dtype) result = self._client.embeddings.create(input=[text], model=self.model) return self._process_embedding(result.data[0].embedding, as_buffer, dtype) @@ -228,7 +235,7 @@ async def aembed_many( if len(texts) > 0 and not isinstance(texts[0], str): raise TypeError("Must pass in a list of str values to embed.") - dtype = kwargs.pop("dtype", None) + dtype = kwargs.pop("dtype", self.dtype) embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): @@ -274,7 +281,7 @@ async def aembed( if preprocess: text = preprocess(text) - dtype = kwargs.pop("dtype", None) + dtype = kwargs.pop("dtype", self.dtype) result = await self._aclient.embeddings.create(input=[text], model=self.model) return self._process_embedding(result.data[0].embedding, as_buffer, dtype) diff --git a/redisvl/utils/vectorize/text/vertexai.py b/redisvl/utils/vectorize/text/vertexai.py index f0c3c475..6ddc28b6 100644 --- a/redisvl/utils/vectorize/text/vertexai.py +++ b/redisvl/utils/vectorize/text/vertexai.py @@ -44,7 +44,10 @@ class VertexAITextVectorizer(BaseVectorizer): _client: Any = PrivateAttr() def __init__( - self, model: str = "textembedding-gecko", api_config: Optional[Dict] = None + self, + model: str = "textembedding-gecko", + api_config: Optional[Dict] = None, + dtype: str = "float32", ): """Initialize the VertexAI vectorizer. @@ -53,13 +56,17 @@ def __init__( 'textembedding-gecko'. api_config (Optional[Dict], optional): Dictionary containing the API config details. Defaults to None. + dtype (str): the default datatype to use when embedding text as byte arrays. + Used when setting `as_buffer=True` in calls to embed() and embed_many(). + Defaults to 'float32'. Raises: ImportError: If the google-cloud-aiplatform library is not installed. ValueError: If the API key is not provided. + ValueError: If an invalid dtype is provided. """ self._initialize_client(model, api_config) - super().__init__(model=model, dims=self._set_model_dims()) + super().__init__(model=model, dims=self._set_model_dims(), dtype=dtype) def _initialize_client(self, model: str, api_config: Optional[Dict]): """ @@ -151,7 +158,7 @@ def embed_many( if len(texts) > 0 and not isinstance(texts[0], str): raise TypeError("Must pass in a list of str values to embed.") - dtype = kwargs.pop("dtype", None) + dtype = kwargs.pop("dtype", self.dtype) embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): @@ -194,7 +201,7 @@ def embed( if preprocess: text = preprocess(text) - dtype = kwargs.pop("dtype", None) + dtype = kwargs.pop("dtype", self.dtype) result = self._client.get_embeddings([text]) return self._process_embedding(result[0].values, as_buffer, dtype) diff --git a/tests/integration/test_vectorizers.py b/tests/integration/test_vectorizers.py index 94cc7fd0..dd3112fa 100644 --- a/tests/integration/test_vectorizers.py +++ b/tests/integration/test_vectorizers.py @@ -28,7 +28,7 @@ def skip_vectorizer() -> bool: CohereTextVectorizer, AzureOpenAITextVectorizer, BedrockTextVectorizer, - # MistralAITextVectorizer, + MistralAITextVectorizer, CustomTextVectorizer, ] ) @@ -238,11 +238,68 @@ def bad_return_type(text: str) -> str: ) +@pytest.mark.parametrize( + "vector_class", + [ + AzureOpenAITextVectorizer, + BedrockTextVectorizer, + CohereTextVectorizer, + CustomTextVectorizer, + HFTextVectorizer, + # MistralAITextVectorizer, + OpenAITextVectorizer, + VertexAITextVectorizer, + ], +) +def test_dtypes(vector_class, skip_vectorizer): + if skip_vectorizer: + pytest.skip("Skipping vectorizer instantiation...") + + # test dtype defaults to float32 + if issubclass(vector_class, CustomTextVectorizer): + vectorizer = vector_class(embed=lambda x, input_type=None: [1.0, 2.0, 3.0]) + elif issubclass(vector_class, AzureOpenAITextVectorizer): + vectorizer = vector_class( + model=os.getenv("AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002") + ) + else: + vectorizer = vector_class() + assert vectorizer.dtype == "float32" + + # test initializing dtype in constructor + for dtype in ["float16", "float32", "float64", "bfloat16"]: + if issubclass(vector_class, CustomTextVectorizer): + vectorizer = vector_class(embed=lambda x: [1.0, 2.0, 3.0], dtype=dtype) + elif issubclass(vector_class, AzureOpenAITextVectorizer): + vectorizer = vector_class( + model=os.getenv( + "AZURE_OPENAI_DEPLOYMENT_NAME", "text-embedding-ada-002" + ), + dtype=dtype, + ) + else: + vectorizer = vector_class(dtype=dtype) + assert vectorizer.dtype == dtype + + # test validation of dtype on init + if issubclass(vector_class, CustomTextVectorizer): + pytest.skip("skipping custom text vectorizer") + + with pytest.raises(ValueError): + vectorizer = vector_class(dtype="float25") + + with pytest.raises(ValueError): + vectorizer = vector_class(dtype=7) + + with pytest.raises(ValueError): + vectorizer = vector_class(dtype=None) + + @pytest.fixture( params=[ OpenAITextVectorizer, BedrockTextVectorizer, - # MistralAITextVectorizer, + MistralAITextVectorizer, CustomTextVectorizer, ] )