Skip to content

Commit

Permalink
Add Llama 3.1 on Vertex AI to CrossProviderInferenceEngine
Browse files Browse the repository at this point in the history
Signed-off-by: Yifan Mai <yifan@cs.stanford.edu>
  • Loading branch information
yifanmai committed Jan 19, 2025
1 parent 4d50008 commit 660bb71
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/unitxt/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -2895,6 +2895,7 @@ def _infer(
"watsonx-sdk",
"rits",
"azure",
"vertex-ai",
]


Expand All @@ -2909,7 +2910,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
user requests.
Current _supported_apis = ["watsonx", "together-ai", "open-ai", "aws", "ollama",
"bam", "watsonx-sdk", "rits"]
"bam", "watsonx-sdk", "rits", "vertex-ai"]
Args:
provider (Optional):
Expand Down Expand Up @@ -3014,6 +3015,11 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
"gpt-3.5-turbo-16k-0613": "azure/gpt-3.5-turbo-16k-0613",
"gpt-4-vision": "azure/gpt-4-vision",
},
"vertex-ai": {
"llama-3-1-8b-instruct": "vertex_ai/meta/llama-3.1-8b-instruct-maas",
"llama-3-1-70b-instruct": "vertex_ai/meta/llama-3.1-70b-instruct-maas",
"llama-3-1-405b-instruct": "vertex_ai/meta/llama-3.1-405b-instruct-maas",
},
}

_provider_to_base_class = {
Expand All @@ -3026,6 +3032,7 @@ class CrossProviderInferenceEngine(InferenceEngine, StandardAPIParamsMixin):
"watsonx-sdk": WMLInferenceEngine,
"rits": RITSInferenceEngine,
"azure": LiteLLMInferenceEngine,
"vertex-ai": LiteLLMInferenceEngine,
}

_provider_param_renaming = {
Expand Down

0 comments on commit 660bb71

Please sign in to comment.