From 660bb718a5257fbc36a67d8c4746682f1c2bd979 Mon Sep 17 00:00:00 2001 From: Yifan Mai Date: Sun, 19 Jan 2025 15:52:06 -0800 Subject: [PATCH] Add Llama 3.1 on Vertex AI to CrossProviderInferenceEngine Signed-off-by: Yifan Mai --- src/unitxt/inference.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/unitxt/inference.py b/src/unitxt/inference.py index 1d4666fe9b..61c1adf8e5 100644 --- a/src/unitxt/inference.py +++ b/src/unitxt/inference.py @@ -2895,6 +2895,7 @@ def _infer( "watsonx-sdk", "rits", "azure", + "vertex-ai", ] @@ -2909,7 +2910,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin): user requests. Current _supported_apis = ["watsonx", "together-ai", "open-ai", "aws", "ollama", - "bam", "watsonx-sdk", "rits"] + "bam", "watsonx-sdk", "rits", "vertex-ai"] Args: provider (Optional): @@ -3014,6 +3015,11 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin): "gpt-3.5-turbo-16k-0613": "azure/gpt-3.5-turbo-16k-0613", "gpt-4-vision": "azure/gpt-4-vision", }, + "vertex-ai": { + "llama-3-1-8b-instruct": "vertex_ai/meta/llama-3.1-8b-instruct-maas", + "llama-3-1-70b-instruct": "vertex_ai/meta/llama-3.1-70b-instruct-maas", + "llama-3-1-405b-instruct": "vertex_ai/meta/llama-3.1-405b-instruct-maas", + }, } _provider_to_base_class = { @@ -3026,6 +3032,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin): "watsonx-sdk": WMLInferenceEngine, "rits": RITSInferenceEngine, "azure": LiteLLMInferenceEngine, + "vertex-ai": LiteLLMInferenceEngine, } _provider_param_renaming = {