From 4c5b045576db8d829c40a956ffbc37e474f2e189 Mon Sep 17 00:00:00 2001 From: nabsabraham Date: Fri, 12 Jan 2024 17:01:50 -0500 Subject: [PATCH] lint --- redisvl/vectorize/text/__init__.py | 9 ++++- redisvl/vectorize/text/cohere.py | 58 +++++++++++++++++---------- tests/integration/test_vectorizers.py | 14 ++++++- 3 files changed, 55 insertions(+), 26 deletions(-) diff --git a/redisvl/vectorize/text/__init__.py b/redisvl/vectorize/text/__init__.py index 3cbf5fa7..633c7c1e 100644 --- a/redisvl/vectorize/text/__init__.py +++ b/redisvl/vectorize/text/__init__.py @@ -1,6 +1,11 @@ +from redisvl.vectorize.text.cohere import CohereTextVectorizer from redisvl.vectorize.text.huggingface import HFTextVectorizer from redisvl.vectorize.text.openai import OpenAITextVectorizer from redisvl.vectorize.text.vertexai import VertexAITextVectorizer -from redisvl.vectorize.text.cohere import CohereTextVectorizer -__all__ = ["OpenAITextVectorizer", "HFTextVectorizer", "VertexAITextVectorizer", "CohereTextVectorizer"] +__all__ = [ + "OpenAITextVectorizer", + "HFTextVectorizer", + "VertexAITextVectorizer", + "CohereTextVectorizer", +] diff --git a/redisvl/vectorize/text/cohere.py b/redisvl/vectorize/text/cohere.py index 954d4aff..31c84324 100644 --- a/redisvl/vectorize/text/cohere.py +++ b/redisvl/vectorize/text/cohere.py @@ -1,25 +1,29 @@ import os from typing import Callable, Dict, List, Optional + from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type from redisvl.vectorize.base import BaseVectorizer + class CohereTextVectorizer(BaseVectorizer): """ The CohereTextVectorizer class utilizes Cohere's API to generate embeddings for text data. - This vectorizer is designed to interact with Cohere's /embed API, requiring an API key for authentication. The key - can be provided directly in the `api_config` dictionary or through the `COHERE_API_KEY` environment variable. Users - must obtain an API key from Cohere's website (https://dashboard.cohere.com/). Additionally, the `cohere` python + This vectorizer is designed to interact with Cohere's /embed API, requiring an API key for authentication. The key + can be provided directly in the `api_config` dictionary or through the `COHERE_API_KEY` environment variable. Users + must obtain an API key from Cohere's website (https://dashboard.cohere.com/). Additionally, the `cohere` python client must be installed with `pip install cohere`. - The vectorizer supports only synchronous operations, allows for batch processing of texts and flexibility in + The vectorizer supports only synchronous operations, allows for batch processing of texts and flexibility in handling preprocessing tasks. """ - def __init__(self, model: str = 'embed-english-v3.0', api_config: Optional[Dict] = None): + def __init__( + self, model: str = "embed-english-v3.0", api_config: Optional[Dict] = None + ): """Initialize the Cohere vectorizer. Visit https://cohere.ai/embed to learn about embeddings. Args: @@ -40,7 +44,7 @@ def __init__(self, model: str = 'embed-english-v3.0', api_config: Optional[Dict] raise ImportError( "Cohere vectorizer requires the cohere library. Please install with `pip install cohere`" ) - + # Fetch the API key from api_config or environment variable api_key = ( api_config.get("api_key") if api_config else os.getenv("COHERE_API_KEY") @@ -50,15 +54,17 @@ def __init__(self, model: str = 'embed-english-v3.0', api_config: Optional[Dict] "Cohere API key is required. " "Provide it in api_config or set the COHERE_API_KEY environment variable." ) - + self._model = model self._model_client = cohere.Client(api_key) self._dims = self._set_model_dims() - + def _set_model_dims(self) -> int: try: embedding = self._model_client.embed( - texts=["dimension test"], model=self._model, input_type="search_document", + texts=["dimension test"], + model=self._model, + input_type="search_document", ).embeddings[0] except (KeyError, IndexError) as ke: raise ValueError(f"Unexpected response from the Cohere API: {str(ke)}") @@ -66,11 +72,11 @@ def _set_model_dims(self) -> int: # fall back (TODO get more specific) raise ValueError(f"Error setting embedding model dimensions: {str(e)}") return len(embedding) - + def embed( self, text: str, - input_type: str, + input_type: str, preprocess: Optional[Callable] = None, as_buffer: bool = False, ) -> List[float]: @@ -78,8 +84,8 @@ def embed( Args: text (str): Chunk of text to embed. - input_type (str): Specifies the type of input you're giving to the model. - Not required for older versions of the embedding models (i.e. anything lower than v3), but is required + input_type (str): Specifies the type of input you're giving to the model. + Not required for older versions of the embedding models (i.e. anything lower than v3), but is required for more recent versions (i.e. anything bigger than v2). preprocess (Optional[Callable], optional): Optional preprocessing callable to perform before vectorization. Defaults to None. @@ -95,10 +101,14 @@ def embed( if not isinstance(text, str): raise TypeError("Must pass in a str value to embed.") if not isinstance(input_type, str): - raise TypeError("Must pass in a str value for input_type. See https://docs.cohere.com/reference/embed.") + raise TypeError( + "Must pass in a str value for input_type. See https://docs.cohere.com/reference/embed." + ) if preprocess: text = preprocess(text) - embedding = self._model_client.embed(texts=[text], model=self._model, input_type=input_type).embeddings[0] + embedding = self._model_client.embed( + texts=[text], model=self._model, input_type=input_type + ).embeddings[0] return self._process_embedding(embedding, as_buffer) @retry( @@ -109,7 +119,7 @@ def embed( def embed_many( self, texts: List[str], - input_type: str, + input_type: str, preprocess: Optional[Callable] = None, batch_size: int = 10, as_buffer: bool = False, @@ -118,9 +128,9 @@ def embed_many( Args: texts (List[str]): List of text chunks to embed. - input_type (str): Specifies the type of input you're giving to the model. - Not required for older versions of the embedding models (i.e. anything lower than v3), but is required - for more recent versions (i.e. anything bigger than v2). + input_type (str): Specifies the type of input you're giving to the model. + Not required for older versions of the embedding models (i.e. anything lower than v3), but is required + for more recent versions (i.e. anything bigger than v2). preprocess (Optional[Callable], optional): Optional preprocessing callable to perform before vectorization. Defaults to None. batch_size (int, optional): Batch size of texts to use when creating @@ -139,12 +149,16 @@ def embed_many( if len(texts) > 0 and not isinstance(texts[0], str): raise TypeError("Must pass in a list of str values to embed.") if not isinstance(input_type, str): - raise TypeError("Must pass in a str value for input_type. See https://docs.cohere.com/reference/embed.") + raise TypeError( + "Must pass in a str value for input_type. See https://docs.cohere.com/reference/embed." + ) embeddings: List = [] for batch in self.batchify(texts, batch_size, preprocess): - response = self._model_client.embed(texts=batch, model=self._model, input_type=input_type) + response = self._model_client.embed( + texts=batch, model=self._model, input_type=input_type + ) embeddings += [ self._process_embedding(embedding, as_buffer) for embedding in response.embeddings ] - return embeddings \ No newline at end of file + return embeddings diff --git a/tests/integration/test_vectorizers.py b/tests/integration/test_vectorizers.py index bd65d5e3..d84ff980 100644 --- a/tests/integration/test_vectorizers.py +++ b/tests/integration/test_vectorizers.py @@ -3,10 +3,10 @@ import pytest from redisvl.vectorize.text import ( + CohereTextVectorizer, HFTextVectorizer, OpenAITextVectorizer, VertexAITextVectorizer, - CohereTextVectorizer ) @@ -20,7 +20,15 @@ def skip_vectorizer() -> bool: skip_vectorizer_test = lambda: pytest.config.getfixturevalue("skip_vectorizer") -@pytest.fixture(params=[HFTextVectorizer, OpenAITextVectorizer, VertexAITextVectorizer, CohereTextVectorizer]) + +@pytest.fixture( + params=[ + HFTextVectorizer, + OpenAITextVectorizer, + VertexAITextVectorizer, + CohereTextVectorizer, + ] +) def vectorizer(request): if request.param == HFTextVectorizer: return request.param() @@ -31,6 +39,7 @@ def vectorizer(request): elif request.param == CohereTextVectorizer: return request.param() + @pytest.mark.skipif(skip_vectorizer_test, reason="Skipping vectorizer tests") def test_vectorizer_embed(vectorizer): text = "This is a test sentence." @@ -76,6 +85,7 @@ def avectorizer(request, openai_key): if request.param == OpenAITextVectorizer: return request.param() + @pytest.mark.skipif(skip_vectorizer_test, reason="Skipping vectorizer tests") @pytest.mark.asyncio async def test_vectorizer_aembed(avectorizer):