diff --git a/dotnet/sample/AutoGen.BasicSamples/Example17_ReActAgent.cs b/dotnet/sample/AutoGen.BasicSamples/Example17_ReActAgent.cs new file mode 100644 index 000000000000..f598ebbf7c46 --- /dev/null +++ b/dotnet/sample/AutoGen.BasicSamples/Example17_ReActAgent.cs @@ -0,0 +1,187 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Example17_ReActAgent.cs + +using AutoGen.Core; +using AutoGen.OpenAI; +using AutoGen.OpenAI.Extension; +using Azure.AI.OpenAI; + +namespace AutoGen.BasicSample; + +public class OpenAIReActAgent : IAgent +{ + private readonly OpenAIClient _client; + private readonly string modelName = "gpt-3.5-turbo"; + private readonly FunctionContract[] tools; + private readonly Dictionary>> toolExecutors = new(); + private readonly IAgent reasoner; + private readonly IAgent actor; + private readonly IAgent helper; + private readonly int maxSteps = 10; + + private const string ReActPrompt = @"Answer the following questions as best you can. +You can invoke the following tools: +{tools} + +Use the following format: + +Question: the input question you must answer +Thought: you should always think about what to do +Tool: the tool to invoke +Tool Input: the input to the tool +Observation: the invoke result of the tool +... (this process can repeat multiple times) + +Once you have the final answer, provide the final answer in the following format: +Thought: I now know the final answer +Final Answer: the final answer to the original input question + +Begin! +Question: {input}"; + + public OpenAIReActAgent(OpenAIClient client, string modelName, string name, FunctionContract[] tools, Dictionary>> toolExecutors) + { + _client = client; + this.Name = name; + this.modelName = modelName; + this.tools = tools; + this.toolExecutors = toolExecutors; + this.reasoner = CreateReasoner(); + this.actor = CreateActor(); + this.helper = new OpenAIChatAgent(client, "helper", modelName) + .RegisterMessageConnector(); + } + + public string Name { get; } + + public async Task GenerateReplyAsync(IEnumerable messages, GenerateReplyOptions? options = null, CancellationToken cancellationToken = default) + { + // step 1: extract the input question + var userQuestion = await helper.SendAsync("Extract the question from chat history", chatHistory: messages); + if (userQuestion.GetContent() is not string question) + { + return new TextMessage(Role.Assistant, "I couldn't find a question in the chat history. Please ask a question.", from: Name); + } + var reactPrompt = CreateReActPrompt(question); + var promptMessage = new TextMessage(Role.User, reactPrompt); + var chatHistory = new List() { promptMessage }; + + // step 2: ReAct + for (int i = 0; i != this.maxSteps; i++) + { + // reasoning + var reasoning = await reasoner.SendAsync(chatHistory: chatHistory); + if (reasoning.GetContent() is not string reasoningContent) + { + return new TextMessage(Role.Assistant, "I couldn't find a reasoning in the chat history. Please provide a reasoning.", from: Name); + } + if (reasoningContent.Contains("I now know the final answer")) + { + return new TextMessage(Role.Assistant, reasoningContent, from: Name); + } + + chatHistory.Add(reasoning); + + // action + var action = await actor.SendAsync(reasoning); + chatHistory.Add(action); + } + + // fail to find the final answer + // return the summary of the chat history + var summary = await helper.SendAsync("Summarize the chat history and find out what's missing", chatHistory: chatHistory); + summary.From = Name; + + return summary; + } + + private string CreateReActPrompt(string input) + { + var toolPrompt = tools.Select(t => $"{t.Name}: {t.Description}").Aggregate((a, b) => $"{a}\n{b}"); + var prompt = ReActPrompt.Replace("{tools}", toolPrompt); + prompt = prompt.Replace("{input}", input); + return prompt; + } + + private IAgent CreateReasoner() + { + return new OpenAIChatAgent( + openAIClient: _client, + modelName: modelName, + name: "reasoner") + .RegisterMessageConnector() + .RegisterPrintMessage(); + } + + private IAgent CreateActor() + { + var functionCallMiddleware = new FunctionCallMiddleware(tools, toolExecutors); + return new OpenAIChatAgent( + openAIClient: _client, + modelName: modelName, + name: "actor") + .RegisterMessageConnector() + .RegisterMiddleware(functionCallMiddleware) + .RegisterPrintMessage(); + } +} + +public partial class Tools +{ + /// + /// Get weather report for a specific place on a specific date + /// + /// city + /// date as DD/MM/YYYY + [Function] + public async Task WeatherReport(string city, string date) + { + return $"Weather report for {city} on {date} is sunny"; + } + + /// + /// Get current localization + /// + [Function] + public async Task GetLocalization(string dummy) + { + return $"Paris"; + } + + /// + /// Get current date as DD/MM/YYYY + /// + [Function] + public async Task GetDateToday(string dummy) + { + return $"27/05/2024"; + } +} + +public class Example17_ReActAgent +{ + public static async Task RunAsync() + { + var openAIKey = Environment.GetEnvironmentVariable("OPENAI_API_KEY") ?? throw new Exception("Please set OPENAI_API_KEY environment variable."); + var modelName = "gpt-4-turbo"; + var tools = new Tools(); + var openAIClient = new OpenAIClient(openAIKey); + var reactAgent = new OpenAIReActAgent( + client: openAIClient, + modelName: modelName, + name: "react-agent", + tools: [tools.GetLocalizationFunctionContract, tools.GetDateTodayFunctionContract, tools.WeatherReportFunctionContract], + toolExecutors: new Dictionary>> + { + { tools.GetLocalizationFunctionContract.Name, tools.GetLocalizationWrapper }, + { tools.GetDateTodayFunctionContract.Name, tools.GetDateTodayWrapper }, + { tools.WeatherReportFunctionContract.Name, tools.WeatherReportWrapper }, + } + ) + .RegisterPrintMessage(); + + var message = new TextMessage(Role.User, "What is the weather here", from: "user"); + + var response = await reactAgent.SendAsync(message); + } +} diff --git a/dotnet/sample/AutoGen.BasicSamples/Program.cs b/dotnet/sample/AutoGen.BasicSamples/Program.cs index 11b5127ade0d..b48e2be4aa16 100644 --- a/dotnet/sample/AutoGen.BasicSamples/Program.cs +++ b/dotnet/sample/AutoGen.BasicSamples/Program.cs @@ -3,4 +3,4 @@ using AutoGen.BasicSample; Console.ReadLine(); -await Example16_OpenAIChatAgent_ConnectToThirdPartyBackend.RunAsync(); +await Example17_ReActAgent.RunAsync(); diff --git a/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.cs b/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.cs index e56db112eb70..40adbdcde47c 100644 --- a/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.cs +++ b/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.cs @@ -121,7 +121,8 @@ public virtual string TransformText() this.Write("\",\r\n"); } if (functionContract.Parameters != null) { - this.Write(" Parameters = new []\r\n {\r\n"); + this.Write(" Parameters = new global::AutoGen.Core.FunctionParameterContract[]" + + "\r\n {\r\n"); foreach (var parameter in functionContract.Parameters) { this.Write(" new FunctionParameterContract\r\n {\r\n"); if (parameter.Name != null) { diff --git a/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.tt b/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.tt index 526dfe400cea..0d1b221c35c8 100644 --- a/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.tt +++ b/dotnet/src/AutoGen.SourceGenerator/Template/FunctionCallTemplate.tt @@ -72,7 +72,7 @@ namespace <#=NameSpace#> ReturnDescription = @"<#=functionContract.ReturnDescription#>", <#}#> <#if (functionContract.Parameters != null) {#> - Parameters = new [] + Parameters = new global::AutoGen.Core.FunctionParameterContract[] { <#foreach (var parameter in functionContract.Parameters) {#> new FunctionParameterContract @@ -110,6 +110,6 @@ namespace <#=NameSpace#> <#+ public string NameSpace {get; set;} public string ClassName {get; set;} -public IEnumerable FunctionContracts {get; set;} +public IEnumerable FunctionContracts {get; set;} public bool IsStatic {get; set;} = false; #> \ No newline at end of file diff --git a/dotnet/test/AutoGen.SourceGenerator.Tests/ApprovalTests/FunctionCallTemplateTests.TestFunctionCallTemplate.approved.txt b/dotnet/test/AutoGen.SourceGenerator.Tests/ApprovalTests/FunctionCallTemplateTests.TestFunctionCallTemplate.approved.txt index feab4ebd6078..0439febc52c7 100644 --- a/dotnet/test/AutoGen.SourceGenerator.Tests/ApprovalTests/FunctionCallTemplateTests.TestFunctionCallTemplate.approved.txt +++ b/dotnet/test/AutoGen.SourceGenerator.Tests/ApprovalTests/FunctionCallTemplateTests.TestFunctionCallTemplate.approved.txt @@ -42,7 +42,7 @@ namespace AutoGen.SourceGenerator.Tests Name = @"AddAsync", Description = @"Add two numbers.", ReturnType = typeof(System.Threading.Tasks.Task`1[System.String]), - Parameters = new [] + Parameters = new global::AutoGen.Core.FunctionParameterContract[] { new FunctionParameterContract {