Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
nabsabraham committed Jan 12, 2024
1 parent c039f13 commit 4c5b045
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 26 deletions.
9 changes: 7 additions & 2 deletions redisvl/vectorize/text/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from redisvl.vectorize.text.cohere import CohereTextVectorizer
from redisvl.vectorize.text.huggingface import HFTextVectorizer
from redisvl.vectorize.text.openai import OpenAITextVectorizer
from redisvl.vectorize.text.vertexai import VertexAITextVectorizer
from redisvl.vectorize.text.cohere import CohereTextVectorizer

__all__ = ["OpenAITextVectorizer", "HFTextVectorizer", "VertexAITextVectorizer", "CohereTextVectorizer"]
__all__ = [
"OpenAITextVectorizer",
"HFTextVectorizer",
"VertexAITextVectorizer",
"CohereTextVectorizer",
]
58 changes: 36 additions & 22 deletions redisvl/vectorize/text/cohere.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,29 @@
import os
from typing import Callable, Dict, List, Optional

from tenacity import retry, stop_after_attempt, wait_random_exponential
from tenacity.retry import retry_if_not_exception_type

from redisvl.vectorize.base import BaseVectorizer


class CohereTextVectorizer(BaseVectorizer):
"""
The CohereTextVectorizer class utilizes Cohere's API to generate embeddings
for text data.
This vectorizer is designed to interact with Cohere's /embed API, requiring an API key for authentication. The key
can be provided directly in the `api_config` dictionary or through the `COHERE_API_KEY` environment variable. Users
must obtain an API key from Cohere's website (https://dashboard.cohere.com/). Additionally, the `cohere` python
This vectorizer is designed to interact with Cohere's /embed API, requiring an API key for authentication. The key
can be provided directly in the `api_config` dictionary or through the `COHERE_API_KEY` environment variable. Users
must obtain an API key from Cohere's website (https://dashboard.cohere.com/). Additionally, the `cohere` python
client must be installed with `pip install cohere`.
The vectorizer supports only synchronous operations, allows for batch processing of texts and flexibility in
The vectorizer supports only synchronous operations, allows for batch processing of texts and flexibility in
handling preprocessing tasks.
"""

def __init__(self, model: str = 'embed-english-v3.0', api_config: Optional[Dict] = None):
def __init__(
self, model: str = "embed-english-v3.0", api_config: Optional[Dict] = None
):
"""Initialize the Cohere vectorizer. Visit https://cohere.ai/embed to learn about embeddings.
Args:
Expand All @@ -40,7 +44,7 @@ def __init__(self, model: str = 'embed-english-v3.0', api_config: Optional[Dict]
raise ImportError(
"Cohere vectorizer requires the cohere library. Please install with `pip install cohere`"
)

# Fetch the API key from api_config or environment variable
api_key = (
api_config.get("api_key") if api_config else os.getenv("COHERE_API_KEY")
Expand All @@ -50,36 +54,38 @@ def __init__(self, model: str = 'embed-english-v3.0', api_config: Optional[Dict]
"Cohere API key is required. "
"Provide it in api_config or set the COHERE_API_KEY environment variable."
)

self._model = model
self._model_client = cohere.Client(api_key)
self._dims = self._set_model_dims()

def _set_model_dims(self) -> int:
try:
embedding = self._model_client.embed(
texts=["dimension test"], model=self._model, input_type="search_document",
texts=["dimension test"],
model=self._model,
input_type="search_document",
).embeddings[0]
except (KeyError, IndexError) as ke:
raise ValueError(f"Unexpected response from the Cohere API: {str(ke)}")
except Exception as e: # pylint: disable=broad-except
# fall back (TODO get more specific)
raise ValueError(f"Error setting embedding model dimensions: {str(e)}")
return len(embedding)

def embed(
self,
text: str,
input_type: str,
input_type: str,
preprocess: Optional[Callable] = None,
as_buffer: bool = False,
) -> List[float]:
"""Embed a chunk of text using the Cohere API.
Args:
text (str): Chunk of text to embed.
input_type (str): Specifies the type of input you're giving to the model.
Not required for older versions of the embedding models (i.e. anything lower than v3), but is required
input_type (str): Specifies the type of input you're giving to the model.
Not required for older versions of the embedding models (i.e. anything lower than v3), but is required
for more recent versions (i.e. anything bigger than v2).
preprocess (Optional[Callable], optional): Optional preprocessing callable to
perform before vectorization. Defaults to None.
Expand All @@ -95,10 +101,14 @@ def embed(
if not isinstance(text, str):
raise TypeError("Must pass in a str value to embed.")
if not isinstance(input_type, str):
raise TypeError("Must pass in a str value for input_type. See https://docs.cohere.com/reference/embed.")
raise TypeError(
"Must pass in a str value for input_type. See https://docs.cohere.com/reference/embed."
)
if preprocess:
text = preprocess(text)
embedding = self._model_client.embed(texts=[text], model=self._model, input_type=input_type).embeddings[0]
embedding = self._model_client.embed(
texts=[text], model=self._model, input_type=input_type
).embeddings[0]
return self._process_embedding(embedding, as_buffer)

@retry(
Expand All @@ -109,7 +119,7 @@ def embed(
def embed_many(
self,
texts: List[str],
input_type: str,
input_type: str,
preprocess: Optional[Callable] = None,
batch_size: int = 10,
as_buffer: bool = False,
Expand All @@ -118,9 +128,9 @@ def embed_many(
Args:
texts (List[str]): List of text chunks to embed.
input_type (str): Specifies the type of input you're giving to the model.
Not required for older versions of the embedding models (i.e. anything lower than v3), but is required
for more recent versions (i.e. anything bigger than v2).
input_type (str): Specifies the type of input you're giving to the model.
Not required for older versions of the embedding models (i.e. anything lower than v3), but is required
for more recent versions (i.e. anything bigger than v2).
preprocess (Optional[Callable], optional): Optional preprocessing callable to
perform before vectorization. Defaults to None.
batch_size (int, optional): Batch size of texts to use when creating
Expand All @@ -139,12 +149,16 @@ def embed_many(
if len(texts) > 0 and not isinstance(texts[0], str):
raise TypeError("Must pass in a list of str values to embed.")
if not isinstance(input_type, str):
raise TypeError("Must pass in a str value for input_type. See https://docs.cohere.com/reference/embed.")
raise TypeError(
"Must pass in a str value for input_type. See https://docs.cohere.com/reference/embed."
)
embeddings: List = []
for batch in self.batchify(texts, batch_size, preprocess):
response = self._model_client.embed(texts=batch, model=self._model, input_type=input_type)
response = self._model_client.embed(
texts=batch, model=self._model, input_type=input_type
)
embeddings += [
self._process_embedding(embedding, as_buffer)
for embedding in response.embeddings
]
return embeddings
return embeddings
14 changes: 12 additions & 2 deletions tests/integration/test_vectorizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import pytest

from redisvl.vectorize.text import (
CohereTextVectorizer,
HFTextVectorizer,
OpenAITextVectorizer,
VertexAITextVectorizer,
CohereTextVectorizer
)


Expand All @@ -20,7 +20,15 @@ def skip_vectorizer() -> bool:

skip_vectorizer_test = lambda: pytest.config.getfixturevalue("skip_vectorizer")

@pytest.fixture(params=[HFTextVectorizer, OpenAITextVectorizer, VertexAITextVectorizer, CohereTextVectorizer])

@pytest.fixture(
params=[
HFTextVectorizer,
OpenAITextVectorizer,
VertexAITextVectorizer,
CohereTextVectorizer,
]
)
def vectorizer(request):
if request.param == HFTextVectorizer:
return request.param()
Expand All @@ -31,6 +39,7 @@ def vectorizer(request):
elif request.param == CohereTextVectorizer:
return request.param()


@pytest.mark.skipif(skip_vectorizer_test, reason="Skipping vectorizer tests")
def test_vectorizer_embed(vectorizer):
text = "This is a test sentence."
Expand Down Expand Up @@ -76,6 +85,7 @@ def avectorizer(request, openai_key):
if request.param == OpenAITextVectorizer:
return request.param()


@pytest.mark.skipif(skip_vectorizer_test, reason="Skipping vectorizer tests")
@pytest.mark.asyncio
async def test_vectorizer_aembed(avectorizer):
Expand Down

0 comments on commit 4c5b045

Please sign in to comment.