diff --git a/packages/generate-text/src/index.ts b/packages/generate-text/src/index.ts index bf28b812..c14af482 100644 --- a/packages/generate-text/src/index.ts +++ b/packages/generate-text/src/index.ts @@ -117,7 +117,7 @@ const rawGenerateText: RawGenerateText = async (options: GenerateTextOptions) => } of message.tool_calls) { const tool = (options.tools as Tool[]).find(tool => tool.function.name === toolName)! const parsedArgs = JSON.parse(toolArgs) as Record - const result = await tool.execute(parsedArgs) + const result = await tool.execute(parsedArgs, { abortSignal: options.abortSignal, messages, toolCallId }) toolCalls.push({ args: toolArgs, diff --git a/packages/shared-chat/src/types/tool.ts b/packages/shared-chat/src/types/tool.ts index aea0f06f..922a7baf 100644 --- a/packages/shared-chat/src/types/tool.ts +++ b/packages/shared-chat/src/types/tool.ts @@ -1,5 +1,7 @@ +import type { Message } from './message' + export interface Tool { - execute: (input: unknown) => Promise | string + execute: (input: unknown, options: ToolExecuteOptions) => Promise | string function: { description?: string name: string @@ -8,3 +10,9 @@ export interface Tool { } type: 'function' } + +interface ToolExecuteOptions { + abortSignal?: AbortSignal + messages: Message[] + toolCallId: string +} diff --git a/packages/tool/src/generate-text.ts b/packages/tool/src/generate-text.ts index 22979a30..26ec8464 100644 --- a/packages/tool/src/generate-text.ts +++ b/packages/tool/src/generate-text.ts @@ -1,3 +1,5 @@ +/// + import type { Schema } from '@typeschema/main' import type { ToolResult } from '.'