Skip to content

Commit

Permalink
feat(generate-text): add onStepFinish() (#26)
Browse files Browse the repository at this point in the history
* feat(generate-text): add onStepFinish()

* chore(generate-text): update test
  • Loading branch information
kwaa authored Jan 15, 2025
1 parent c59ab2c commit 2bc1733
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
13 changes: 11 additions & 2 deletions packages/generate-text/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
export interface GenerateTextOptions extends ChatOptions {
/** @default 1 */
maxSteps?: number
onStepFinish?: (step: StepResult) => Promise<void> | void
/** @internal */
steps?: StepResult[]
/** if you want to enable stream, use `@xsai/stream-{text,object}` */
Expand Down Expand Up @@ -102,6 +103,9 @@ const rawGenerateText: RawGenerateText = async (options: GenerateTextOptions) =>

steps.push(step)

if (options.onStepFinish)
await options.onStepFinish(step)

return {
finishReason,
steps,
Expand Down Expand Up @@ -141,12 +145,17 @@ const rawGenerateText: RawGenerateText = async (options: GenerateTextOptions) =>
})
}

steps.push({
const step: StepResult = {
text: message.content,
toolCalls,
toolResults,
usage,
})
}

steps.push(step)

if (options.onStepFinish)
await options.onStepFinish(step)

return async () => await rawGenerateText({
...options,
Expand Down
9 changes: 8 additions & 1 deletion packages/generate-text/test/index.test.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import { ollama } from '@xsai/providers'
import { describe, expect, it } from 'vitest'

import { generateText } from '../src'
import { generateText, type StepResult } from '../src'

describe('@xsai/generate-text', () => {
it('basic', async () => {
let step: StepResult | undefined

const { finishReason, steps, text, toolCalls, toolResults } = await generateText({
...ollama.chat('llama3.2'),
messages: [
Expand All @@ -17,13 +19,18 @@ describe('@xsai/generate-text', () => {
role: 'user',
},
],
onStepFinish: (result) => {
step = result
},
})

expect(text).toStrictEqual('YES')
expect(finishReason).toBe('stop')
expect(toolCalls.length).toBe(0)
expect(toolResults.length).toBe(0)
expect(steps).toMatchSnapshot()

expect(steps[0]).toStrictEqual(step)
})

// TODO: error handling
Expand Down

0 comments on commit 2bc1733

Please sign in to comment.