-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathagent.ts
138 lines (116 loc) · 4.56 KB
/
agent.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import { Annotation, StateGraph } from "@langchain/langgraph";
import { AIMessage, BaseMessage, HumanMessage } from "@langchain/core/messages";
import "dotenv/config";
import { MongoClient } from "mongodb";
import { tool } from "@langchain/core/tools";
import { ToolNode } from "@langchain/langgraph/prebuilt";
import { MongoDBAtlasVectorSearch } from "@langchain/mongodb";
import { ChatMistralAI, MistralAIEmbeddings } from "@langchain/mistralai";
import {
ChatPromptTemplate,
MessagesPlaceholder,
} from "@langchain/core/prompts";
import { z } from "zod";
import { MongoDBSaver } from "@langchain/langgraph-checkpoint-mongodb";
export async function callAgent(
client: MongoClient,
query: string,
thread_id: string
) {
const dbName = "hr_database";
const collection = client.db(dbName).collection("employees");
// Define the graph state
const GraphState = Annotation.Root({
messages: Annotation<BaseMessage[]>({
reducer: (x, y) => x.concat(y),
}),
});
// Define the tools for the agent to use
const employeeLookupTool = tool(
async ({ query, n = 10 }) => {
console.log("Employee lookup tool called");
const dbConfig = {
collection,
indexName: "vector_index",
textKey: "embedding_text",
embeddingKey: "embedding",
};
// Initialize vector store
const vectorStore = new MongoDBAtlasVectorSearch(
new MistralAIEmbeddings(),
dbConfig
);
const result = await vectorStore.similaritySearchWithScore(query, n);
return JSON.stringify(result);
},
{
name: "employee_lookup",
description: "Gathers employee details from the HR database",
schema: z.object({
query: z.string().describe("The search query"),
n: z
.number()
.optional()
.default(10)
.describe("Number of results to return"),
}),
}
);
const tools = [employeeLookupTool];
// We can extract the state typing via `GraphState.state`
const toolNode = new ToolNode<typeof GraphState.State>(tools);
const model = new ChatMistralAI({
model: "mistral-small-latest",
temperature: 0,
}).bindTools(tools);
// Define the function that determines whether to continue or not
function shouldContinue(state: typeof GraphState.State) {
const messages = state.messages;
const lastMessage = messages[messages.length - 1] as AIMessage;
// If the LLM makes a tool call, then we route to the "tools" node
if (lastMessage.tool_calls?.length) {
return "tools";
}
// Otherwise, we stop (reply to the user)
return "__end__";
}
// Define the function that calls the model
async function callModel(state: typeof GraphState.State) {
const prompt = ChatPromptTemplate.fromMessages([
[
"system",
`You are a helpful AI assistant, collaborating with other assistants. Use the provided tools to progress towards answering the question. If you are unable to fully answer, that's OK, another assistant with different tools will help where you left off. Execute what you can to make progress. If you or any of the other assistants have the final answer or deliverable, prefix your response with FINAL ANSWER so the team knows to stop. You have access to the following tools: {tool_names}.\n{system_message}\nCurrent time: {time}.`,
],
new MessagesPlaceholder("messages"),
]);
const formattedPrompt = await prompt.formatMessages({
system_message: "You are helpful HR Chatbot Agent.",
time: new Date().toISOString(),
tool_names: tools.map((tool) => tool.name).join(", "),
messages: state.messages,
});
const result = await model.invoke(formattedPrompt);
return { messages: [result] };
}
// Define a new graph
const workflow = new StateGraph(GraphState)
.addNode("agent", callModel)
.addNode("tools", toolNode)
.addEdge("__start__", "agent")
.addConditionalEdges("agent", shouldContinue)
.addEdge("tools", "agent");
// Initialize the MongoDB memory to persist state between graph runs
const checkpointer = new MongoDBSaver({ client, dbName });
// This compiles it into a LangChain Runnable.
// Note that we're passing the memory when compiling the graph
const app = workflow.compile({ checkpointer });
// Use the Runnable
const finalState = await app.invoke(
{
messages: [new HumanMessage(query)],
},
{ recursionLimit: 15, configurable: { thread_id } }
);
console.log(finalState.messages[finalState.messages.length - 1].content);
return finalState.messages[finalState.messages.length - 1].content;
}