Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Cohere Rerank 3 Support #466

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions bin/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ export function getConfig(): SystemConfig {
name: "cross-encoder/ms-marco-MiniLM-L-12-v2",
default: true,
},
{
provider: "cohere",
name: "rerank-english-v3.0",
default: false,
},
],
},
};
Expand Down
4 changes: 4 additions & 0 deletions cli/magic-config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,10 @@ async function processCreateOptions(options: any): Promise<void> {
name: "cross-encoder/ms-marco-MiniLM-L-12-v2",
default: true,
};
config.rag.crossEncoderModels[1] = {
provider: "cohere",
name: "rerank-english-v3.0",
};
config.rag.embeddingsModels = embeddingModels;
config.rag.embeddingsModels.forEach((m: any) => {
if (m.name === models.defaultEmbedding) {
Expand Down
16 changes: 16 additions & 0 deletions docs/documentation/inference-script.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,19 @@ The API is JSON body based:
"passages": ["I love Paris", "I love London"]
}
```

## Cohere Rerank 3

To use the Cohere Rerank 3 model, get an API key from Cohere, and include the following in the JSON request body:

```json
{
"type": "cross-encoder",
"model": "rerank-english-v3.0",
"input": "What is the capital of the United States?",
"passages": [
"Carson City is the capital city of the American state of Nevada.",
"Washington, D.C. is the capital of the United States.",
...
]
}
2 changes: 1 addition & 1 deletion lib/model-interfaces/idefics/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ export class IdeficsInterface extends Construct {
new iam.PolicyStatement({
actions: ["kms:Decrypt", "kms:ReEncryptFrom"],
effect: iam.Effect.ALLOW,
resources: ["arn:aws:kms:*"],
resources: ["*"],
})
);

Expand Down
37 changes: 21 additions & 16 deletions lib/rag-engines/sagemaker-rag-models/model/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"intfloat/multilingual-e5-large",
"sentence-transformers/all-MiniLM-L6-v2",
]
cross_encoder_models = ["cross-encoder/ms-marco-MiniLM-L-12-v2"]
cross_encoder_models = ["cross-encoder/ms-marco-MiniLM-L-12-v2", "rerank-english-v3.0"]


def process_model_list(model_list):
Expand Down Expand Up @@ -130,21 +130,26 @@ def predict_fn(input_object, config):
passages = input_object["passages"]
data = [[current_input, passage] for passage in passages]

with torch.inference_mode():
features = current_tokenizer(
data, padding=True, truncation=True, return_tensors="pt"
)

features = features.to(device)

scores = current_model(**features).logits.cpu().numpy()
ret_value = list(
map(
lambda val: val[-1] if isinstance(val, list) else val,
scores.tolist(),
if current_model_id == "rerank-english-v3.0":
# Use Cohere Rerank 3 API
co = cohere.Client(os.environ["COHERE_API_KEY"])
results = co.rerank(query=current_input, documents=passages, top_n=len(passages), model='rerank-english-v3.0')
ret_value = [result.relevance_score for result in results]
else:
with torch.inference_mode():
features = current_tokenizer(
data, padding=True, truncation=True, return_tensors="pt"
)
)

return ret_value

features = features.to(device)

scores = current_model(**features).logits.cpu().numpy()
ret_value = list(
map(
lambda val: val[-1] if isinstance(val, list) else val,
scores.tolist(),
)
)
return ret_value

return []
1 change: 1 addition & 0 deletions lib/shared/layers/common/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ pgvector==0.2.2
pydantic==2.3.0
urllib3<2
openai==0.28.1
cohere==5.3.0
beautifulsoup4==4.12.2
requests==2.31.0
attrs==23.1.0
Expand Down
10 changes: 10 additions & 0 deletions lib/shared/layers/python-sdk/python/genai_core/clients.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import boto3
import cohere
import openai
import genai_core.types
import genai_core.parameters
Expand Down Expand Up @@ -52,3 +53,12 @@ def get_bedrock_client(service_name="bedrock-runtime"):
bedrock_config_data["aws_session_token"] = credentials["SessionToken"]

return boto3.client(**bedrock_config_data)

def get_cohere_client():
api_key = genai_core.parameters.get_external_api_key("COHERE_API_KEY")
if not api_key:
return None

cohere_client = cohere.Client(api_key)

return cohere_client
27 changes: 27 additions & 0 deletions lib/shared/layers/python-sdk/python/genai_core/cross_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ def rank_passages(

if model.provider == "sagemaker":
return _rank_passages_sagemaker(model, input, passages)
elif model.provider == "cohere":
return _rank_passages_cohere(model, input, passages)

raise genai_core.typesCommonError(f"Unknown provider")

Expand All @@ -29,6 +31,10 @@ def get_cross_encoder_models():
if not SAGEMAKER_RAG_MODELS_ENDPOINT:
models = list(filter(lambda x: x["provider"] != "sagemaker", models))

for model in models:
if 'default' not in model:
model['default'] = False

return models


Expand Down Expand Up @@ -66,3 +72,24 @@ def _rank_passages_sagemaker(
ret_value = json.loads(response["Body"].read().decode())

return ret_value

def _rank_passages_cohere(
model: genai_core.types.CrossEncoderModel, input: str, passages: List[str]
):
cohere_client = genai_core.clients.get_cohere_client()
if not cohere_client:
raise genai_core.types.CommonError("Cohere API key not set")

results = cohere_client.rerank(
query=input,
documents=passages,
model=model.name,
)

return [
genai_core.types.RankedPassage(
passage=passage,
score=result.relevance_score,
)
for passage, result in zip(passages, results)
]
2 changes: 1 addition & 1 deletion lib/shared/types.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import * as sagemaker from "aws-cdk-lib/aws-sagemaker";

export type ModelProvider = "sagemaker" | "bedrock" | "openai";
export type ModelProvider = "sagemaker" | "bedrock" | "openai" | "cohere";

export enum SupportedSageMakerModels {
FalconLite = "FalconLite [ml.g5.12xlarge]",
Expand Down