diff --git a/crates/goose-server/src/routes/providers_and_keys.json b/crates/goose-server/src/routes/providers_and_keys.json index e56f9f8b4..41ebf9634 100644 --- a/crates/goose-server/src/routes/providers_and_keys.json +++ b/crates/goose-server/src/routes/providers_and_keys.json @@ -46,5 +46,11 @@ "description": "Connect to Azure OpenAI Service", "models": ["gpt-4o", "gpt-4o-mini"], "required_keys": ["AZURE_OPENAI_API_KEY", "AZURE_OPENAI_ENDPOINT", "AZURE_OPENAI_DEPLOYMENT_NAME"] + }, + "vertex_ai": { + "name": "Vertex AI", + "description": "Access variety of AI models through Vertex AI", + "models": ["claude-3-5-sonnet-v2@20241022", "claude-3-5-sonnet@20240620"], + "required_keys": ["VERTEXAI_PROJECT_ID", "VERTEXAI_REGION"] } } diff --git a/crates/goose/Cargo.toml b/crates/goose/Cargo.toml index 3d0d4a7b2..294beb6c7 100644 --- a/crates/goose/Cargo.toml +++ b/crates/goose/Cargo.toml @@ -17,6 +17,7 @@ mcp-core = { path = "../mcp-core" } anyhow = "1.0" thiserror = "1.0" futures = "0.3" +gcp-sdk-auth = "0.1.1" reqwest = { version = "0.12.9", features = [ "rustls-tls", "json", diff --git a/crates/goose/src/providers/factory.rs b/crates/goose/src/providers/factory.rs index d17fb8893..efbba57b0 100644 --- a/crates/goose/src/providers/factory.rs +++ b/crates/goose/src/providers/factory.rs @@ -9,6 +9,7 @@ use super::{ ollama::OllamaProvider, openai::OpenAiProvider, openrouter::OpenRouterProvider, + vertexai::VertexAIProvider, }; use crate::model::ModelConfig; use anyhow::Result; @@ -24,6 +25,7 @@ pub fn providers() -> Vec { OllamaProvider::metadata(), OpenAiProvider::metadata(), OpenRouterProvider::metadata(), + VertexAIProvider::metadata(), ] } @@ -38,6 +40,7 @@ pub fn create(name: &str, model: ModelConfig) -> Result Ok(Box::new(OllamaProvider::from_env(model)?)), "openrouter" => Ok(Box::new(OpenRouterProvider::from_env(model)?)), "google" => Ok(Box::new(GoogleProvider::from_env(model)?)), + "vertex_ai" => Ok(Box::new(VertexAIProvider::from_env(model)?)), _ => Err(anyhow::anyhow!("Unknown provider: {}", name)), } } diff --git a/crates/goose/src/providers/formats/mod.rs b/crates/goose/src/providers/formats/mod.rs index 780f38488..986e283ef 100644 --- a/crates/goose/src/providers/formats/mod.rs +++ b/crates/goose/src/providers/formats/mod.rs @@ -2,3 +2,4 @@ pub mod anthropic; pub mod bedrock; pub mod google; pub mod openai; +pub mod vertexai; diff --git a/crates/goose/src/providers/formats/vertexai.rs b/crates/goose/src/providers/formats/vertexai.rs new file mode 100644 index 000000000..61506ae40 --- /dev/null +++ b/crates/goose/src/providers/formats/vertexai.rs @@ -0,0 +1,67 @@ +use crate::message::Message; +use crate::model::ModelConfig; +use crate::providers::base::Usage; +use anyhow::Result; +use mcp_core::tool::Tool; +use serde_json::Value; + +use super::anthropic; + +pub fn create_request( + model_config: &ModelConfig, + system: &str, + messages: &[Message], + tools: &[Tool], +) -> Result { + match model_config.model_name.as_str() { + "claude-3-5-sonnet-v2@20241022" | "claude-3-5-sonnet@20240620" => { + create_anthropic_request(model_config, system, messages, tools) + } + _ => Err(anyhow::anyhow!("Vertex AI only supports Anthropic models")), + } +} + +pub fn create_anthropic_request( + model_config: &ModelConfig, + system: &str, + messages: &[Message], + tools: &[Tool], +) -> Result { + let mut request = anthropic::create_request(model_config, system, messages, tools)?; + + // the Vertex AI for Claude API has small differences from the Anthropic API + // ref: https://docs.anthropic.com/en/api/claude-on-vertex-ai + request.as_object_mut().unwrap().remove("model"); + request.as_object_mut().unwrap().insert( + "anthropic_version".to_string(), + Value::String("vertex-2023-10-16".to_string()), + ); + + Ok(request) +} + +pub fn response_to_message(response: Value) -> Result { + anthropic::response_to_message(response) +} + +pub fn get_usage(data: &Value) -> Result { + anthropic::get_usage(data) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_create_request() { + let model_config = ModelConfig::new("claude-3-5-sonnet-v2@20241022".to_string()); + let system = "You are a helpful assistant."; + let messages = vec![Message::user().with_text("Hello, how are you?")]; + let tools = vec![]; + + let request = create_request(&model_config, &system, &messages, &tools).unwrap(); + + assert!(request.get("anthropic_version").is_some()); + assert!(request.get("model").is_none()); + } +} diff --git a/crates/goose/src/providers/mod.rs b/crates/goose/src/providers/mod.rs index 634224fd7..47083a9d0 100644 --- a/crates/goose/src/providers/mod.rs +++ b/crates/goose/src/providers/mod.rs @@ -13,5 +13,6 @@ pub mod ollama; pub mod openai; pub mod openrouter; pub mod utils; +pub mod vertexai; pub use factory::{create, providers}; diff --git a/crates/goose/src/providers/vertexai.rs b/crates/goose/src/providers/vertexai.rs new file mode 100644 index 000000000..0d9a35ed2 --- /dev/null +++ b/crates/goose/src/providers/vertexai.rs @@ -0,0 +1,189 @@ +use std::time::Duration; + +use anyhow::Result; +use async_trait::async_trait; +use gcp_sdk_auth::credentials::create_access_token_credential; +use reqwest::Client; +use serde_json::Value; + +use crate::message::Message; +use crate::model::ModelConfig; +use crate::providers::base::{ConfigKey, Provider, ProviderMetadata, ProviderUsage}; +use crate::providers::errors::ProviderError; +use crate::providers::formats::vertexai::{create_request, get_usage, response_to_message}; +use crate::providers::utils::emit_debug_trace; +use mcp_core::tool::Tool; + +pub const VERTEXAI_DEFAULT_MODEL: &str = "claude-3-5-sonnet-v2@20241022"; +pub const VERTEXAI_KNOWN_MODELS: &[&str] = &[ + "claude-3-5-sonnet-v2@20241022", + "claude-3-5-sonnet@20240620", +]; +pub const VERTEXAI_DOC_URL: &str = "https://cloud.google.com/vertex-ai"; +pub const VERTEXAI_DEFAULT_REGION: &str = "us-east5"; + +#[derive(Debug, serde::Serialize)] +pub struct VertexAIProvider { + #[serde(skip)] + client: Client, + host: String, + project_id: String, + region: String, + model: ModelConfig, +} + +impl VertexAIProvider { + pub fn from_env(model: ModelConfig) -> Result { + let config = crate::config::Config::global(); + + let project_id = config.get("VERTEXAI_PROJECT_ID")?; + let region = config + .get("VERTEXAI_REGION") + .unwrap_or_else(|_| VERTEXAI_DEFAULT_REGION.to_string()); + let host = config + .get("VERTEXAI_API_HOST") + .unwrap_or_else(|_| format!("https://{}-aiplatform.googleapis.com", region)); + + let client = Client::builder() + .timeout(Duration::from_secs(600)) + .build()?; + + Ok(VertexAIProvider { + client, + host, + project_id, + region, + model, + }) + } + + async fn post(&self, payload: Value) -> Result { + let base_url = url::Url::parse(&self.host) + .map_err(|e| ProviderError::RequestFailed(format!("Invalid base URL: {e}")))?; + let path = format!( + "v1/projects/{}/locations/{}/publishers/{}/models/{}:streamRawPredict", + self.project_id, + self.region, + self.get_model_provider(), + self.model.model_name + ); + let url = base_url.join(&path).map_err(|e| { + ProviderError::RequestFailed(format!("Failed to construct endpoint URL: {e}")) + })?; + + let creds = create_access_token_credential().await.map_err(|e| { + ProviderError::RequestFailed(format!("Failed to create access token credential: {}", e)) + })?; + let token = creds.get_token().await.map_err(|e| { + ProviderError::RequestFailed(format!("Failed to get access token: {}", e)) + })?; + + let response = self + .client + .post(url) + .json(&payload) + .header("Authorization", format!("Bearer {}", token.token)) + .send() + .await + .map_err(|e| ProviderError::RequestFailed(format!("Request failed: {}", e)))?; + + let status = response.status(); + let response_json = response.json::().await.map_err(|e| { + ProviderError::RequestFailed(format!("Failed to parse response: {}", e)) + })?; + + match status { + reqwest::StatusCode::OK => Ok(response_json), + reqwest::StatusCode::UNAUTHORIZED | reqwest::StatusCode::FORBIDDEN => { + tracing::debug!( + "{}", + format!( + "Provider request failed with status: {}. Payload: {:?}", + status, payload + ) + ); + Err(ProviderError::Authentication(format!( + "Authentication failed: {:?}", + response_json + ))) + } + _ => { + tracing::debug!( + "{}", + format!("Request failed with status {}: {:?}", status, response_json) + ); + Err(ProviderError::RequestFailed(format!( + "Request failed with status {}: {:?}", + status, response_json + ))) + } + } + } + + fn get_model_provider(&self) -> String { + // TODO: switch this by model_name + "anthropic".to_string() + } +} + +impl Default for VertexAIProvider { + fn default() -> Self { + let model = ModelConfig::new(Self::metadata().default_model); + VertexAIProvider::from_env(model).expect("Failed to initialize VertexAI provider") + } +} + +#[async_trait] +impl Provider for VertexAIProvider { + fn metadata() -> ProviderMetadata + where + Self: Sized, + { + ProviderMetadata::new( + "vertex_ai", + "Vertex AI", + "Access variety of AI models such as Claude through Vertex AI", + VERTEXAI_DEFAULT_MODEL, + VERTEXAI_KNOWN_MODELS + .iter() + .map(|&s| s.to_string()) + .collect(), + VERTEXAI_DOC_URL, + vec![ + ConfigKey::new("VERTEXAI_PROJECT_ID", true, false, None), + ConfigKey::new( + "VERTEXAI_REGION", + true, + false, + Some(VERTEXAI_DEFAULT_REGION), + ), + ], + ) + } + + #[tracing::instrument( + skip(self, system, messages, tools), + fields(model_config, input, output, input_tokens, output_tokens, total_tokens) + )] + async fn complete( + &self, + system: &str, + messages: &[Message], + tools: &[Tool], + ) -> Result<(Message, ProviderUsage), ProviderError> { + let request = create_request(&self.model, system, messages, tools)?; + let response = self.post(request.clone()).await?; + let usage = get_usage(&response)?; + + emit_debug_trace(self, &request, &response, &usage); + + let message = response_to_message(response.clone())?; + let provider_usage = ProviderUsage::new(self.model.model_name.clone(), usage); + + Ok((message, provider_usage)) + } + + fn get_model_config(&self) -> ModelConfig { + self.model.clone() + } +} diff --git a/crates/goose/tests/truncate_agent.rs b/crates/goose/tests/truncate_agent.rs index 4c8be6e9d..46b3168de 100644 --- a/crates/goose/tests/truncate_agent.rs +++ b/crates/goose/tests/truncate_agent.rs @@ -6,6 +6,7 @@ use goose::agents::AgentFactory; use goose::message::Message; use goose::model::ModelConfig; use goose::providers::base::Provider; +use goose::providers::vertexai::VertexAIProvider; use goose::providers::{anthropic::AnthropicProvider, databricks::DatabricksProvider}; use goose::providers::{ azure::AzureProvider, bedrock::BedrockProvider, ollama::OllamaProvider, openai::OpenAiProvider, @@ -24,6 +25,7 @@ enum ProviderType { Groq, Ollama, OpenRouter, + VertexAI, } impl ProviderType { @@ -42,6 +44,7 @@ impl ProviderType { ProviderType::Groq => &["GROQ_API_KEY"], ProviderType::Ollama => &[], ProviderType::OpenRouter => &["OPENROUTER_API_KEY"], + ProviderType::VertexAI => &["VERTEXAI_PROJECT_ID", "VERTEXAI_REGION"], } } @@ -74,6 +77,7 @@ impl ProviderType { ProviderType::Groq => Box::new(GroqProvider::from_env(model_config)?), ProviderType::Ollama => Box::new(OllamaProvider::from_env(model_config)?), ProviderType::OpenRouter => Box::new(OpenRouterProvider::from_env(model_config)?), + ProviderType::VertexAI => Box::new(VertexAIProvider::from_env(model_config)?), }) } } @@ -290,4 +294,14 @@ mod tests { }) .await } + + #[tokio::test] + async fn test_truncate_agent_with_vertexai() -> Result<()> { + run_test_with_config(TestConfig { + provider_type: ProviderType::VertexAI, + model: "claude-3-5-sonnet-v2@20241022", + context_window: 200_000, + }) + .await + } } diff --git a/ui/desktop/src/components/settings/api_keys/utils.tsx b/ui/desktop/src/components/settings/api_keys/utils.tsx index dc8805eb6..185cb37da 100644 --- a/ui/desktop/src/components/settings/api_keys/utils.tsx +++ b/ui/desktop/src/components/settings/api_keys/utils.tsx @@ -8,6 +8,8 @@ export function isSecretKey(keyName: string): boolean { 'OLLAMA_HOST', 'AZURE_OPENAI_ENDPOINT', 'AZURE_OPENAI_DEPLOYMENT_NAME', + 'VERTEXAI_PROJECT_ID', + 'VERTEXAI_REGION', ]; return !nonSecretKeys.includes(keyName); } diff --git a/ui/desktop/src/components/settings/models/hardcoded_stuff.tsx b/ui/desktop/src/components/settings/models/hardcoded_stuff.tsx index f8b342d35..ae60fb916 100644 --- a/ui/desktop/src/components/settings/models/hardcoded_stuff.tsx +++ b/ui/desktop/src/components/settings/models/hardcoded_stuff.tsx @@ -19,6 +19,8 @@ export const goose_models: Model[] = [ { id: 17, name: 'qwen2.5', provider: 'Ollama' }, { id: 18, name: 'anthropic/claude-3.5-sonnet', provider: 'OpenRouter' }, { id: 19, name: 'gpt-4o', provider: 'Azure OpenAI' }, + { id: 20, name: 'claude-3-5-sonnet-v2@20241022', provider: 'Vertex AI' }, + { id: 21, name: 'claude-3-5-sonnet@20240620', provider: 'Vertex AI' }, ]; export const openai_models = ['gpt-4o-mini', 'gpt-4o', 'gpt-4-turbo', 'o1']; @@ -47,6 +49,8 @@ export const openrouter_models = ['anthropic/claude-3.5-sonnet']; export const azure_openai_models = ['gpt-4o']; +export const vertexai_models = ['claude-3-5-sonnet-v2@20241022', 'claude-3-5-sonnet@20240620']; + export const default_models = { openai: 'gpt-4o', anthropic: 'claude-3-5-sonnet-latest', @@ -56,6 +60,7 @@ export const default_models = { openrouter: 'anthropic/claude-3.5-sonnet', ollama: 'qwen2.5', azure_openai: 'gpt-4o', + vertex_ai: 'claude-3-5-sonnet-v2@20241022', }; export function getDefaultModel(key: string): string | undefined { @@ -73,6 +78,7 @@ export const required_keys = { Google: ['GOOGLE_API_KEY'], OpenRouter: ['OPENROUTER_API_KEY'], 'Azure OpenAI': ['AZURE_OPENAI_API_KEY', 'AZURE_OPENAI_ENDPOINT', 'AZURE_OPENAI_DEPLOYMENT_NAME'], + 'Vertex AI': ['VERTEXAI_PROJECT_ID', 'VERTEXAI_REGION'], }; export const supported_providers = [ @@ -84,6 +90,7 @@ export const supported_providers = [ 'Ollama', 'OpenRouter', 'Azure OpenAI', + 'Vertex AI', ]; export const model_docs_link = [ @@ -97,6 +104,7 @@ export const model_docs_link = [ }, { name: 'OpenRouter', href: 'https://openrouter.ai/models' }, { name: 'Ollama', href: 'https://ollama.com/library' }, + { name: 'Vertex AI', href: 'https://cloud.google.com/vertex-ai' }, ]; export const provider_aliases = [ @@ -108,4 +116,5 @@ export const provider_aliases = [ { provider: 'OpenRouter', alias: 'openrouter' }, { provider: 'Google', alias: 'google' }, { provider: 'Azure OpenAI', alias: 'azure_openai' }, + { provider: 'Vertex AI', alias: 'vertex_ai' }, ];