Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(cli): add mcp prompt support via slash commands #1323

Merged
merged 16 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
16 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/goose-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ chrono = "0.4"
tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt", "json", "time"] }
tracing-appender = "0.2"
once_cell = "1.20.2"
shlex = "1.3.0"

[target.'cfg(target_os = "windows")'.dependencies]
winapi = { version = "0.3", features = ["wincred"] }
Expand Down
164 changes: 164 additions & 0 deletions crates/goose-cli/src/session/input.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use anyhow::Result;
use rustyline::Editor;
use shlex;
use std::collections::HashMap;

#[derive(Debug)]
pub enum InputResult {
Expand All @@ -9,6 +11,15 @@ pub enum InputResult {
AddBuiltin(String),
ToggleTheme,
Retry,
ListPrompts,
PromptCommand(PromptCommandOptions),
}

#[derive(Debug)]
pub struct PromptCommandOptions {
pub name: String,
pub info: bool,
pub arguments: HashMap<String, String>,
}

pub fn get_input(
Expand Down Expand Up @@ -59,19 +70,76 @@ fn handle_slash_command(input: &str) -> Option<InputResult> {
Some(InputResult::Retry)
}
"/t" => Some(InputResult::ToggleTheme),
"/prompts" => Some(InputResult::ListPrompts),
s if s.starts_with("/prompt") => {
if s == "/prompt" {
// No arguments case
Some(InputResult::PromptCommand(PromptCommandOptions {
name: String::new(), // Empty name will trigger the error message in the rendering
info: false,
arguments: HashMap::new(),
}))
} else if let Some(stripped) = s.strip_prefix("/prompt ") {
// Has arguments case
parse_prompt_command(stripped)
} else {
// Handle invalid cases like "/promptxyz"
None
}
}
s if s.starts_with("/extension ") => Some(InputResult::AddExtension(s[11..].to_string())),
s if s.starts_with("/builtin ") => Some(InputResult::AddBuiltin(s[9..].to_string())),
_ => None,
}
}

fn parse_prompt_command(args: &str) -> Option<InputResult> {
let parts: Vec<String> = shlex::split(args).unwrap_or_default();

// set name to empty and error out in the rendering
let mut options = PromptCommandOptions {
name: parts.first().cloned().unwrap_or_default(),
info: false,
arguments: HashMap::new(),
};

// handle info at any point in the command
if parts.iter().any(|part| part == "--info") {
options.info = true;
}

// Parse remaining arguments
let mut i = 1;

while i < parts.len() {
let part = &parts[i];

// Skip flag arguments
if part == "--info" {
i += 1;
continue;
}

// Process key=value pairs - removed redundant contains check
if let Some((key, value)) = part.split_once('=') {
options.arguments.insert(key.to_string(), value.to_string());
}

i += 1;
}

Some(InputResult::PromptCommand(options))
}

fn print_help() {
println!(
"Available commands:
/exit or /quit - Exit the session
/t - Toggle Light/Dark/Ansi theme
/extension <command> - Add a stdio extension (format: ENV1=val1 command args...)
/builtin <names> - Add builtin extensions by name (comma-separated)
/prompts - List all available prompts by name
/prompt <name> [--info] [key=value...] - Get prompt info or execute a prompt
/? or /help - Display this help message

Navigation:
Expand Down Expand Up @@ -131,6 +199,33 @@ mod tests {
assert!(handle_slash_command("/unknown").is_none());
}

#[test]
fn test_prompt_command() {
// Test basic prompt info command
if let Some(InputResult::PromptCommand(opts)) =
handle_slash_command("/prompt test-prompt --info")
{
assert_eq!(opts.name, "test-prompt");
assert!(opts.info);
assert!(opts.arguments.is_empty());
} else {
panic!("Expected PromptCommand");
}

// Test prompt with arguments
if let Some(InputResult::PromptCommand(opts)) =
handle_slash_command("/prompt test-prompt arg1=val1 arg2=val2")
{
assert_eq!(opts.name, "test-prompt");
assert!(!opts.info);
assert_eq!(opts.arguments.len(), 2);
assert_eq!(opts.arguments.get("arg1"), Some(&"val1".to_string()));
assert_eq!(opts.arguments.get("arg2"), Some(&"val2".to_string()));
} else {
panic!("Expected PromptCommand");
}
}

// Test whitespace handling
#[test]
fn test_whitespace_handling() {
Expand All @@ -149,4 +244,73 @@ mod tests {
panic!("Expected AddBuiltin");
}
}

// Test prompt with no arguments
#[test]
fn test_prompt_no_args() {
// Test just "/prompt" with no arguments
if let Some(InputResult::PromptCommand(opts)) = handle_slash_command("/prompt") {
assert_eq!(opts.name, "");
assert!(!opts.info);
assert!(opts.arguments.is_empty());
} else {
panic!("Expected PromptCommand");
}

// Test invalid prompt command
assert!(handle_slash_command("/promptxyz").is_none());
}

// Test quoted arguments
#[test]
fn test_quoted_arguments() {
// Test prompt with quoted arguments
if let Some(InputResult::PromptCommand(opts)) = handle_slash_command(
r#"/prompt test-prompt arg1="value with spaces" arg2="another value""#,
) {
assert_eq!(opts.name, "test-prompt");
assert_eq!(opts.arguments.len(), 2);
assert_eq!(
opts.arguments.get("arg1"),
Some(&"value with spaces".to_string())
);
assert_eq!(
opts.arguments.get("arg2"),
Some(&"another value".to_string())
);
} else {
panic!("Expected PromptCommand");
}

// Test prompt with mixed quoted and unquoted arguments
if let Some(InputResult::PromptCommand(opts)) = handle_slash_command(
r#"/prompt test-prompt simple=value quoted="value with \"nested\" quotes""#,
) {
assert_eq!(opts.name, "test-prompt");
assert_eq!(opts.arguments.len(), 2);
assert_eq!(opts.arguments.get("simple"), Some(&"value".to_string()));
assert_eq!(
opts.arguments.get("quoted"),
Some(&r#"value with "nested" quotes"#.to_string())
);
} else {
panic!("Expected PromptCommand");
}
}

// Test invalid arguments
#[test]
fn test_invalid_arguments() {
// Test prompt with invalid arguments
if let Some(InputResult::PromptCommand(opts)) =
handle_slash_command(r#"/prompt test-prompt valid=value invalid_arg another_invalid"#)
{
assert_eq!(opts.name, "test-prompt");
assert_eq!(opts.arguments.len(), 1);
assert_eq!(opts.arguments.get("valid"), Some(&"value".to_string()));
// Invalid arguments are ignored but logged
} else {
panic!("Expected PromptCommand");
}
}
}
100 changes: 100 additions & 0 deletions crates/goose-cli/src/session/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ use goose::agents::extension::{Envs, ExtensionConfig};
use goose::agents::Agent;
use goose::message::{Message, MessageContent};
use mcp_core::handler::ToolError;
use mcp_core::prompt::PromptMessage;

use rand::{distributions::Alphanumeric, Rng};
use serde_json::Value;
use std::collections::HashMap;
use std::path::PathBuf;
use tokio;

Expand Down Expand Up @@ -104,6 +108,40 @@ impl Session {
Ok(())
}

pub async fn list_prompts(&mut self) -> HashMap<String, Vec<String>> {
let prompts = self.agent.list_extension_prompts().await;
prompts
.into_iter()
.map(|(extension, prompt_list)| {
let names = prompt_list.into_iter().map(|p| p.name).collect();
(extension, names)
})
.collect()
}

pub async fn get_prompt_info(&mut self, name: &str) -> Result<Option<output::PromptInfo>> {
let prompts = self.agent.list_extension_prompts().await;

// Find which extension has this prompt
for (extension, prompt_list) in prompts {
if let Some(prompt) = prompt_list.iter().find(|p| p.name == name) {
return Ok(Some(output::PromptInfo {
name: prompt.name.clone(),
description: prompt.description.clone(),
arguments: prompt.arguments.clone(),
extension: Some(extension),
}));
}
}

Ok(None)
}

pub async fn get_prompt(&mut self, name: &str, arguments: Value) -> Result<Vec<PromptMessage>> {
Copy link
Collaborator

@salman1993 salman1993 Feb 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we invoked a PromptMessage that starts with assistant role - i think that would error because the CLI assumes its user msg and we also need to alternate between user & assistant. can you check this? if so, we might wanna enforce some checks.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added enforcement for user first, and that the messages alternate user & assistant

let result = self.agent.get_prompt(name, arguments).await?;
Ok(result.messages)
}

/// Process a single message and get the response
async fn process_message(&mut self, message: String) -> Result<()> {
self.messages.push(Message::user().with_text(&message));
Expand Down Expand Up @@ -179,6 +217,68 @@ impl Session {
continue;
}
input::InputResult::Retry => continue,
input::InputResult::ListPrompts => {
output::render_prompts(&self.list_prompts().await)
}
input::InputResult::PromptCommand(opts) => {
// name is required
if opts.name.is_empty() {
output::render_error("Prompt name argument is required");
continue;
}

if opts.info {
match self.get_prompt_info(&opts.name).await? {
Some(info) => output::render_prompt_info(&info),
None => {
output::render_error(&format!("Prompt '{}' not found", opts.name))
}
}
} else {
// Convert the arguments HashMap to a Value
let arguments = serde_json::to_value(opts.arguments)
.map_err(|e| anyhow::anyhow!("Failed to serialize arguments: {}", e))?;

match self.get_prompt(&opts.name, arguments).await {
Ok(messages) => {
let start_len = self.messages.len();
let mut valid = true;
for (i, prompt_message) in messages.into_iter().enumerate() {
let msg = Message::from(prompt_message);
// ensure we get a User - Assistant - User type pattern
let expected_role = if i % 2 == 0 {
mcp_core::Role::User
} else {
mcp_core::Role::Assistant
};

if msg.role != expected_role {
output::render_error(&format!(
"Expected {:?} message at position {}, but found {:?}",
expected_role, i, msg.role
));
valid = false;
// get rid of everything we added to messages
self.messages.truncate(start_len);
break;
}

if msg.role == mcp_core::Role::User {
output::render_message(&msg);
}
self.messages.push(msg);
}

if valid {
output::show_thinking();
self.process_agent_response(true).await?;
output::hide_thinking();
}
}
Err(e) => output::render_error(&e.to_string()),
}
}
}
}
}

Expand Down
Loading
Loading