From 34a4bac93131a170b898f0b61d289ad58f574b54 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Wed, 19 Feb 2025 15:04:26 -0800 Subject: [PATCH 01/14] feat: add prompts support to mcp-client, ahere to MCP spec for prompts - add new endpoints `list_prompts` and `get_prompt` in the MCP client - update prompt model in mcp-core to make `description` and `arguments` optional, following MCP spec --- crates/goose-mcp/src/developer/mod.rs | 14 ++------ crates/mcp-client/src/client.rs | 48 +++++++++++++++++++++++++-- crates/mcp-core/src/prompt.rs | 28 ++++++++++------ crates/mcp-server/src/router.rs | 27 ++++++++------- 4 files changed, 81 insertions(+), 36 deletions(-) diff --git a/crates/goose-mcp/src/developer/mod.rs b/crates/goose-mcp/src/developer/mod.rs index ee326ddd4..5cbdef982 100644 --- a/crates/goose-mcp/src/developer/mod.rs +++ b/crates/goose-mcp/src/developer/mod.rs @@ -70,9 +70,9 @@ pub fn load_prompt_files() -> HashMap { description: arg.description, required: arg.required, }) - .collect(); + .collect::>(); - let prompt = Prompt::new(&template.id, &template.template, arguments); + let prompt = Prompt::new(&template.id, Some(&template.template), Some(arguments)); if prompts.contains_key(&prompt.name) { eprintln!("Duplicate prompt name '{}' found. Skipping.", prompt.name); @@ -854,15 +854,7 @@ impl Router for DeveloperRouter { Some(Box::pin(async move { match prompts.get(&prompt_name) { - Some(prompt) => { - if prompt.description.trim().is_empty() { - Err(PromptError::InternalError(format!( - "Prompt '{prompt_name}' has an empty description" - ))) - } else { - Ok(prompt.description.clone()) - } - } + Some(prompt) => Ok(prompt.description.clone().unwrap_or_default()), None => Err(PromptError::NotFound(format!( "Prompt '{prompt_name}' not found" ))), diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index 0a00e8c77..0d722e558 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -1,7 +1,7 @@ use mcp_core::protocol::{ - CallToolResult, Implementation, InitializeResult, JsonRpcError, JsonRpcMessage, - JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListResourcesResult, ListToolsResult, - ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND, + CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcError, + JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListPromptsResult, + ListResourcesResult, ListToolsResult, ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND, }; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -93,6 +93,10 @@ pub trait McpClientTrait: Send + Sync { async fn list_tools(&self, next_cursor: Option) -> Result; async fn call_tool(&self, name: &str, arguments: Value) -> Result; + + async fn list_prompts(&self, next_cursor: Option) -> Result; + + async fn get_prompt(&self, name: &str, arguments: Value) -> Result; } /// The MCP client is the interface for MCP operations. @@ -346,4 +350,42 @@ where // https://modelcontextprotocol.io/docs/concepts/tools#error-handling-2 self.send_request("tools/call", params).await } + + async fn list_prompts(&self, next_cursor: Option) -> Result { + if !self.completed_initialization() { + return Err(Error::NotInitialized); + } + + // If prompts is not supported, return an error + if self.server_capabilities.as_ref().unwrap().prompts.is_none() { + return Err(Error::RpcError { + code: METHOD_NOT_FOUND, + message: "Server does not support 'prompts' capability".to_string(), + }); + } + + let payload = next_cursor + .map(|cursor| serde_json::json!({"cursor": cursor})) + .unwrap_or_else(|| serde_json::json!({})); + + self.send_request("prompts/list", payload).await + } + + async fn get_prompt(&self, name: &str, arguments: Value) -> Result { + if !self.completed_initialization() { + return Err(Error::NotInitialized); + } + + // If prompts is not supported, return an error + if self.server_capabilities.as_ref().unwrap().prompts.is_none() { + return Err(Error::RpcError { + code: METHOD_NOT_FOUND, + message: "Server does not support 'prompts' capability".to_string(), + }); + } + + let params = serde_json::json!({ "name": name, "arguments": arguments }); + + self.send_request("prompts/get", params).await + } } diff --git a/crates/mcp-core/src/prompt.rs b/crates/mcp-core/src/prompt.rs index 7b814fd44..4a0106e34 100644 --- a/crates/mcp-core/src/prompt.rs +++ b/crates/mcp-core/src/prompt.rs @@ -10,22 +10,28 @@ use serde::{Deserialize, Serialize}; pub struct Prompt { /// The name of the prompt pub name: String, - /// A description of what the prompt does - pub description: String, - /// The arguments that can be passed to customize the prompt - pub arguments: Vec, + /// Optional description of what the prompt does + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + /// Optional arguments that can be passed to customize the prompt + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option>, } impl Prompt { /// Create a new prompt with the given name, description and arguments - pub fn new(name: N, description: D, arguments: Vec) -> Self + pub fn new( + name: N, + description: Option, + arguments: Option>, + ) -> Self where N: Into, D: Into, { Prompt { name: name.into(), - description: description.into(), + description: description.map(Into::into), arguments, } } @@ -37,9 +43,11 @@ pub struct PromptArgument { /// The name of the argument pub name: String, /// A description of what the argument is used for - pub description: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, /// Whether this argument is required - pub required: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub required: Option, } /// Represents the role of a message sender in a prompt conversation @@ -151,6 +159,6 @@ pub struct PromptTemplate { #[derive(Debug, Serialize, Deserialize)] pub struct PromptArgumentTemplate { pub name: String, - pub description: String, - pub required: bool, + pub description: Option, + pub required: Option, } diff --git a/crates/mcp-server/src/router.rs b/crates/mcp-server/src/router.rs index d2918311c..0060ffd92 100644 --- a/crates/mcp-server/src/router.rs +++ b/crates/mcp-server/src/router.rs @@ -305,18 +305,21 @@ pub trait Router: Send + Sync + 'static { }; // Validate required arguments - for arg in &prompt.arguments { - if arg.required - && (!arguments.contains_key(&arg.name) - || arguments - .get(&arg.name) - .and_then(Value::as_str) - .is_none_or(str::is_empty)) - { - return Err(RouterError::InvalidParams(format!( - "Missing required argument: '{}'", - arg.name - ))); + if let Some(args) = &prompt.arguments { + for arg in args { + if arg.required.is_some() + && arg.required.unwrap() + && (!arguments.contains_key(&arg.name) + || arguments + .get(&arg.name) + .and_then(Value::as_str) + .is_none_or(str::is_empty)) + { + return Err(RouterError::InvalidParams(format!( + "Missing required argument: '{}'", + arg.name + ))); + } } } From 5c4258b4e876a6a427e823a6401c899bc0285abe Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Thu, 20 Feb 2025 10:37:16 -0800 Subject: [PATCH 02/14] feat: handle JsonRpcMessage::Error messages to propagate to the user --- crates/mcp-client/src/transport/sse.rs | 20 +++++++++++++++----- crates/mcp-client/src/transport/stdio.rs | 14 +++++++++++--- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/crates/mcp-client/src/transport/sse.rs b/crates/mcp-client/src/transport/sse.rs index ed08e4800..90dc5f2f2 100644 --- a/crates/mcp-client/src/transport/sse.rs +++ b/crates/mcp-client/src/transport/sse.rs @@ -111,13 +111,23 @@ impl SseActor { // Attempt to parse the SSE data as a JsonRpcMessage match serde_json::from_str::(&e.data) { Ok(message) => { - // If it's a response, complete the pending request - if let JsonRpcMessage::Response(resp) = &message { - if let Some(id) = &resp.id { - pending_requests.respond(&id.to_string(), Ok(message)).await; + match &message { + JsonRpcMessage::Response(response) => { + if let Some(id) = &response.id { + pending_requests + .respond(&id.to_string(), Ok(message)) + .await; + } } + JsonRpcMessage::Error(error) => { + if let Some(id) = &error.id { + pending_requests + .respond(&id.to_string(), Ok(message)) + .await; + } + } + _ => {} // TODO: Handle other variants (Request, etc.) } - // If it's something else (notification, etc.), handle as needed } Err(err) => { warn!("Failed to parse SSE message: {err}"); diff --git a/crates/mcp-client/src/transport/stdio.rs b/crates/mcp-client/src/transport/stdio.rs index 59d900540..7980816bf 100644 --- a/crates/mcp-client/src/transport/stdio.rs +++ b/crates/mcp-client/src/transport/stdio.rs @@ -87,10 +87,18 @@ impl StdioActor { "Received incoming message" ); - if let JsonRpcMessage::Response(response) = &message { - if let Some(id) = &response.id { - pending_requests.respond(&id.to_string(), Ok(message)).await; + match &message { + JsonRpcMessage::Response(response) => { + if let Some(id) = &response.id { + pending_requests.respond(&id.to_string(), Ok(message)).await; + } } + JsonRpcMessage::Error(error) => { + if let Some(id) = &error.id { + pending_requests.respond(&id.to_string(), Ok(message)).await; + } + } + _ => {} // TODO: Handle other variants (Request, etc.) } } line.clear(); From 2523e61ce6fafa74f35ad8d664e7b66fa6049b9c Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Thu, 20 Feb 2025 16:07:48 -0800 Subject: [PATCH 03/14] test: update MockClient in test with list_prompts and get_prompt --- crates/goose/src/agents/capabilities.rs | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/crates/goose/src/agents/capabilities.rs b/crates/goose/src/agents/capabilities.rs index 783f15def..fb487979a 100644 --- a/crates/goose/src/agents/capabilities.rs +++ b/crates/goose/src/agents/capabilities.rs @@ -556,7 +556,8 @@ mod tests { use mcp_client::client::Error; use mcp_client::client::McpClientTrait; use mcp_core::protocol::{ - CallToolResult, InitializeResult, ListResourcesResult, ListToolsResult, ReadResourceResult, + CallToolResult, GetPromptResult, InitializeResult, ListPromptsResult, ListResourcesResult, + ListToolsResult, ReadResourceResult, }; use serde_json::json; @@ -625,6 +626,20 @@ mod tests { _ => Err(Error::NotInitialized), } } + async fn list_prompts( + &self, + _next_cursor: Option, + ) -> Result { + Err(Error::NotInitialized) + } + + async fn get_prompt( + &self, + _name: &str, + _arguments: Value, + ) -> Result { + Err(Error::NotInitialized) + } } #[test] From 529b7fd50578ebc315f202f688930f7cb2e18a67 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Fri, 21 Feb 2025 09:14:32 -0800 Subject: [PATCH 04/14] feat: remove concrete impl of get_prompt and list_prompts, and require implementing types to define them, similar to other methods --- crates/mcp-server/src/router.rs | 27 ++++++++++----------------- 1 file changed, 10 insertions(+), 17 deletions(-) diff --git a/crates/mcp-server/src/router.rs b/crates/mcp-server/src/router.rs index 0060ffd92..2c277d1c4 100644 --- a/crates/mcp-server/src/router.rs +++ b/crates/mcp-server/src/router.rs @@ -97,12 +97,8 @@ pub trait Router: Send + Sync + 'static { &self, uri: &str, ) -> Pin> + Send + 'static>>; - fn list_prompts(&self) -> Option> { - None - } - fn get_prompt(&self, _prompt_name: &str) -> Option { - None - } + fn list_prompts(&self) -> Vec; + fn get_prompt(&self, prompt_name: &str) -> PromptFuture; // Helper method to create base response fn create_response(&self, id: Option) -> JsonRpcResponse { @@ -257,7 +253,7 @@ pub trait Router: Send + Sync + 'static { req: JsonRpcRequest, ) -> impl Future> + Send { async move { - let prompts = self.list_prompts().unwrap_or_default(); + let prompts = self.list_prompts(); let result = ListPromptsResult { prompts }; @@ -294,15 +290,13 @@ pub trait Router: Send + Sync + 'static { .ok_or_else(|| RouterError::InvalidParams("Missing arguments object".into()))?; // Fetch the prompt definition first - let prompt = match self.list_prompts() { - Some(prompts) => prompts - .into_iter() - .find(|p| p.name == prompt_name) - .ok_or_else(|| { - RouterError::PromptNotFound(format!("Prompt '{}' not found", prompt_name)) - })?, - None => return Err(RouterError::PromptNotFound("No prompts available".into())), - }; + let prompt = self + .list_prompts() + .into_iter() + .find(|p| p.name == prompt_name) + .ok_or_else(|| { + RouterError::PromptNotFound(format!("Prompt '{}' not found", prompt_name)) + })?; // Validate required arguments if let Some(args) = &prompt.arguments { @@ -326,7 +320,6 @@ pub trait Router: Send + Sync + 'static { // Now get the prompt content let description = self .get_prompt(prompt_name) - .ok_or_else(|| RouterError::PromptNotFound("Prompt not found".into()))? .await .map_err(|e| RouterError::Internal(e.to_string()))?; From a0110ac819c275a5ec5f05e75fff7d39a3446a21 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Fri, 21 Feb 2025 09:16:34 -0800 Subject: [PATCH 05/14] test: add impl of list/get prompt to main.rs and stdio_integration to test both new methods --- .../mcp-client/examples/stdio_integration.rs | 11 ++++++ crates/mcp-server/src/main.rs | 35 ++++++++++++++++++- 2 files changed, 45 insertions(+), 1 deletion(-) diff --git a/crates/mcp-client/examples/stdio_integration.rs b/crates/mcp-client/examples/stdio_integration.rs index 9acd2086d..ffdcc10c3 100644 --- a/crates/mcp-client/examples/stdio_integration.rs +++ b/crates/mcp-client/examples/stdio_integration.rs @@ -82,5 +82,16 @@ async fn main() -> Result<(), ClientError> { let resource = client.read_resource("memo://insights").await?; println!("Resource: {resource:?}\n"); + let prompts = client.list_prompts(None).await?; + println!("Prompts: {prompts:?}\n"); + + let prompt = client + .get_prompt( + "example_prompt", + serde_json::json!({"message": "hello there!"}), + ) + .await?; + println!("Prompt: {prompt:?}\n"); + Ok(()) } diff --git a/crates/mcp-server/src/main.rs b/crates/mcp-server/src/main.rs index eee250025..907cc1b1c 100644 --- a/crates/mcp-server/src/main.rs +++ b/crates/mcp-server/src/main.rs @@ -1,6 +1,7 @@ use anyhow::Result; use mcp_core::content::Content; -use mcp_core::handler::ResourceError; +use mcp_core::handler::{PromptError, ResourceError}; +use mcp_core::prompt::{Prompt, PromptArgument}; use mcp_core::{handler::ToolError, protocol::ServerCapabilities, resource::Resource, tool::Tool}; use mcp_server::router::{CapabilitiesBuilder, RouterService}; use mcp_server::{ByteTransport, Router, Server}; @@ -61,6 +62,7 @@ impl Router for CounterRouter { CapabilitiesBuilder::new() .with_tools(false) .with_resources(false, false) + .with_prompts(false) .build() } @@ -153,6 +155,37 @@ impl Router for CounterRouter { } }) } + + fn list_prompts(&self) -> Vec { + vec![Prompt::new( + "example_prompt", + Some("This is an example prompt that takes one required agrument, message"), + Some(vec![PromptArgument { + name: "message".to_string(), + description: Some("A message to put in the prompt".to_string()), + required: Some(true), + }]), + )] + } + + fn get_prompt( + &self, + prompt_name: &str, + ) -> Pin> + Send + 'static>> { + let prompt_name = prompt_name.to_string(); + Box::pin(async move { + match prompt_name.as_str() { + "example_prompt" => { + let prompt = "This is an example prompt with your message here: '{message}'"; + Ok(prompt.to_string()) + } + _ => Err(PromptError::NotFound(format!( + "Prompt {} not found", + prompt_name + ))), + } + }) + } } #[tokio::main] From 0f776c07de63a6ed9960b90a932a6d2e9d157777 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Fri, 21 Feb 2025 10:59:19 -0800 Subject: [PATCH 06/14] refactor: implement list_prompts and get_prompts for mcp servers --- .../goose-mcp/src/computercontroller/mod.rs | 20 ++++++++++++++++++- crates/goose-mcp/src/developer/mod.rs | 18 +++++++---------- crates/goose-mcp/src/google_drive/mod.rs | 20 ++++++++++++++++++- crates/goose-mcp/src/jetbrains/mod.rs | 20 ++++++++++++++++++- crates/goose-mcp/src/memory/mod.rs | 19 +++++++++++++++++- crates/goose-mcp/src/tutorial/mod.rs | 20 ++++++++++++++++++- 6 files changed, 101 insertions(+), 16 deletions(-) diff --git a/crates/goose-mcp/src/computercontroller/mod.rs b/crates/goose-mcp/src/computercontroller/mod.rs index be74395b7..77d372db0 100644 --- a/crates/goose-mcp/src/computercontroller/mod.rs +++ b/crates/goose-mcp/src/computercontroller/mod.rs @@ -9,7 +9,8 @@ use std::{ use tokio::process::Command; use mcp_core::{ - handler::{ResourceError, ToolError}, + handler::{PromptError, ResourceError, ToolError}, + prompt::Prompt, protocol::ServerCapabilities, resource::Resource, tool::Tool, @@ -807,4 +808,21 @@ impl Router for ComputerControllerRouter { } }) } + + fn list_prompts(&self) -> Vec { + vec![] + } + + fn get_prompt( + &self, + prompt_name: &str, + ) -> Pin> + Send + 'static>> { + let prompt_name = prompt_name.to_string(); + Box::pin(async move { + Err(PromptError::NotFound(format!( + "Prompt {} not found", + prompt_name + ))) + }) + } } diff --git a/crates/goose-mcp/src/developer/mod.rs b/crates/goose-mcp/src/developer/mod.rs index 5cbdef982..a0a58b2a6 100644 --- a/crates/goose-mcp/src/developer/mod.rs +++ b/crates/goose-mcp/src/developer/mod.rs @@ -827,39 +827,35 @@ impl Router for DeveloperRouter { Box::pin(async move { Ok("".to_string()) }) } - fn list_prompts(&self) -> Option> { - if self.prompts.is_empty() { - None - } else { - Some(self.prompts.values().cloned().collect()) - } + fn list_prompts(&self) -> Vec { + self.prompts.values().cloned().collect() } fn get_prompt( &self, prompt_name: &str, - ) -> Option> + Send + 'static>>> { + ) -> Pin> + Send + 'static>> { let prompt_name = prompt_name.trim().to_owned(); // Validate prompt name is not empty if prompt_name.is_empty() { - return Some(Box::pin(async move { + return Box::pin(async move { Err(PromptError::InvalidParameters( "Prompt name cannot be empty".to_string(), )) - })); + }); } let prompts = Arc::clone(&self.prompts); - Some(Box::pin(async move { + Box::pin(async move { match prompts.get(&prompt_name) { Some(prompt) => Ok(prompt.description.clone().unwrap_or_default()), None => Err(PromptError::NotFound(format!( "Prompt '{prompt_name}' not found" ))), } - })) + }) } } diff --git a/crates/goose-mcp/src/google_drive/mod.rs b/crates/goose-mcp/src/google_drive/mod.rs index 2ba36a574..2ed1e7f14 100644 --- a/crates/goose-mcp/src/google_drive/mod.rs +++ b/crates/goose-mcp/src/google_drive/mod.rs @@ -5,7 +5,8 @@ use serde_json::{json, Value}; use std::{env, fs, future::Future, io::Write, path::Path, pin::Pin}; use mcp_core::{ - handler::{ResourceError, ToolError}, + handler::{PromptError, ResourceError, ToolError}, + prompt::Prompt, protocol::ServerCapabilities, resource::Resource, tool::Tool, @@ -618,6 +619,23 @@ impl Router for GoogleDriveRouter { let uri_clone = uri.to_string(); Box::pin(async move { this.read_google_resource(uri_clone).await }) } + + fn list_prompts(&self) -> Vec { + vec![] + } + + fn get_prompt( + &self, + prompt_name: &str, + ) -> Pin> + Send + 'static>> { + let prompt_name = prompt_name.to_string(); + Box::pin(async move { + Err(PromptError::NotFound(format!( + "Prompt {} not found", + prompt_name + ))) + }) + } } impl Clone for GoogleDriveRouter { diff --git a/crates/goose-mcp/src/jetbrains/mod.rs b/crates/goose-mcp/src/jetbrains/mod.rs index 319cdcd36..0cdf80189 100644 --- a/crates/goose-mcp/src/jetbrains/mod.rs +++ b/crates/goose-mcp/src/jetbrains/mod.rs @@ -3,7 +3,8 @@ mod proxy; use anyhow::Result; use mcp_core::{ content::Content, - handler::{ResourceError, ToolError}, + handler::{PromptError, ResourceError, ToolError}, + prompt::Prompt, protocol::ServerCapabilities, resource::Resource, role::Role, @@ -176,6 +177,23 @@ impl Router for JetBrainsRouter { ) -> Pin> + Send + 'static>> { Box::pin(async { Err(ResourceError::NotFound("Resource not found".into())) }) } + + fn list_prompts(&self) -> Vec { + vec![] + } + + fn get_prompt( + &self, + prompt_name: &str, + ) -> Pin> + Send + 'static>> { + let prompt_name = prompt_name.to_string(); + Box::pin(async move { + Err(PromptError::NotFound(format!( + "Prompt {} not found", + prompt_name + ))) + }) + } } impl Clone for JetBrainsRouter { diff --git a/crates/goose-mcp/src/memory/mod.rs b/crates/goose-mcp/src/memory/mod.rs index 4a7411a54..a9fd1fa39 100644 --- a/crates/goose-mcp/src/memory/mod.rs +++ b/crates/goose-mcp/src/memory/mod.rs @@ -12,7 +12,8 @@ use std::{ }; use mcp_core::{ - handler::{ResourceError, ToolError}, + handler::{PromptError, ResourceError, ToolError}, + prompt::Prompt, protocol::ServerCapabilities, resource::Resource, tool::{Tool, ToolCall}, @@ -493,6 +494,22 @@ impl Router for MemoryRouter { ) -> Pin> + Send + 'static>> { Box::pin(async move { Ok("".to_string()) }) } + fn list_prompts(&self) -> Vec { + vec![] + } + + fn get_prompt( + &self, + prompt_name: &str, + ) -> Pin> + Send + 'static>> { + let prompt_name = prompt_name.to_string(); + Box::pin(async move { + Err(PromptError::NotFound(format!( + "Prompt {} not found", + prompt_name + ))) + }) + } } #[derive(Debug)] diff --git a/crates/goose-mcp/src/tutorial/mod.rs b/crates/goose-mcp/src/tutorial/mod.rs index 9d6ba3d7c..2f32b03ac 100644 --- a/crates/goose-mcp/src/tutorial/mod.rs +++ b/crates/goose-mcp/src/tutorial/mod.rs @@ -5,7 +5,8 @@ use serde_json::{json, Value}; use std::{future::Future, pin::Pin}; use mcp_core::{ - handler::{ResourceError, ToolError}, + handler::{PromptError, ResourceError, ToolError}, + prompt::Prompt, protocol::ServerCapabilities, resource::Resource, role::Role, @@ -156,6 +157,23 @@ impl Router for TutorialRouter { ) -> Pin> + Send + 'static>> { Box::pin(async move { Ok("".to_string()) }) } + + fn list_prompts(&self) -> Vec { + vec![] + } + + fn get_prompt( + &self, + prompt_name: &str, + ) -> Pin> + Send + 'static>> { + let prompt_name = prompt_name.to_string(); + Box::pin(async move { + Err(PromptError::NotFound(format!( + "Prompt {} not found", + prompt_name + ))) + }) + } } impl Clone for TutorialRouter { From 7d0cb115b1dd365205cff45ec8023e6335adf42f Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Wed, 19 Feb 2025 16:01:29 -0800 Subject: [PATCH 07/14] feat: add list prompts command - extend CLI input handling to support `/prompts` for listing available prompts - add `ListPrompts` variant in the input enum and update help documentation - implement prompt rendering in the session output module - update agent traits and capabilities to aggregate and list prompts from all extensions --- crates/goose-cli/src/session/input.rs | 4 ++ crates/goose-cli/src/session/mod.rs | 8 +++ crates/goose-cli/src/session/output.rs | 12 ++++ crates/goose/src/agents/agent.rs | 5 ++ crates/goose/src/agents/capabilities.rs | 81 ++++++++++++++++++++++++- crates/goose/src/agents/reference.rs | 18 ++++++ crates/goose/src/agents/truncate.rs | 18 ++++++ 7 files changed, 145 insertions(+), 1 deletion(-) diff --git a/crates/goose-cli/src/session/input.rs b/crates/goose-cli/src/session/input.rs index 7cfa94d35..0245a4563 100644 --- a/crates/goose-cli/src/session/input.rs +++ b/crates/goose-cli/src/session/input.rs @@ -9,6 +9,8 @@ pub enum InputResult { AddBuiltin(String), ToggleTheme, Retry, + ListPrompts, + //UsePrompt(String), } pub fn get_input( @@ -59,6 +61,7 @@ fn handle_slash_command(input: &str) -> Option { Some(InputResult::Retry) } "/t" => Some(InputResult::ToggleTheme), + "/prompts" => Some(InputResult::ListPrompts), s if s.starts_with("/extension ") => Some(InputResult::AddExtension(s[11..].to_string())), s if s.starts_with("/builtin ") => Some(InputResult::AddBuiltin(s[9..].to_string())), _ => None, @@ -72,6 +75,7 @@ fn print_help() { /t - Toggle Light/Dark/Ansi theme /extension - Add a stdio extension (format: ENV1=val1 command args...) /builtin - Add builtin extensions by name (comma-separated) +/prompts - List all available prompts by name /? or /help - Display this help message Navigation: diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index d359a297a..d6db761c7 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -14,6 +14,7 @@ use goose::agents::Agent; use goose::message::{Message, MessageContent}; use mcp_core::handler::ToolError; use rand::{distributions::Alphanumeric, Rng}; +use std::collections::HashMap; use std::path::PathBuf; use tokio; @@ -103,6 +104,10 @@ impl Session { Ok(()) } + pub async fn list_prompts(&mut self) -> HashMap> { + self.agent.list_extension_prompts().await + } + pub async fn start(&mut self) -> Result<()> { let mut editor = rustyline::Editor::<(), rustyline::history::DefaultHistory>::new()?; @@ -165,6 +170,9 @@ impl Session { continue; } input::InputResult::Retry => continue, + input::InputResult::ListPrompts => { + output::render_prompts(&self.list_prompts().await) + } } } diff --git a/crates/goose-cli/src/session/output.rs b/crates/goose-cli/src/session/output.rs index f6ccbdfcb..47ad3248a 100644 --- a/crates/goose-cli/src/session/output.rs +++ b/crates/goose-cli/src/session/output.rs @@ -5,6 +5,7 @@ use goose::message::{Message, MessageContent, ToolRequest, ToolResponse}; use mcp_core::tool::ToolCall; use serde_json::Value; use std::cell::RefCell; +use std::collections::HashMap; use std::path::Path; // Re-export theme for use in main @@ -151,6 +152,17 @@ pub fn render_error(message: &str) { println!("\n {} {}\n", style("error:").red().bold(), message); } +pub fn render_prompts(prompts: &HashMap>) { + println!(); + for (extension, prompts) in prompts { + println!(" {}", style(extension).green()); + for prompt in prompts { + println!(" - {}", style(prompt).cyan()); + } + } + println!(); +} + pub fn render_extension_success(name: &str) { println!(); println!( diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 4500f95d1..589c98f75 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use anyhow::Result; use async_trait::async_trait; use futures::stream::BoxStream; @@ -34,4 +36,7 @@ pub trait Agent: Send + Sync { /// Override the system prompt with custom text async fn override_system_prompt(&mut self, template: String); + + /// Lists all prompts from all extensions + async fn list_extension_prompts(&self) -> HashMap>; } diff --git a/crates/goose/src/agents/capabilities.rs b/crates/goose/src/agents/capabilities.rs index fb487979a..caf3b24ed 100644 --- a/crates/goose/src/agents/capabilities.rs +++ b/crates/goose/src/agents/capabilities.rs @@ -1,3 +1,4 @@ +use anyhow::Result; use chrono::{DateTime, TimeZone, Utc}; use futures::stream::{FuturesUnordered, StreamExt}; use mcp_client::McpService; @@ -13,7 +14,7 @@ use crate::prompt_template::{load_prompt, load_prompt_file}; use crate::providers::base::{Provider, ProviderUsage}; use mcp_client::client::{ClientCapabilities, ClientInfo, McpClient, McpClientTrait}; use mcp_client::transport::{SseTransport, StdioTransport, Transport}; -use mcp_core::{Content, Tool, ToolCall, ToolError, ToolResult}; +use mcp_core::{prompt::Prompt, Content, Tool, ToolCall, ToolError, ToolResult}; use serde_json::Value; // By default, we set it to Jan 1, 2020 if the resource does not have a timestamp @@ -544,6 +545,69 @@ impl Capabilities { result } + + pub async fn list_prompts_from_extension( + &self, + extension_name: &str, + ) -> Result, ToolError> { + let client = self.clients.get(extension_name).ok_or_else(|| { + ToolError::InvalidParameters(format!("Extension {} is not valid", extension_name)) + })?; + + let client_guard = client.lock().await; + client_guard + .list_prompts(None) + .await + .map_err(|e| { + ToolError::ExecutionError(format!( + "Unable to list prompts for {}, {:?}", + extension_name, e + )) + }) + .map(|lp| lp.prompts) + } + + pub async fn list_prompts(&self) -> Result>, ToolError> { + let mut futures = FuturesUnordered::new(); + + for extension_name in self.clients.keys() { + futures.push(async move { + ( + extension_name, + self.list_prompts_from_extension(extension_name).await, + ) + }); + } + + let mut all_prompts = HashMap::new(); + let mut errors = Vec::new(); + + // Process results as they complete + while let Some(result) = futures.next().await { + let (name, prompts) = result; + match prompts { + Ok(content) => { + all_prompts.insert(name.to_string(), content); + } + Err(tool_error) => { + errors.push(tool_error); + } + } + } + + // Log any errors that occurred + if !errors.is_empty() { + tracing::error!( + errors = ?errors + .into_iter() + .map(|e| format!("{:?}", e)) + .collect::>(), + "errors from listing prompts" + ); + } + + Ok(all_prompts) + } } #[cfg(test)] @@ -617,6 +681,21 @@ mod tests { Err(Error::NotInitialized) } + async fn list_prompts( + &self, + _next_cursor: Option, + ) -> Result { + Err(Error::NotInitialized) + } + + async fn get_prompt( + &self, + _name: &str, + _arguments: Value, + ) -> Result { + Err(Error::NotInitialized) + } + async fn call_tool(&self, name: &str, _arguments: Value) -> Result { match name { "tool" | "test__tool" => Ok(CallToolResult { diff --git a/crates/goose/src/agents/reference.rs b/crates/goose/src/agents/reference.rs index 6c30435d9..7f9074a7d 100644 --- a/crates/goose/src/agents/reference.rs +++ b/crates/goose/src/agents/reference.rs @@ -2,6 +2,7 @@ /// It makes no attempt to handle context limits, and cannot read resources use async_trait::async_trait; use futures::stream::BoxStream; +use std::collections::HashMap; use tokio::sync::Mutex; use tracing::{debug, instrument}; @@ -194,6 +195,23 @@ impl Agent for ReferenceAgent { let mut capabilities = self.capabilities.lock().await; capabilities.set_system_prompt_override(template); } + + async fn list_extension_prompts(&self) -> HashMap> { + let capabilities = self.capabilities.lock().await; + capabilities + .list_prompts() + .await + .map(|prompts| { + prompts + .into_iter() + .map(|(extension, prompt_list)| { + let names = prompt_list.into_iter().map(|p| p.name).collect(); + (extension, names) + }) + .collect() + }) + .expect("Failed to list prompts") + } } register_agent!("reference", ReferenceAgent); diff --git a/crates/goose/src/agents/truncate.rs b/crates/goose/src/agents/truncate.rs index 685524d4b..c19bd53c1 100644 --- a/crates/goose/src/agents/truncate.rs +++ b/crates/goose/src/agents/truncate.rs @@ -2,6 +2,7 @@ /// It makes no attempt to handle context limits, and cannot read resources use async_trait::async_trait; use futures::stream::BoxStream; +use std::collections::HashMap; use tokio::sync::Mutex; use tracing::{debug, error, instrument, warn}; @@ -302,6 +303,23 @@ impl Agent for TruncateAgent { let mut capabilities = self.capabilities.lock().await; capabilities.set_system_prompt_override(template); } + + async fn list_extension_prompts(&self) -> HashMap> { + let capabilities = self.capabilities.lock().await; + capabilities + .list_prompts() + .await + .map(|prompts| { + prompts + .into_iter() + .map(|(extension, prompt_list)| { + let names = prompt_list.into_iter().map(|p| p.name).collect(); + (extension, names) + }) + .collect() + }) + .expect("Failed to list prompts") + } } register_agent!("truncate", TruncateAgent); From ead9c1da43c04e431b4f36f897c0a1fc37954bc3 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Wed, 19 Feb 2025 17:09:03 -0800 Subject: [PATCH 08/14] feat: add /prompt $NAME --info and placeholder for exeuction --- crates/goose-cli/src/session/input.rs | 89 +++++++++++++++++++++++--- crates/goose-cli/src/session/mod.rs | 40 +++++++++++- crates/goose-cli/src/session/output.rs | 43 +++++++++++++ crates/goose/src/agents/agent.rs | 3 +- crates/goose/src/agents/reference.rs | 12 +--- crates/goose/src/agents/truncate.rs | 12 +--- 6 files changed, 167 insertions(+), 32 deletions(-) diff --git a/crates/goose-cli/src/session/input.rs b/crates/goose-cli/src/session/input.rs index 0245a4563..32545e30f 100644 --- a/crates/goose-cli/src/session/input.rs +++ b/crates/goose-cli/src/session/input.rs @@ -1,5 +1,6 @@ use anyhow::Result; use rustyline::Editor; +use std::collections::HashMap; #[derive(Debug)] pub enum InputResult { @@ -10,7 +11,14 @@ pub enum InputResult { ToggleTheme, Retry, ListPrompts, - //UsePrompt(String), + PromptCommand(PromptCommandOptions), +} + +#[derive(Debug)] +pub struct PromptCommandOptions { + pub name: String, + pub info: bool, + pub arguments: HashMap, } pub fn get_input( @@ -52,22 +60,55 @@ pub fn get_input( } fn handle_slash_command(input: &str) -> Option { - let input = input.trim(); - - match input { - "/exit" | "/quit" => Some(InputResult::Exit), - "/?" | "/help" => { + let parts: Vec<&str> = input.trim().split_whitespace().collect(); + match parts.get(0).map(|s| *s) { + Some("/exit") | Some("/quit") => Some(InputResult::Exit), + Some("/?") | Some("/help") => { print_help(); Some(InputResult::Retry) } - "/t" => Some(InputResult::ToggleTheme), - "/prompts" => Some(InputResult::ListPrompts), - s if s.starts_with("/extension ") => Some(InputResult::AddExtension(s[11..].to_string())), - s if s.starts_with("/builtin ") => Some(InputResult::AddBuiltin(s[9..].to_string())), + Some("/t") => Some(InputResult::ToggleTheme), + Some("/prompts") => Some(InputResult::ListPrompts), + Some("/prompt") => parse_prompt_command(&parts[1..]), + Some(s) if s.starts_with("/extension ") => { + Some(InputResult::AddExtension(s[11..].to_string())) + } + Some(s) if s.starts_with("/builtin ") => Some(InputResult::AddBuiltin(s[9..].to_string())), _ => None, } } +fn parse_prompt_command(args: &[&str]) -> Option { + if args.is_empty() { + return None; + } + + let mut options = PromptCommandOptions { + name: args[0].to_string(), + info: false, + arguments: HashMap::new(), + }; + + // Parse remaining arguments + let mut i = 1; + while i < args.len() { + match args[i] { + "--info" => { + options.info = true; + } + arg if arg.contains('=') => { + if let Some((key, value)) = arg.split_once('=') { + options.arguments.insert(key.to_string(), value.to_string()); + } + } + _ => return None, // Invalid format + } + i += 1; + } + + Some(InputResult::PromptCommand(options)) +} + fn print_help() { println!( "Available commands: @@ -76,6 +117,7 @@ fn print_help() { /extension - Add a stdio extension (format: ENV1=val1 command args...) /builtin - Add builtin extensions by name (comma-separated) /prompts - List all available prompts by name +/prompt [--info] [key=value...] - Get prompt info or execute a prompt /? or /help - Display this help message Navigation: @@ -135,6 +177,33 @@ mod tests { assert!(handle_slash_command("/unknown").is_none()); } + #[test] + fn test_prompt_command() { + // Test basic prompt info command + if let Some(InputResult::PromptCommand(opts)) = + handle_slash_command("/prompt test-prompt --info") + { + assert_eq!(opts.name, "test-prompt"); + assert!(opts.info); + assert!(opts.arguments.is_empty()); + } else { + panic!("Expected PromptCommand"); + } + + // Test prompt with arguments + if let Some(InputResult::PromptCommand(opts)) = + handle_slash_command("/prompt test-prompt arg1=val1 arg2=val2") + { + assert_eq!(opts.name, "test-prompt"); + assert!(!opts.info); + assert_eq!(opts.arguments.len(), 2); + assert_eq!(opts.arguments.get("arg1"), Some(&"val1".to_string())); + assert_eq!(opts.arguments.get("arg2"), Some(&"val2".to_string())); + } else { + panic!("Expected PromptCommand"); + } + } + // Test whitespace handling #[test] fn test_whitespace_handling() { diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index d6db761c7..475e10c4d 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -13,6 +13,7 @@ use goose::agents::extension::{Envs, ExtensionConfig}; use goose::agents::Agent; use goose::message::{Message, MessageContent}; use mcp_core::handler::ToolError; + use rand::{distributions::Alphanumeric, Rng}; use std::collections::HashMap; use std::path::PathBuf; @@ -105,7 +106,32 @@ impl Session { } pub async fn list_prompts(&mut self) -> HashMap> { - self.agent.list_extension_prompts().await + let prompts = self.agent.list_extension_prompts().await; + prompts + .into_iter() + .map(|(extension, prompt_list)| { + let names = prompt_list.into_iter().map(|p| p.name).collect(); + (extension, names) + }) + .collect() + } + + pub async fn get_prompt_info(&mut self, name: &str) -> Result> { + let prompts = self.agent.list_extension_prompts().await; + + // Find which extension has this prompt + for (extension, prompt_list) in prompts { + if let Some(prompt) = prompt_list.iter().find(|p| p.name == name) { + return Ok(Some(output::PromptInfo { + name: prompt.name.clone(), + description: prompt.description.clone(), + arguments: prompt.arguments.clone(), + extension: Some(extension), + })); + } + } + + Ok(None) } pub async fn start(&mut self) -> Result<()> { @@ -173,6 +199,18 @@ impl Session { input::InputResult::ListPrompts => { output::render_prompts(&self.list_prompts().await) } + input::InputResult::PromptCommand(opts) => { + if opts.info { + match self.get_prompt_info(&opts.name).await? { + Some(info) => output::render_prompt_info(&info), + None => { + output::render_error(&format!("Prompt '{}' not found", opts.name)) + } + } + } else { + output::render_error("Prompt execution not yet implemented"); + } + } } } diff --git a/crates/goose-cli/src/session/output.rs b/crates/goose-cli/src/session/output.rs index 47ad3248a..93af711db 100644 --- a/crates/goose-cli/src/session/output.rs +++ b/crates/goose-cli/src/session/output.rs @@ -2,6 +2,7 @@ use bat::WrappingMode; use console::style; use goose::config::Config; use goose::message::{Message, MessageContent, ToolRequest, ToolResponse}; +use mcp_core::prompt::PromptArgument; use mcp_core::tool::ToolCall; use serde_json::Value; use std::cell::RefCell; @@ -74,6 +75,14 @@ impl ThinkingIndicator { } } +#[derive(Debug)] +pub struct PromptInfo { + pub name: String, + pub description: Option, + pub arguments: Option>, + pub extension: Option, +} + // Global thinking indicator thread_local! { static THINKING: RefCell = RefCell::new(ThinkingIndicator::default()); @@ -163,6 +172,40 @@ pub fn render_prompts(prompts: &HashMap>) { println!(); } +pub fn render_prompt_info(info: &PromptInfo) { + println!(); + + if let Some(ext) = &info.extension { + println!(" {}: {}", style("Extension").green(), ext); + } + + println!("Prompt: {}", style(&info.name).cyan().bold()); + + if let Some(desc) = &info.description { + println!("\n {}", desc); + } + + if let Some(args) = &info.arguments { + println!("\n Arguments:"); + for arg in args { + let required = arg.required.unwrap_or(false); + let req_str = if required { + style("(required)").red() + } else { + style("(optional)").dim() + }; + + println!( + " {} {} {}", + style(&arg.name).yellow(), + req_str, + arg.description.as_deref().unwrap_or("") + ); + } + } + println!(); +} + pub fn render_extension_success(name: &str) { println!(); println!( diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 589c98f75..1ef94d355 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -8,6 +8,7 @@ use serde_json::Value; use super::extension::{ExtensionConfig, ExtensionResult}; use crate::message::Message; use crate::providers::base::ProviderUsage; +use mcp_core::prompt::Prompt; /// Core trait defining the behavior of an Agent #[async_trait] @@ -38,5 +39,5 @@ pub trait Agent: Send + Sync { async fn override_system_prompt(&mut self, template: String); /// Lists all prompts from all extensions - async fn list_extension_prompts(&self) -> HashMap>; + async fn list_extension_prompts(&self) -> HashMap>; } diff --git a/crates/goose/src/agents/reference.rs b/crates/goose/src/agents/reference.rs index 7f9074a7d..c11bbff4a 100644 --- a/crates/goose/src/agents/reference.rs +++ b/crates/goose/src/agents/reference.rs @@ -15,6 +15,7 @@ use crate::providers::base::ProviderUsage; use crate::register_agent; use crate::token_counter::TokenCounter; use indoc::indoc; +use mcp_core::prompt::Prompt; use mcp_core::tool::Tool; use serde_json::{json, Value}; @@ -196,20 +197,11 @@ impl Agent for ReferenceAgent { capabilities.set_system_prompt_override(template); } - async fn list_extension_prompts(&self) -> HashMap> { + async fn list_extension_prompts(&self) -> HashMap> { let capabilities = self.capabilities.lock().await; capabilities .list_prompts() .await - .map(|prompts| { - prompts - .into_iter() - .map(|(extension, prompt_list)| { - let names = prompt_list.into_iter().map(|p| p.name).collect(); - (extension, names) - }) - .collect() - }) .expect("Failed to list prompts") } } diff --git a/crates/goose/src/agents/truncate.rs b/crates/goose/src/agents/truncate.rs index c19bd53c1..ca9644776 100644 --- a/crates/goose/src/agents/truncate.rs +++ b/crates/goose/src/agents/truncate.rs @@ -17,6 +17,7 @@ use crate::register_agent; use crate::token_counter::TokenCounter; use crate::truncate::{truncate_messages, OldestFirstTruncation}; use indoc::indoc; +use mcp_core::prompt::Prompt; use mcp_core::tool::Tool; use serde_json::{json, Value}; @@ -304,20 +305,11 @@ impl Agent for TruncateAgent { capabilities.set_system_prompt_override(template); } - async fn list_extension_prompts(&self) -> HashMap> { + async fn list_extension_prompts(&self) -> HashMap> { let capabilities = self.capabilities.lock().await; capabilities .list_prompts() .await - .map(|prompts| { - prompts - .into_iter() - .map(|(extension, prompt_list)| { - let names = prompt_list.into_iter().map(|p| p.name).collect(); - (extension, names) - }) - .collect() - }) .expect("Failed to list prompts") } } From 8025a936187fa29965feca308dee32ea7fc9f9dc Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Fri, 21 Feb 2025 11:20:57 -0800 Subject: [PATCH 09/14] refactor: revert handle_slash_command, match existing patterns fix: cherry-pick conflicts resolved --- crates/goose-cli/src/session/input.rs | 33 +++++++++++++------------ crates/goose-cli/src/session/output.rs | 2 +- crates/goose/src/agents/capabilities.rs | 16 +----------- 3 files changed, 19 insertions(+), 32 deletions(-) diff --git a/crates/goose-cli/src/session/input.rs b/crates/goose-cli/src/session/input.rs index 32545e30f..b865f55a1 100644 --- a/crates/goose-cli/src/session/input.rs +++ b/crates/goose-cli/src/session/input.rs @@ -60,39 +60,40 @@ pub fn get_input( } fn handle_slash_command(input: &str) -> Option { - let parts: Vec<&str> = input.trim().split_whitespace().collect(); - match parts.get(0).map(|s| *s) { - Some("/exit") | Some("/quit") => Some(InputResult::Exit), - Some("/?") | Some("/help") => { + let input = input.trim(); + + match input { + "/exit" | "/quit" => Some(InputResult::Exit), + "/?" | "/help" => { print_help(); Some(InputResult::Retry) } - Some("/t") => Some(InputResult::ToggleTheme), - Some("/prompts") => Some(InputResult::ListPrompts), - Some("/prompt") => parse_prompt_command(&parts[1..]), - Some(s) if s.starts_with("/extension ") => { - Some(InputResult::AddExtension(s[11..].to_string())) - } - Some(s) if s.starts_with("/builtin ") => Some(InputResult::AddBuiltin(s[9..].to_string())), + "/t" => Some(InputResult::ToggleTheme), + "/prompts" => Some(InputResult::ListPrompts), + s if s.starts_with("/prompt ") => parse_prompt_command(&s[8..]), + s if s.starts_with("/extension ") => Some(InputResult::AddExtension(s[11..].to_string())), + s if s.starts_with("/builtin ") => Some(InputResult::AddBuiltin(s[9..].to_string())), _ => None, } } -fn parse_prompt_command(args: &[&str]) -> Option { - if args.is_empty() { +fn parse_prompt_command(args: &str) -> Option { + let parts: Vec<&str> = args.split_whitespace().collect(); + + if parts.is_empty() { return None; } let mut options = PromptCommandOptions { - name: args[0].to_string(), + name: parts[0].to_string(), info: false, arguments: HashMap::new(), }; // Parse remaining arguments let mut i = 1; - while i < args.len() { - match args[i] { + while i < parts.len() { + match parts[i] { "--info" => { options.info = true; } diff --git a/crates/goose-cli/src/session/output.rs b/crates/goose-cli/src/session/output.rs index 93af711db..525c48575 100644 --- a/crates/goose-cli/src/session/output.rs +++ b/crates/goose-cli/src/session/output.rs @@ -179,7 +179,7 @@ pub fn render_prompt_info(info: &PromptInfo) { println!(" {}: {}", style("Extension").green(), ext); } - println!("Prompt: {}", style(&info.name).cyan().bold()); + println!(" Prompt: {}", style(&info.name).cyan().bold()); if let Some(desc) = &info.description { println!("\n {}", desc); diff --git a/crates/goose/src/agents/capabilities.rs b/crates/goose/src/agents/capabilities.rs index caf3b24ed..fc4762ea4 100644 --- a/crates/goose/src/agents/capabilities.rs +++ b/crates/goose/src/agents/capabilities.rs @@ -681,21 +681,6 @@ mod tests { Err(Error::NotInitialized) } - async fn list_prompts( - &self, - _next_cursor: Option, - ) -> Result { - Err(Error::NotInitialized) - } - - async fn get_prompt( - &self, - _name: &str, - _arguments: Value, - ) -> Result { - Err(Error::NotInitialized) - } - async fn call_tool(&self, name: &str, _arguments: Value) -> Result { match name { "tool" | "test__tool" => Ok(CallToolResult { @@ -705,6 +690,7 @@ mod tests { _ => Err(Error::NotInitialized), } } + async fn list_prompts( &self, _next_cursor: Option, From 2407051ef25e1cadfa02f0bb429fe688f896bc0b Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Tue, 25 Feb 2025 13:57:34 -0800 Subject: [PATCH 10/14] feat: first pass at enabling /prompt support, just rendering the output from the mcp server --- Cargo.lock | 1 + crates/goose-cli/Cargo.toml | 1 + crates/goose-cli/src/session/input.rs | 124 ++++++++++++++++++++---- crates/goose-cli/src/session/mod.rs | 29 +++++- crates/goose/src/agents/agent.rs | 5 + crates/goose/src/agents/capabilities.rs | 19 ++++ crates/goose/src/agents/reference.rs | 25 +++++ crates/goose/src/agents/truncate.rs | 25 +++++ 8 files changed, 211 insertions(+), 18 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index f9675c094..6457a36d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2203,6 +2203,7 @@ dependencies = [ "serde", "serde_json", "serde_yaml", + "shlex", "temp-env", "tempfile", "test-case", diff --git a/crates/goose-cli/Cargo.toml b/crates/goose-cli/Cargo.toml index 41cb85a60..7addc964d 100644 --- a/crates/goose-cli/Cargo.toml +++ b/crates/goose-cli/Cargo.toml @@ -47,6 +47,7 @@ chrono = "0.4" tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt", "json", "time"] } tracing-appender = "0.2" once_cell = "1.20.2" +shlex = "1.3.0" [target.'cfg(target_os = "windows")'.dependencies] winapi = { version = "0.3", features = ["wincred"] } diff --git a/crates/goose-cli/src/session/input.rs b/crates/goose-cli/src/session/input.rs index b865f55a1..38fd2ea58 100644 --- a/crates/goose-cli/src/session/input.rs +++ b/crates/goose-cli/src/session/input.rs @@ -1,5 +1,6 @@ use anyhow::Result; use rustyline::Editor; +use shlex; use std::collections::HashMap; #[derive(Debug)] @@ -70,7 +71,22 @@ fn handle_slash_command(input: &str) -> Option { } "/t" => Some(InputResult::ToggleTheme), "/prompts" => Some(InputResult::ListPrompts), - s if s.starts_with("/prompt ") => parse_prompt_command(&s[8..]), + s if s.starts_with("/prompt") => { + if s == "/prompt" { + // No arguments case + Some(InputResult::PromptCommand(PromptCommandOptions { + name: String::new(), // Empty name will trigger the error message in the rendering + info: false, + arguments: HashMap::new(), + })) + } else if let Some(stripped) = s.strip_prefix("/prompt ") { + // Has arguments case + parse_prompt_command(stripped) + } else { + // Handle invalid cases like "/promptxyz" + None + } + } s if s.starts_with("/extension ") => Some(InputResult::AddExtension(s[11..].to_string())), s if s.starts_with("/builtin ") => Some(InputResult::AddBuiltin(s[9..].to_string())), _ => None, @@ -78,32 +94,37 @@ fn handle_slash_command(input: &str) -> Option { } fn parse_prompt_command(args: &str) -> Option { - let parts: Vec<&str> = args.split_whitespace().collect(); - - if parts.is_empty() { - return None; - } + let parts: Vec = shlex::split(args).unwrap_or_default(); + // set name to empty and error out in the rendering let mut options = PromptCommandOptions { - name: parts[0].to_string(), + name: parts.first().cloned().unwrap_or_default(), info: false, arguments: HashMap::new(), }; + // handle info at any point in the command + if parts.iter().any(|part| part == "--info") { + options.info = true; + } + // Parse remaining arguments let mut i = 1; + while i < parts.len() { - match parts[i] { - "--info" => { - options.info = true; - } - arg if arg.contains('=') => { - if let Some((key, value)) = arg.split_once('=') { - options.arguments.insert(key.to_string(), value.to_string()); - } - } - _ => return None, // Invalid format + let part = &parts[i]; + + // Skip flag arguments + if part == "--info" { + i += 1; + continue; + } + + // Process key=value pairs - removed redundant contains check + if let Some((key, value)) = part.split_once('=') { + options.arguments.insert(key.to_string(), value.to_string()); } + i += 1; } @@ -223,4 +244,73 @@ mod tests { panic!("Expected AddBuiltin"); } } + + // Test prompt with no arguments + #[test] + fn test_prompt_no_args() { + // Test just "/prompt" with no arguments + if let Some(InputResult::PromptCommand(opts)) = handle_slash_command("/prompt") { + assert_eq!(opts.name, ""); + assert!(!opts.info); + assert!(opts.arguments.is_empty()); + } else { + panic!("Expected PromptCommand"); + } + + // Test invalid prompt command + assert!(handle_slash_command("/promptxyz").is_none()); + } + + // Test quoted arguments + #[test] + fn test_quoted_arguments() { + // Test prompt with quoted arguments + if let Some(InputResult::PromptCommand(opts)) = handle_slash_command( + r#"/prompt test-prompt arg1="value with spaces" arg2="another value""#, + ) { + assert_eq!(opts.name, "test-prompt"); + assert_eq!(opts.arguments.len(), 2); + assert_eq!( + opts.arguments.get("arg1"), + Some(&"value with spaces".to_string()) + ); + assert_eq!( + opts.arguments.get("arg2"), + Some(&"another value".to_string()) + ); + } else { + panic!("Expected PromptCommand"); + } + + // Test prompt with mixed quoted and unquoted arguments + if let Some(InputResult::PromptCommand(opts)) = handle_slash_command( + r#"/prompt test-prompt simple=value quoted="value with \"nested\" quotes""#, + ) { + assert_eq!(opts.name, "test-prompt"); + assert_eq!(opts.arguments.len(), 2); + assert_eq!(opts.arguments.get("simple"), Some(&"value".to_string())); + assert_eq!( + opts.arguments.get("quoted"), + Some(&r#"value with "nested" quotes"#.to_string()) + ); + } else { + panic!("Expected PromptCommand"); + } + } + + // Test invalid arguments + #[test] + fn test_invalid_arguments() { + // Test prompt with invalid arguments + if let Some(InputResult::PromptCommand(opts)) = + handle_slash_command(r#"/prompt test-prompt valid=value invalid_arg another_invalid"#) + { + assert_eq!(opts.name, "test-prompt"); + assert_eq!(opts.arguments.len(), 1); + assert_eq!(opts.arguments.get("valid"), Some(&"value".to_string())); + // Invalid arguments are ignored but logged + } else { + panic!("Expected PromptCommand"); + } + } } diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 475e10c4d..3e142b203 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -13,8 +13,10 @@ use goose::agents::extension::{Envs, ExtensionConfig}; use goose::agents::Agent; use goose::message::{Message, MessageContent}; use mcp_core::handler::ToolError; +use mcp_core::prompt::PromptMessage; use rand::{distributions::Alphanumeric, Rng}; +use serde_json::Value; use std::collections::HashMap; use std::path::PathBuf; use tokio; @@ -134,6 +136,11 @@ impl Session { Ok(None) } + pub async fn get_prompt(&mut self, name: &str, arguments: Value) -> Result> { + let result = self.agent.get_prompt(name, arguments).await?; + Ok(result.messages) + } + pub async fn start(&mut self) -> Result<()> { let mut editor = rustyline::Editor::<(), rustyline::history::DefaultHistory>::new()?; @@ -200,6 +207,12 @@ impl Session { output::render_prompts(&self.list_prompts().await) } input::InputResult::PromptCommand(opts) => { + // name is required + if opts.name.is_empty() { + output::render_error("Prompt name argument is required"); + continue; + } + if opts.info { match self.get_prompt_info(&opts.name).await? { Some(info) => output::render_prompt_info(&info), @@ -208,7 +221,21 @@ impl Session { } } } else { - output::render_error("Prompt execution not yet implemented"); + // Convert the arguments HashMap to a Value + let arguments = serde_json::to_value(opts.arguments) + .map_err(|e| anyhow::anyhow!("Failed to serialize arguments: {}", e))?; + + match self.get_prompt(&opts.name, arguments).await { + Ok(messages) => { + println!( + "{:?}", + serde_json::to_string(&messages) + .unwrap_or("failed to get prompt".to_string()) + ); + continue; + } + Err(e) => output::render_error(&e.to_string()), + } } } } diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index 1ef94d355..9ee53522c 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -9,6 +9,7 @@ use super::extension::{ExtensionConfig, ExtensionResult}; use crate::message::Message; use crate::providers::base::ProviderUsage; use mcp_core::prompt::Prompt; +use mcp_core::protocol::GetPromptResult; /// Core trait defining the behavior of an Agent #[async_trait] @@ -40,4 +41,8 @@ pub trait Agent: Send + Sync { /// Lists all prompts from all extensions async fn list_extension_prompts(&self) -> HashMap>; + + /// Get a prompt result with the given name and arguments + /// Returns the prompt text that would be used as user input + async fn get_prompt(&self, name: &str, arguments: Value) -> Result; } diff --git a/crates/goose/src/agents/capabilities.rs b/crates/goose/src/agents/capabilities.rs index fc4762ea4..2e95ec706 100644 --- a/crates/goose/src/agents/capabilities.rs +++ b/crates/goose/src/agents/capabilities.rs @@ -2,6 +2,7 @@ use anyhow::Result; use chrono::{DateTime, TimeZone, Utc}; use futures::stream::{FuturesUnordered, StreamExt}; use mcp_client::McpService; +use mcp_core::protocol::GetPromptResult; use std::collections::{HashMap, HashSet}; use std::sync::Arc; use std::sync::LazyLock; @@ -608,6 +609,24 @@ impl Capabilities { Ok(all_prompts) } + + pub async fn get_prompt( + &self, + extension_name: &str, + name: &str, + arguments: Value, + ) -> Result { + let client = self + .clients + .get(extension_name) + .ok_or_else(|| anyhow::anyhow!("Extension {} not found", extension_name))?; + + let client_guard = client.lock().await; + client_guard + .get_prompt(name, arguments) + .await + .map_err(|e| anyhow::anyhow!("Failed to get prompt: {}", e)) + } } #[cfg(test)] diff --git a/crates/goose/src/agents/reference.rs b/crates/goose/src/agents/reference.rs index c11bbff4a..87118bef3 100644 --- a/crates/goose/src/agents/reference.rs +++ b/crates/goose/src/agents/reference.rs @@ -14,8 +14,10 @@ use crate::providers::base::Provider; use crate::providers::base::ProviderUsage; use crate::register_agent; use crate::token_counter::TokenCounter; +use anyhow::{anyhow, Result}; use indoc::indoc; use mcp_core::prompt::Prompt; +use mcp_core::protocol::GetPromptResult; use mcp_core::tool::Tool; use serde_json::{json, Value}; @@ -204,6 +206,29 @@ impl Agent for ReferenceAgent { .await .expect("Failed to list prompts") } + + async fn get_prompt(&self, name: &str, arguments: Value) -> Result { + let capabilities = self.capabilities.lock().await; + + // First find which extension has this prompt + let prompts = capabilities + .list_prompts() + .await + .map_err(|e| anyhow!("Failed to list prompts: {}", e))?; + + if let Some(extension) = prompts + .iter() + .find(|(_, prompt_list)| prompt_list.iter().any(|p| p.name == name)) + .map(|(extension, _)| extension) + { + return capabilities + .get_prompt(extension, name, arguments) + .await + .map_err(|e| anyhow!("Failed to get prompt: {}", e)); + } + + Err(anyhow!("Prompt '{}' not found", name)) + } } register_agent!("reference", ReferenceAgent); diff --git a/crates/goose/src/agents/truncate.rs b/crates/goose/src/agents/truncate.rs index ca9644776..0d6e28fb5 100644 --- a/crates/goose/src/agents/truncate.rs +++ b/crates/goose/src/agents/truncate.rs @@ -16,8 +16,10 @@ use crate::providers::errors::ProviderError; use crate::register_agent; use crate::token_counter::TokenCounter; use crate::truncate::{truncate_messages, OldestFirstTruncation}; +use anyhow::{anyhow, Result}; use indoc::indoc; use mcp_core::prompt::Prompt; +use mcp_core::protocol::GetPromptResult; use mcp_core::tool::Tool; use serde_json::{json, Value}; @@ -312,6 +314,29 @@ impl Agent for TruncateAgent { .await .expect("Failed to list prompts") } + + async fn get_prompt(&self, name: &str, arguments: Value) -> Result { + let capabilities = self.capabilities.lock().await; + + // First find which extension has this prompt + let prompts = capabilities + .list_prompts() + .await + .map_err(|e| anyhow!("Failed to list prompts: {}", e))?; + + if let Some(extension) = prompts + .iter() + .find(|(_, prompt_list)| prompt_list.iter().any(|p| p.name == name)) + .map(|(extension, _)| extension) + { + return capabilities + .get_prompt(extension, name, arguments) + .await + .map_err(|e| anyhow!("Failed to get prompt: {}", e)); + } + + Err(anyhow!("Prompt '{}' not found", name)) + } } register_agent!("truncate", TruncateAgent); From dd4cddbcb02d0e31633979c550ae9881e5171d28 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Tue, 25 Feb 2025 14:58:39 -0800 Subject: [PATCH 11/14] feat: convert prompt messages to agent messages, and handle prompt in agent loop Add functionality to transform PromptMessageContent to MessageContent with proper handling in the session module and add test coverage. Add the results of GetPrompt to the message conversation and run the agent loop with prompt response. --- crates/goose-cli/src/session/mod.rs | 29 ++++-- crates/goose/src/message.rs | 135 ++++++++++++++++++++++++++++ 2 files changed, 156 insertions(+), 8 deletions(-) diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 3e142b203..2b203c6c6 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -11,9 +11,9 @@ use anyhow::Result; use etcetera::choose_app_strategy; use goose::agents::extension::{Envs, ExtensionConfig}; use goose::agents::Agent; -use goose::message::{Message, MessageContent}; +use goose::message::{prompt_content_to_message_content, Message, MessageContent}; use mcp_core::handler::ToolError; -use mcp_core::prompt::PromptMessage; +use mcp_core::prompt::{PromptMessage, PromptMessageRole}; use rand::{distributions::Alphanumeric, Rng}; use serde_json::Value; @@ -227,12 +227,25 @@ impl Session { match self.get_prompt(&opts.name, arguments).await { Ok(messages) => { - println!( - "{:?}", - serde_json::to_string(&messages) - .unwrap_or("failed to get prompt".to_string()) - ); - continue; + // convert the PromptMessages to Messages + for message in messages { + let msg_content = + prompt_content_to_message_content(message.content); + match message.role { + PromptMessageRole::User => { + self.messages + .push(Message::user().with_content(msg_content)); + } + PromptMessageRole::Assistant => { + self.messages.push( + Message::assistant().with_content(msg_content), + ); + } + } + } + output::show_thinking(); + self.process_agent_response().await?; + output::hide_thinking(); } Err(e) => output::render_error(&e.to_string()), } diff --git a/crates/goose/src/message.rs b/crates/goose/src/message.rs index 30de253ff..67877b9e0 100644 --- a/crates/goose/src/message.rs +++ b/crates/goose/src/message.rs @@ -10,9 +10,31 @@ use std::collections::HashSet; use chrono::Utc; use mcp_core::content::{Content, ImageContent, TextContent}; use mcp_core::handler::ToolResult; +use mcp_core::prompt::PromptMessageContent; +use mcp_core::resource::ResourceContents; use mcp_core::role::Role; use mcp_core::tool::ToolCall; +/// Convert PromptMessageContent to MessageContent +/// +/// This function allows converting from the prompt message content type +/// to the message content type used in the agent. +pub fn prompt_content_to_message_content(content: PromptMessageContent) -> MessageContent { + match content { + PromptMessageContent::Text { text } => MessageContent::text(text), + PromptMessageContent::Image { image } => MessageContent::image(image.data, image.mime_type), + PromptMessageContent::Resource { resource } => { + // For resources, convert to text content with the resource text + match resource.resource { + ResourceContents::TextResourceContents { text, .. } => MessageContent::text(text), + ResourceContents::BlobResourceContents { blob, .. } => { + MessageContent::text(format!("[Binary content: {}]", blob)) + } + } + } + } +} + #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] pub struct ToolRequest { pub id: String, @@ -248,3 +270,116 @@ impl Message { .all(|c| matches!(c, MessageContent::Text(_))) } } + +#[cfg(test)] +mod tests { + use super::*; + use mcp_core::content::EmbeddedResource; + use mcp_core::prompt::PromptMessageContent; + use mcp_core::resource::ResourceContents; + + #[test] + fn test_prompt_content_to_message_content_text() { + let prompt_content = PromptMessageContent::Text { + text: "Hello, world!".to_string(), + }; + + let message_content = prompt_content_to_message_content(prompt_content); + + if let MessageContent::Text(text_content) = message_content { + assert_eq!(text_content.text, "Hello, world!"); + } else { + panic!("Expected MessageContent::Text"); + } + } + + #[test] + fn test_prompt_content_to_message_content_image() { + let prompt_content = PromptMessageContent::Image { + image: ImageContent { + data: "base64data".to_string(), + mime_type: "image/jpeg".to_string(), + annotations: None, + }, + }; + + let message_content = prompt_content_to_message_content(prompt_content); + + if let MessageContent::Image(image_content) = message_content { + assert_eq!(image_content.data, "base64data"); + assert_eq!(image_content.mime_type, "image/jpeg"); + } else { + panic!("Expected MessageContent::Image"); + } + } + + #[test] + fn test_prompt_content_to_message_content_text_resource() { + let resource = ResourceContents::TextResourceContents { + uri: "file:///test.txt".to_string(), + mime_type: Some("text/plain".to_string()), + text: "Resource content".to_string(), + }; + + let prompt_content = PromptMessageContent::Resource { + resource: EmbeddedResource { + resource, + annotations: None, + }, + }; + + let message_content = prompt_content_to_message_content(prompt_content); + + if let MessageContent::Text(text_content) = message_content { + assert_eq!(text_content.text, "Resource content"); + } else { + panic!("Expected MessageContent::Text"); + } + } + + #[test] + fn test_prompt_content_to_message_content_blob_resource() { + let resource = ResourceContents::BlobResourceContents { + uri: "file:///test.bin".to_string(), + mime_type: Some("application/octet-stream".to_string()), + blob: "binary_data".to_string(), + }; + + let prompt_content = PromptMessageContent::Resource { + resource: EmbeddedResource { + resource, + annotations: None, + }, + }; + + let message_content = prompt_content_to_message_content(prompt_content); + + if let MessageContent::Text(text_content) = message_content { + assert_eq!(text_content.text, "[Binary content: binary_data]"); + } else { + panic!("Expected MessageContent::Text"); + } + } + + #[test] + fn test_message_with_text() { + let message = Message::user().with_text("Hello"); + assert_eq!(message.as_concat_text(), "Hello"); + } + + #[test] + fn test_message_with_tool_request() { + let tool_call = Ok(ToolCall { + name: "test_tool".to_string(), + arguments: serde_json::json!({}), + }); + + let message = Message::assistant().with_tool_request("req1", tool_call); + assert!(message.is_tool_call()); + assert!(!message.is_tool_response()); + + let ids = message.get_tool_ids(); + assert_eq!(ids.len(), 1); + assert!(ids.contains("req1")); + } +} From a1ef0b9e33d01cee3bca83e6d0c8240d3cb960c6 Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Wed, 26 Feb 2025 09:13:28 -0800 Subject: [PATCH 12/14] style: cargo fmt after merge fix: update process_agent_response call --- crates/goose-cli/src/session/mod.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 3acb43a7e..9c26c6290 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -107,7 +107,6 @@ impl Session { Ok(()) } - pub async fn list_prompts(&mut self) -> HashMap> { let prompts = self.agent.list_extension_prompts().await; prompts @@ -258,7 +257,7 @@ impl Session { } } output::show_thinking(); - self.process_agent_response().await?; + self.process_agent_response(true).await?; output::hide_thinking(); } Err(e) => output::render_error(&e.to_string()), From e2ef3d18d57504691e5495f9454b2d7bb461a99b Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Thu, 27 Feb 2025 11:24:04 -0800 Subject: [PATCH 13/14] refactor: cleanup PromptMessage to Message using impl From --- crates/goose-cli/src/session/mod.rs | 20 +---- crates/goose/src/message.rs | 116 +++++++++++++++++++++------- 2 files changed, 92 insertions(+), 44 deletions(-) diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index f5a36a743..71782f8c3 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -12,9 +12,9 @@ use anyhow::Result; use etcetera::choose_app_strategy; use goose::agents::extension::{Envs, ExtensionConfig}; use goose::agents::Agent; -use goose::message::{prompt_content_to_message_content, Message, MessageContent}; +use goose::message::{Message, MessageContent}; use mcp_core::handler::ToolError; -use mcp_core::prompt::{PromptMessage, PromptMessageRole}; +use mcp_core::prompt::PromptMessage; use rand::{distributions::Alphanumeric, Rng}; use serde_json::Value; @@ -242,20 +242,8 @@ impl Session { match self.get_prompt(&opts.name, arguments).await { Ok(messages) => { // convert the PromptMessages to Messages - for message in messages { - let msg_content = - prompt_content_to_message_content(message.content); - match message.role { - PromptMessageRole::User => { - self.messages - .push(Message::user().with_content(msg_content)); - } - PromptMessageRole::Assistant => { - self.messages.push( - Message::assistant().with_content(msg_content), - ); - } - } + for prompt_message in messages { + self.messages.push(Message::from(prompt_message)); } output::show_thinking(); self.process_agent_response(true).await?; diff --git a/crates/goose/src/message.rs b/crates/goose/src/message.rs index 0ef6f504b..41193a53f 100644 --- a/crates/goose/src/message.rs +++ b/crates/goose/src/message.rs @@ -10,7 +10,7 @@ use std::collections::HashSet; use chrono::Utc; use mcp_core::content::{Content, ImageContent, TextContent}; use mcp_core::handler::ToolResult; -use mcp_core::prompt::PromptMessageContent; +use mcp_core::prompt::{PromptMessage, PromptMessageContent, PromptMessageRole}; use mcp_core::resource::ResourceContents; use mcp_core::role::Role; use mcp_core::tool::ToolCall; @@ -158,23 +158,34 @@ impl From for MessageContent { } } -/// Convert PromptMessageContent to MessageContent -/// -/// This function allows converting from the prompt message content type -/// to the message content type used in the agent. -pub fn prompt_content_to_message_content(content: PromptMessageContent) -> MessageContent { - match content { - PromptMessageContent::Text { text } => MessageContent::text(text), - PromptMessageContent::Image { image } => MessageContent::image(image.data, image.mime_type), - PromptMessageContent::Resource { resource } => { - // For resources, convert to text content with the resource text - match resource.resource { - ResourceContents::TextResourceContents { text, .. } => MessageContent::text(text), - ResourceContents::BlobResourceContents { blob, .. } => { - MessageContent::text(format!("[Binary content: {}]", blob)) +impl From for Message { + fn from(prompt_message: PromptMessage) -> Self { + // Create a new message with the appropriate role + let message = match prompt_message.role { + PromptMessageRole::User => Message::user(), + PromptMessageRole::Assistant => Message::assistant(), + }; + + // Convert and add the content + let content = match prompt_message.content { + PromptMessageContent::Text { text } => MessageContent::text(text), + PromptMessageContent::Image { image } => { + MessageContent::image(image.data, image.mime_type) + } + PromptMessageContent::Resource { resource } => { + // For resources, convert to text content with the resource text + match resource.resource { + ResourceContents::TextResourceContents { text, .. } => { + MessageContent::text(text) + } + ResourceContents::BlobResourceContents { blob, .. } => { + MessageContent::text(format!("[Binary content: {}]", blob)) + } } } - } + }; + + message.with_content(content) } } @@ -447,14 +458,19 @@ mod tests { } #[test] - fn test_prompt_content_to_message_content_text() { + fn test_from_prompt_message_text() { let prompt_content = PromptMessageContent::Text { text: "Hello, world!".to_string(), }; - let message_content = prompt_content_to_message_content(prompt_content); + let prompt_message = PromptMessage { + role: PromptMessageRole::User, + content: prompt_content, + }; + + let message = Message::from(prompt_message); - if let MessageContent::Text(text_content) = message_content { + if let MessageContent::Text(text_content) = &message.content[0] { assert_eq!(text_content.text, "Hello, world!"); } else { panic!("Expected MessageContent::Text"); @@ -462,7 +478,7 @@ mod tests { } #[test] - fn test_prompt_content_to_message_content_image() { + fn test_from_prompt_message_image() { let prompt_content = PromptMessageContent::Image { image: ImageContent { data: "base64data".to_string(), @@ -471,9 +487,14 @@ mod tests { }, }; - let message_content = prompt_content_to_message_content(prompt_content); + let prompt_message = PromptMessage { + role: PromptMessageRole::User, + content: prompt_content, + }; + + let message = Message::from(prompt_message); - if let MessageContent::Image(image_content) = message_content { + if let MessageContent::Image(image_content) = &message.content[0] { assert_eq!(image_content.data, "base64data"); assert_eq!(image_content.mime_type, "image/jpeg"); } else { @@ -482,7 +503,7 @@ mod tests { } #[test] - fn test_prompt_content_to_message_content_text_resource() { + fn test_from_prompt_message_text_resource() { let resource = ResourceContents::TextResourceContents { uri: "file:///test.txt".to_string(), mime_type: Some("text/plain".to_string()), @@ -496,9 +517,14 @@ mod tests { }, }; - let message_content = prompt_content_to_message_content(prompt_content); + let prompt_message = PromptMessage { + role: PromptMessageRole::User, + content: prompt_content, + }; + + let message = Message::from(prompt_message); - if let MessageContent::Text(text_content) = message_content { + if let MessageContent::Text(text_content) = &message.content[0] { assert_eq!(text_content.text, "Resource content"); } else { panic!("Expected MessageContent::Text"); @@ -506,7 +532,7 @@ mod tests { } #[test] - fn test_prompt_content_to_message_content_blob_resource() { + fn test_from_prompt_message_blob_resource() { let resource = ResourceContents::BlobResourceContents { uri: "file:///test.bin".to_string(), mime_type: Some("application/octet-stream".to_string()), @@ -520,15 +546,49 @@ mod tests { }, }; - let message_content = prompt_content_to_message_content(prompt_content); + let prompt_message = PromptMessage { + role: PromptMessageRole::User, + content: prompt_content, + }; + + let message = Message::from(prompt_message); - if let MessageContent::Text(text_content) = message_content { + if let MessageContent::Text(text_content) = &message.content[0] { assert_eq!(text_content.text, "[Binary content: binary_data]"); } else { panic!("Expected MessageContent::Text"); } } + #[test] + fn test_from_prompt_message() { + // Test user message conversion + let prompt_message = PromptMessage { + role: PromptMessageRole::User, + content: PromptMessageContent::Text { + text: "Hello, world!".to_string(), + }, + }; + + let message = Message::from(prompt_message); + assert_eq!(message.role, Role::User); + assert_eq!(message.content.len(), 1); + assert_eq!(message.as_concat_text(), "Hello, world!"); + + // Test assistant message conversion + let prompt_message = PromptMessage { + role: PromptMessageRole::Assistant, + content: PromptMessageContent::Text { + text: "I can help with that.".to_string(), + }, + }; + + let message = Message::from(prompt_message); + assert_eq!(message.role, Role::Assistant); + assert_eq!(message.content.len(), 1); + assert_eq!(message.as_concat_text(), "I can help with that."); + } + #[test] fn test_message_with_text() { let message = Message::user().with_text("Hello"); From 1c2a15a4e9fd68042b29a2c40e9f51b2729c2b4e Mon Sep 17 00:00:00 2001 From: Kalvin Chau Date: Thu, 27 Feb 2025 14:17:04 -0800 Subject: [PATCH 14/14] feat: enforce PromptMessages that we start with User messages - render user messages to the user's session - enforce user-assistant-user pattern and cleanup with bad patterns --- crates/goose-cli/src/session/mod.rs | 38 ++++++++++++++++++++++++----- 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/crates/goose-cli/src/session/mod.rs b/crates/goose-cli/src/session/mod.rs index 71782f8c3..b3826d2be 100644 --- a/crates/goose-cli/src/session/mod.rs +++ b/crates/goose-cli/src/session/mod.rs @@ -241,13 +241,39 @@ impl Session { match self.get_prompt(&opts.name, arguments).await { Ok(messages) => { - // convert the PromptMessages to Messages - for prompt_message in messages { - self.messages.push(Message::from(prompt_message)); + let start_len = self.messages.len(); + let mut valid = true; + for (i, prompt_message) in messages.into_iter().enumerate() { + let msg = Message::from(prompt_message); + // ensure we get a User - Assistant - User type pattern + let expected_role = if i % 2 == 0 { + mcp_core::Role::User + } else { + mcp_core::Role::Assistant + }; + + if msg.role != expected_role { + output::render_error(&format!( + "Expected {:?} message at position {}, but found {:?}", + expected_role, i, msg.role + )); + valid = false; + // get rid of everything we added to messages + self.messages.truncate(start_len); + break; + } + + if msg.role == mcp_core::Role::User { + output::render_message(&msg); + } + self.messages.push(msg); + } + + if valid { + output::show_thinking(); + self.process_agent_response(true).await?; + output::hide_thinking(); } - output::show_thinking(); - self.process_agent_response(true).await?; - output::hide_thinking(); } Err(e) => output::render_error(&e.to_string()), }