Skip to content

Commit

Permalink
add Cohere embedding model support to allow users to create embedding…
Browse files Browse the repository at this point in the history
… indexes using cohere models (microsoft#829)

Co-authored-by: Sophie Chen <sophiechen@microsoft.com>
  • Loading branch information
SophieGarden and Sophie Chen authored May 7, 2024
1 parent 08f75ac commit acf6d79
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 8 deletions.
20 changes: 20 additions & 0 deletions scripts/.env.sample
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# resource switch
FLAG_EMBEDDING_MODEL = "AOAI" # "AOAI" or "COHERE"
FLAG_COHERE = "ENGLISH" # "MULTILINGUAL" or "ENGLISH" options for Cohere embedding models

# update vector dimension based on model chosen
VECTOR_DIMENSION = 1536 # change it to desired, e.g., 1536 for AOAI ada 002, 1024 for COHERE

# AOAI resource
AZURE_OPENAI_API_VERSION = '2023-05-15'
AZURE_OPENAI_ENDPOINT = ""
AZURE_OPENAI_API_KEY = ""

# Cohere multilingual resource
COHERE_MULTILINGUAL_ENDPOINT = ""
COHERE_MULTILINGUAL_API_KEY = ""

# Cohere English resource
COHERE_ENGLISH_ENDPOINT = ""
COHERE_ENGLISH_API_KEY = ""

8 changes: 6 additions & 2 deletions scripts/data_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,21 @@
import json
import os
import subprocess
import time

import requests
import time
from azure.ai.formrecognizer import DocumentAnalysisClient
from azure.core.credentials import AzureKeyCredential
from azure.identity import AzureCliCredential
from azure.search.documents import SearchClient
from dotenv import load_dotenv
from tqdm import tqdm

from data_utils import chunk_directory, chunk_blob_container

# Configure environment variables
load_dotenv() # take environment variables from .env.

SUPPORTED_LANGUAGE_CODES = {
"ar": "Arabic",
"hy": "Armenian",
Expand Down Expand Up @@ -228,7 +232,7 @@ def create_or_update_search_index(
"type": "Collection(Edm.Single)",
"searchable": True,
"retrievable": True,
"dimensions": 1536,
"dimensions": os.getenv("VECTOR_DIMENSION", 1536),
"vectorSearchConfiguration": vector_config_name
})

Expand Down
68 changes: 62 additions & 6 deletions scripts/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,32 @@
import json
import os
import re
import requests
from openai import AzureOpenAI
import re
import ssl
import subprocess
import tempfile
import time
import urllib.request
from abc import ABC, abstractmethod
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass
from functools import partial
from typing import Callable, List, Dict, Optional, Generator, Tuple, Union
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union

import markdown
import requests
import tiktoken
from azure.identity import DefaultAzureCredential
from azure.ai.formrecognizer import DocumentAnalysisClient
from azure.core.credentials import AzureKeyCredential
from azure.identity import DefaultAzureCredential
from azure.storage.blob import ContainerClient
from bs4 import BeautifulSoup
from dotenv import load_dotenv
from langchain.text_splitter import TextSplitter, MarkdownTextSplitter, RecursiveCharacterTextSplitter, PythonCodeTextSplitter
from openai import AzureOpenAI
from tqdm import tqdm
from typing import Any

# Configure environment variables
load_dotenv() # take environment variables from .env.

FILE_FORMAT_DICT = {
"md": "markdown",
Expand Down Expand Up @@ -632,7 +636,59 @@ def unmask_urls(text, url_dict={}):
if total_size > 0:
yield current_chunk, total_size

def get_payload_and_headers_cohere(
text, aad_token) -> Tuple[Dict, Dict]:
oai_headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {aad_token}",
}

cohere_body = { "texts": [text], "input_type": "search_document" }
return cohere_body, oai_headers

def get_embedding(text, embedding_model_endpoint=None, embedding_model_key=None, azure_credential=None):
endpoint = embedding_model_endpoint if embedding_model_endpoint else os.environ.get("EMBEDDING_MODEL_ENDPOINT")

FLAG_EMBEDDING_MODEL = os.getenv("FLAG_EMBEDDING_MODEL", "AOAI")
FLAG_COHERE = os.getenv("FLAG_COHERE", "ENGLISH")

if azure_credential is None and (endpoint is None or key is None):
raise Exception("EMBEDDING_MODEL_ENDPOINT and EMBEDDING_MODEL_KEY are required for embedding")

try:
if FLAG_EMBEDDING_MODEL == "AOAI":
endpoint_parts = endpoint.split("/openai/deployments/")
base_url = endpoint_parts[0]
deployment_id = endpoint_parts[1].split("/embeddings")[0]
api_version = endpoint_parts[1].split("api-version=")[1].split("&")[0]
if azure_credential is not None:
api_key = azure_credential.get_token("https://cognitiveservices.azure.com/.default").token
else:
api_key = embedding_model_key if embedding_model_key else os.getenv("AZURE_OPENAI_API_KEY")

client = AzureOpenAI(api_version=api_version, azure_endpoint=base_url, azure_ad_token=api_key)
embeddings = client.embeddings.create(model=deployment_id, input=text)

return embeddings.dict()['data'][0]['embedding']

if FLAG_EMBEDDING_MODEL == "COHERE":
if FLAG_COHERE == "MULTILINGUAL":
key = embedding_model_key if embedding_model_key else os.getenv("COHERE_MULTILINGUAL_API_KEY")
elif FLAG_COHERE == "ENGLISH":
key = embedding_model_key if embedding_model_key else os.getenv("COHERE_ENGLISH_API_KEY")
data, headers = get_payload_and_headers_cohere(text, key)

body = str.encode(json.dumps(data))
req = urllib.request.Request(endpoint, body, headers)
response = urllib.request.urlopen(req)
result = response.read()
result_content = json.loads(result.decode('utf-8'))

return result_content["embeddings"][0]


except Exception as e:
raise Exception(f"Error getting embeddings with endpoint={endpoint} with error={e}")
def get_embedding(text, embedding_model_endpoint=None, embedding_model_key=None, azure_credential=None):
endpoint = embedding_model_endpoint if embedding_model_endpoint else os.environ.get("EMBEDDING_MODEL_ENDPOINT")
key = embedding_model_key if embedding_model_key else os.environ.get("EMBEDDING_MODEL_KEY")
Expand Down
1 change: 1 addition & 0 deletions scripts/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Follow the instructions in this section to prepare your data locally. This is ea
- Install the necessary packages listed in requirements.txt, e.g. `pip install --user -r requirements.txt`

## Configure
- Create a .env file similar to the .env.example file. Fill in the values for the environment variables.
- Create a config file like `config.json`. The format should be a list of JSON objects, with each object specifying a configuration of local data path and target search service and index.

```
Expand Down

0 comments on commit acf6d79

Please sign in to comment.