Skip to content

Commit

Permalink
refactor(generate-text): trampoline function (#25)
Browse files Browse the repository at this point in the history
* refactor(generate-text)!: trampoline function

* fix(generate-text): clean steps

* refactor(generate-text): simplify naming

* refactor(generate-text): simplify naming

* chore(tool): update test

* chore(generate-text): update test
  • Loading branch information
kwaa authored Jan 12, 2025
1 parent 9b9c73a commit d143e32
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 89 deletions.
174 changes: 87 additions & 87 deletions packages/generate-text/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ import {
export interface GenerateTextOptions extends ChatOptions {
/** @default 1 */
maxSteps?: number
/** @internal */
steps?: StepResult[]
/** if you want to enable stream, use `@xsai/stream-{text,object}` */
stream?: never
}
Expand Down Expand Up @@ -66,102 +68,100 @@ export interface StepResult {
usage: GenerateTextResponseUsage
}

export const generateText = async (options: GenerateTextOptions): Promise<GenerateTextResult> => {
let currentStep = 0

let finishReason: FinishReason = 'error'
let text
let usage: GenerateTextResponseUsage = {
completion_tokens: 0,
prompt_tokens: 0,
total_tokens: 0,
}

const steps: StepResult[] = []
const messages: Message[] = options.messages
const toolCalls: ToolCall[] = []
const toolResults: ToolResult[] = []
while (currentStep < (options.maxSteps ?? 1)) {
currentStep += 1

const data: GenerateTextResponse = await chat({
...options,
maxSteps: undefined,
messages,
stream: false,
}).then(res => res.json())

const { finish_reason, message } = data.choices[0]

finishReason = finish_reason
text = message.content
usage = data.usage

const stepResult: StepResult = {
text: message.content,
toolCalls: [],
toolResults: [],
// type: 'initial',
usage,
}

// TODO: fix types
messages.push({ ...message, content: message.content! })

if (message.tool_calls) {
// execute tools
for (const toolCall of message.tool_calls ?? []) {
const tool = (options.tools as Tool[]).find(tool => tool.function.name === toolCall.function.name)!
const parsedArgs: Record<string, any> = JSON.parse(toolCall.function.arguments)
const toolResult = await tool.execute(parsedArgs)
const toolMessage = {
content: toolResult,
role: 'tool',
tool_call_id: toolCall.id,
} satisfies Message
/** @internal */
type RawGenerateTextTrampoline<T> = Promise<(() => RawGenerateTextTrampoline<T>) | T>

/** @internal */
type RawGenerateText = (options: GenerateTextOptions) => RawGenerateTextTrampoline<GenerateTextResult>

/** @internal */
const rawGenerateText: RawGenerateText = async (options: GenerateTextOptions) =>
await chat({
...options,
maxSteps: undefined,
messages: options.messages,
steps: undefined,
stream: false,
})
.then(res => res.json() as Promise<GenerateTextResponse>)
.then(async ({ choices, usage }) => {
const messages: Message[] = options.messages
const steps: StepResult[] = options.steps ?? []
const toolCalls: ToolCall[] = []
const toolResults: ToolResult[] = []

const { finish_reason: finishReason, message } = choices[0]

if (message.content || !message.tool_calls || steps.length >= (options.maxSteps ?? 1)) {
const step: StepResult = {
text: message.content,
toolCalls,
toolResults,
usage,
}

messages.push(toolMessage)
steps.push(step)

const toolCallData = {
args: toolCall.function.arguments,
toolCallId: toolCall.id,
toolCallType: toolCall.type,
toolName: toolCall.function.name,
return {
finishReason,
steps,
...step,
}
toolCalls.push(toolCallData)
stepResult.toolCalls.push(toolCallData)
const toolResultData = {
}

messages.push({ ...message, content: message.content! })

for (const {
function: { arguments: toolArgs, name: toolName },
id: toolCallId,
type: toolCallType,
} of message.tool_calls) {
const tool = (options.tools as Tool[]).find(tool => tool.function.name === toolName)!
const parsedArgs: Record<string, unknown> = JSON.parse(toolArgs)
const result = await tool.execute(parsedArgs)

toolCalls.push({
args: toolArgs,
toolCallId,
toolCallType,
toolName,
})

toolResults.push({
args: parsedArgs,
result: toolResult,
toolCallId: toolCall.id,
toolName: toolCall.function.name,
}
toolResults.push(toolResultData)
stepResult.toolResults.push(toolResultData)
result,
toolCallId,
toolName,
})

messages.push({
content: result,
role: 'tool',
tool_call_id: toolCallId,
})
}
steps.push(stepResult)
}
else {
steps.push(stepResult)
return {
finishReason: finish_reason,
steps,

steps.push({
text: message.content,
toolCalls,
toolResults,
usage,
}
}
}

return {
finishReason,
steps,
text,
toolCalls,
toolResults,
usage,
}
})

return async () => await rawGenerateText({
...options,
messages,
steps,
})
})

export const generateText = async (options: GenerateTextOptions): Promise<GenerateTextResult> => {
let result = await rawGenerateText(options)

while (result instanceof Function)
result = await result()

return result
}

export default generateText
16 changes: 16 additions & 0 deletions packages/generate-text/test/__snapshots__/index.test.ts.snap
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html

exports[`@xsai/generate-text > basic 1`] = `
[
{
"text": "YES",
"toolCalls": [],
"toolResults": [],
"usage": {
"completion_tokens": 2,
"prompt_tokens": 46,
"total_tokens": 48,
},
},
]
`;
6 changes: 5 additions & 1 deletion packages/generate-text/test/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import { generateText } from '../src'

describe('@xsai/generate-text', () => {
it('basic', async () => {
const { text } = await generateText({
const { finishReason, steps, text, toolCalls, toolResults } = await generateText({
...ollama.chat('llama3.2'),
messages: [
{
Expand All @@ -20,6 +20,10 @@ describe('@xsai/generate-text', () => {
})

expect(text).toStrictEqual('YES')
expect(finishReason).toBe('stop')
expect(toolCalls.length).toBe(0)
expect(toolResults.length).toBe(0)
expect(steps).toMatchSnapshot()
})

// TODO: error handling
Expand Down
11 changes: 10 additions & 1 deletion packages/tool/test/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ describe('@xsai/tool', () => {
}),
})

const { text } = await generateText({
const { steps, text } = await generateText({
...ollama.chat('mistral-nemo'),
maxSteps: 2,
messages: [
Expand All @@ -69,5 +69,14 @@ describe('@xsai/tool', () => {
})

expect(text).toMatchSnapshot()

const { toolCalls, toolResults } = steps[0]

expect(toolCalls[0].toolName).toBe('weather')
expect(toolCalls[0].args).toBe('{"location":"San Francisco"}')

expect(toolCalls[0].toolName).toBe('weather')
expect(toolResults[0].args).toStrictEqual({ location: 'San Francisco' })
expect(toolResults[0].result).toBe('{"location":"San Francisco","temperature":42}')
}, 20000)
})

0 comments on commit d143e32

Please sign in to comment.