Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support VertexAI for Claude #1138

Merged
merged 10 commits into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions crates/goose-server/src/routes/providers_and_keys.json
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,11 @@
"description": "Connect to Azure OpenAI Service",
"models": ["gpt-4o", "gpt-4o-mini"],
"required_keys": ["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME"]
},
"vertex_ai": {
"name": "Vertex AI",
"description": "Access variety of AI models through Vertex AI",
"models": ["claude-3-5-sonnet-v2@20241022", "claude-3-5-sonnet@20240620"],
"required_keys": ["VERTEXAI_PROJECT_ID", "VERTEXAI_REGION"]
}
}
1 change: 1 addition & 0 deletions crates/goose/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ mcp-core = { path = "../mcp-core" }
anyhow = "1.0"
thiserror = "1.0"
futures = "0.3"
gcp-sdk-auth = "0.1.1"
reqwest = { version = "0.12.9", features = [
"rustls-tls",
"json",
Expand Down
3 changes: 3 additions & 0 deletions crates/goose/src/providers/factory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use super::{
ollama::OllamaProvider,
openai::OpenAiProvider,
openrouter::OpenRouterProvider,
vertexai::VertexAIProvider,
};
use crate::model::ModelConfig;
use anyhow::Result;
Expand All @@ -24,6 +25,7 @@ pub fn providers() -> Vec<ProviderMetadata> {
OllamaProvider::metadata(),
OpenAiProvider::metadata(),
OpenRouterProvider::metadata(),
VertexAIProvider::metadata(),
]
}

Expand All @@ -38,6 +40,7 @@ pub fn create(name: &str, model: ModelConfig) -> Result<Box<dyn Provider + Send
"ollama" => Ok(Box::new(OllamaProvider::from_env(model)?)),
"openrouter" => Ok(Box::new(OpenRouterProvider::from_env(model)?)),
"google" => Ok(Box::new(GoogleProvider::from_env(model)?)),
"vertex_ai" => Ok(Box::new(VertexAIProvider::from_env(model)?)),
_ => Err(anyhow::anyhow!("Unknown provider: {}", name)),
}
}
1 change: 1 addition & 0 deletions crates/goose/src/providers/formats/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ pub mod anthropic;
pub mod bedrock;
pub mod google;
pub mod openai;
pub mod vertexai;
67 changes: 67 additions & 0 deletions crates/goose/src/providers/formats/vertexai.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use crate::message::Message;
use crate::model::ModelConfig;
use crate::providers::base::Usage;
use anyhow::Result;
use mcp_core::tool::Tool;
use serde_json::Value;

use super::anthropic;

pub fn create_request(
model_config: &ModelConfig,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<Value> {
match model_config.model_name.as_str() {
"claude-3-5-sonnet-v2@20241022" | "claude-3-5-sonnet@20240620" => {
create_anthropic_request(model_config, system, messages, tools)
}
_ => Err(anyhow::anyhow!("Vertex AI only supports Anthropic models")),
}
}

pub fn create_anthropic_request(
model_config: &ModelConfig,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<Value> {
let mut request = anthropic::create_request(model_config, system, messages, tools)?;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add one more check whether it is anthropic model? By checking the model_config name, we can let the user know right now, only anthropic model is support via vertex ai

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.


// the Vertex AI for Claude API has small differences from the Anthropic API
// ref: https://docs.anthropic.com/en/api/claude-on-vertex-ai
request.as_object_mut().unwrap().remove("model");
request.as_object_mut().unwrap().insert(
"anthropic_version".to_string(),
Value::String("vertex-2023-10-16".to_string()),
);

Ok(request)
}

pub fn response_to_message(response: Value) -> Result<Message> {
anthropic::response_to_message(response)
}

pub fn get_usage(data: &Value) -> Result<Usage> {
anthropic::get_usage(data)
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_create_request() {
let model_config = ModelConfig::new("claude-3-5-sonnet-v2@20241022".to_string());
let system = "You are a helpful assistant.";
let messages = vec![Message::user().with_text("Hello, how are you?")];
let tools = vec![];

let request = create_request(&model_config, &system, &messages, &tools).unwrap();

assert!(request.get("anthropic_version").is_some());
assert!(request.get("model").is_none());
}
}
1 change: 1 addition & 0 deletions crates/goose/src/providers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ pub mod ollama;
pub mod openai;
pub mod openrouter;
pub mod utils;
pub mod vertexai;

pub use factory::{create, providers};
189 changes: 189 additions & 0 deletions crates/goose/src/providers/vertexai.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
use std::time::Duration;

use anyhow::Result;
use async_trait::async_trait;
use gcp_sdk_auth::credentials::create_access_token_credential;
use reqwest::Client;
use serde_json::Value;

use crate::message::Message;
use crate::model::ModelConfig;
use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage};
use crate::providers::errors::ProviderError;
use crate::providers::formats::vertexai::{create_request, get_usage, response_to_message};
use crate::providers::utils::emit_debug_trace;
use mcp_core::tool::Tool;

pub const VERTEXAI_DEFAULT_MODEL: &str = "claude-3-5-sonnet-v2@20241022";
pub const VERTEXAI_KNOWN_MODELS: &[&str] = &[
"claude-3-5-sonnet-v2@20241022",
"claude-3-5-sonnet@20240620",
];
pub const VERTEXAI_DOC_URL: &str = "https://cloud.google.com/vertex-ai";
pub const VERTEXAI_DEFAULT_REGION: &str = "us-east5";

#[derive(Debug, serde::Serialize)]
pub struct VertexAIProvider {
#[serde(skip)]
client: Client,
host: String,
project_id: String,
region: String,
model: ModelConfig,
}

impl VertexAIProvider {
pub fn from_env(model: ModelConfig) -> Result<Self> {
let config = crate::config::Config::global();

let project_id = config.get("VERTEXAI_PROJECT_ID")?;
let region = config
.get("VERTEXAI_REGION")
.unwrap_or_else(|_| VERTEXAI_DEFAULT_REGION.to_string());
let host = config
.get("VERTEXAI_API_HOST")
.unwrap_or_else(|_| format!("https://{}-aiplatform.googleapis.com", region));

let client = Client::builder()
.timeout(Duration::from_secs(600))
.build()?;

Ok(VertexAIProvider {
client,
host,
project_id,
region,
model,
})
}

async fn post(&self, payload: Value) -> Result<Value, ProviderError> {
let base_url = url::Url::parse(&self.host)
.map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?;
let path = format!(
"v1/projects/{}/locations/{}/publishers/{}/models/{}:streamRawPredict",
self.project_id,
self.region,
self.get_model_provider(),
self.model.model_name
);
let url = base_url.join(&path).map_err(|e| {
ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}"))
})?;

let creds = create_access_token_credential().await.map_err(|e| {
ProviderError::RequestFailed(format!("Failed to create access token credential: {}", e))
})?;
let token = creds.get_token().await.map_err(|e| {
ProviderError::RequestFailed(format!("Failed to get access token: {}", e))
})?;

let response = self
.client
.post(url)
.json(&payload)
.header("Authorization", format!("Bearer {}", token.token))
.send()
.await
.map_err(|e| ProviderError::RequestFailed(format!("Request failed: {}", e)))?;

let status = response.status();
let response_json = response.json::<Value>().await.map_err(|e| {
ProviderError::RequestFailed(format!("Failed to parse response: {}", e))
})?;

match status {
reqwest::StatusCode::OK => Ok(response_json),
reqwest::StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => {
tracing::debug!(
"{}",
format!(
"Provider request failed with status: {}. Payload: {:?}",
status, payload
)
);
Err(ProviderError::Authentication(format!(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Authentication failed: {:?}",
response_json
)))
}
_ => {
tracing::debug!(
"{}",
format!("Request failed with status {}: {:?}", status, response_json)
);
Err(ProviderError::RequestFailed(format!(
"Request failed with status {}: {:?}",
status, response_json
)))
}
}
}

fn get_model_provider(&self) -> String {
// TODO: switch this by model_name
"anthropic".to_string()
}
}

impl Default for VertexAIProvider {
fn default() -> Self {
let model = ModelConfig::new(Self::metadata().default_model);
VertexAIProvider::from_env(model).expect("Failed to initialize VertexAI provider")
}
}

#[async_trait]
impl Provider for VertexAIProvider {
fn metadata() -> ProviderMetadata
where
Self: Sized,
{
ProviderMetadata::new(
"vertex_ai",
"Vertex AI",
"Access variety of AI models such as Claude through Vertex AI",
VERTEXAI_DEFAULT_MODEL,
VERTEXAI_KNOWN_MODELS
.iter()
.map(|&s| s.to_string())
.collect(),
VERTEXAI_DOC_URL,
vec![
ConfigKey::new("VERTEXAI_PROJECT_ID", true, false, None),
ConfigKey::new(
"VERTEXAI_REGION",
true,
false,
Some(VERTEXAI_DEFAULT_REGION),
),
],
)
}

#[tracing::instrument(
skip(self, system, messages, tools),
fields(model_config, input, output, input_tokens, output_tokens, total_tokens)
)]
async fn complete(
&self,
system: &str,
messages: &[Message],
tools: &[Tool],
) -> Result<(Message, ProviderUsage), ProviderError> {
let request = create_request(&self.model, system, messages, tools)?;
let response = self.post(request.clone()).await?;
let usage = get_usage(&response)?;

emit_debug_trace(self, &request, &response, &usage);

let message = response_to_message(response.clone())?;
let provider_usage = ProviderUsage::new(self.model.model_name.clone(), usage);

Ok((message, provider_usage))
}

fn get_model_config(&self) -> ModelConfig {
self.model.clone()
}
}
14 changes: 14 additions & 0 deletions crates/goose/tests/truncate_agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use goose::agents::AgentFactory;
use goose::message::Message;
use goose::model::ModelConfig;
use goose::providers::base::Provider;
use goose::providers::vertexai::VertexAIProvider;
use goose::providers::{anthropic::AnthropicProvider, databricks::DatabricksProvider};
use goose::providers::{
azure::AzureProvider, bedrock::BedrockProvider, ollama::OllamaProvider, openai::OpenAiProvider,
Expand All @@ -24,6 +25,7 @@ enum ProviderType {
Groq,
Ollama,
OpenRouter,
VertexAI,
}

impl ProviderType {
Expand All @@ -42,6 +44,7 @@ impl ProviderType {
ProviderType::Groq => &["GROQ_API_KEY"],
ProviderType::Ollama => &[],
ProviderType::OpenRouter => &["OPENROUTER_API_KEY"],
ProviderType::VertexAI => &["VERTEXAI_PROJECT_ID", "VERTEXAI_REGION"],
}
}

Expand Down Expand Up @@ -74,6 +77,7 @@ impl ProviderType {
ProviderType::Groq => Box::new(GroqProvider::from_env(model_config)?),
ProviderType::Ollama => Box::new(OllamaProvider::from_env(model_config)?),
ProviderType::OpenRouter => Box::new(OpenRouterProvider::from_env(model_config)?),
ProviderType::VertexAI => Box::new(VertexAIProvider::from_env(model_config)?),
})
}
}
Expand Down Expand Up @@ -290,4 +294,14 @@ mod tests {
})
.await
}

#[tokio::test]
async fn test_truncate_agent_with_vertexai() -> Result<()> {
run_test_with_config(TestConfig {
provider_type: ProviderType::VertexAI,
model: "claude-3-5-sonnet-v2@20241022",
context_window: 200_000,
})
.await
}
}
2 changes: 2 additions & 0 deletions ui/desktop/src/components/settings/api_keys/utils.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ export function isSecretKey(keyName: string): boolean {
'OLLAMA_HOST',
'AZURE_OPENAI_ENDPOINT',
'AZURE_OPENAI_DEPLOYMENT_NAME',
'VERTEXAI_PROJECT_ID',
'VERTEXAI_REGION',
];
return !nonSecretKeys.includes(keyName);
}
Expand Down
Loading
Loading