diff --git a/.changeset/slimy-lions-visit.md b/.changeset/slimy-lions-visit.md new file mode 100644 index 000000000000..fee4e71e99b0 --- /dev/null +++ b/.changeset/slimy-lions-visit.md @@ -0,0 +1,5 @@ +--- +'@ai-sdk/amazon-bedrock': patch +--- + +feat (provider/amazon-bedrock): add support for cache points diff --git a/content/providers/01-ai-sdk-providers/08-amazon-bedrock.mdx b/content/providers/01-ai-sdk-providers/08-amazon-bedrock.mdx index 5b62630a35f7..d36f736994ba 100644 --- a/content/providers/01-ai-sdk-providers/08-amazon-bedrock.mdx +++ b/content/providers/01-ai-sdk-providers/08-amazon-bedrock.mdx @@ -219,6 +219,103 @@ if (result.providerMetadata?.bedrock.trace) { See the [Amazon Bedrock Guardrails documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/guardrails.html) for more information. +### Cache Points + + + Amazon Bedrock prompt caching is currently in preview release. To request + access, visit the [Amazon Bedrock prompt caching + page](https://aws.amazon.com/bedrock/prompt-caching/). + + +In messages, you can use the `providerOptions` property to set cache points. Set the `bedrock` property in the `providerOptions` object to `{ cachePoint: { type: 'default' } }` to create a cache point. + +Cache usage information is returned in the `providerMetadata` object`. See examples below. + + + Cache points have model-specific token minimums and limits. For example, + Claude 3.5 Sonnet v2 requires at least 1,024 tokens for a cache point and + allows up to 4 cache points. See the [Amazon Bedrock prompt caching + documentation](https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html) + for details on supported models, regions, and limits. + + +```ts +import { bedrock } from '@ai-sdk/amazon-bedrock'; +import { generateText } from 'ai'; + +const cyberpunkAnalysis = + '... literary analysis of cyberpunk themes and concepts ...'; + +const result = await generateText({ + model: bedrock('anthropic.claude-3-5-sonnet-20241022-v2:0'), + messages: [ + { + role: 'system', + content: `You are an expert on William Gibson's cyberpunk literature and themes. You have access to the following academic analysis: ${cyberpunkAnalysis}`, + providerOptions: { + bedrock: { cachePoint: { type: 'default' } }, + }, + }, + { + role: 'user', + content: + 'What are the key cyberpunk themes that Gibson explores in Neuromancer?', + }, + ], +}); + +console.log(result.text); +console.log(result.providerMetadata?.bedrock?.usage); +// Shows cache read/write token usage, e.g.: +// { +// cacheReadInputTokens: 1337, +// cacheWriteInputTokens: 42, +// } +``` + +Cache points also work with streaming responses: + +```ts +import { bedrock } from '@ai-sdk/amazon-bedrock'; +import { streamText } from 'ai'; + +const cyberpunkAnalysis = + '... literary analysis of cyberpunk themes and concepts ...'; + +const result = streamText({ + model: bedrock('anthropic.claude-3-5-sonnet-20241022-v2:0'), + messages: [ + { + role: 'assistant', + content: [ + { type: 'text', text: 'You are an expert on cyberpunk literature.' }, + { type: 'text', text: `Academic analysis: ${cyberpunkAnalysis}` }, + ], + providerOptions: { bedrock: { cachePoint: { type: 'default' } } }, + }, + { + role: 'user', + content: + 'How does Gibson explore the relationship between humanity and technology?', + }, + ], +}); + +for await (const textPart of result.textStream) { + process.stdout.write(textPart); +} + +console.log( + 'Cache token usage:', + (await result.providerMetadata)?.bedrock?.usage, +); +// Shows cache read/write token usage, e.g.: +// { +// cacheReadInputTokens: 1337, +// cacheWriteInputTokens: 42, +// } +``` + ### Model Capabilities | Model | Image Input | Object Generation | Tool Usage | Tool Streaming | diff --git a/examples/ai-core/src/generate-text/amazon-bedrock-cache-point-assistant.ts b/examples/ai-core/src/generate-text/amazon-bedrock-cache-point-assistant.ts new file mode 100644 index 000000000000..48499bc1f479 --- /dev/null +++ b/examples/ai-core/src/generate-text/amazon-bedrock-cache-point-assistant.ts @@ -0,0 +1,46 @@ +import { bedrock } from '@ai-sdk/amazon-bedrock'; +import { generateText } from 'ai'; +import 'dotenv/config'; +import fs from 'node:fs'; + +const errorMessage = fs.readFileSync('data/error-message.txt', 'utf8'); + +async function main() { + const result = await generateText({ + model: bedrock('anthropic.claude-3-5-sonnet-20241022-v2:0'), + messages: [ + { + role: 'assistant', + content: [ + { + type: 'text', + text: 'You are a JavaScript expert.', + }, + { + type: 'text', + text: `Error message: ${errorMessage}`, + }, + ], + providerOptions: { bedrock: { cachePoint: { type: 'default' } } }, + }, + { + role: 'user', + content: [ + { + type: 'text', + text: 'Explain the error message.', + }, + ], + }, + ], + }); + + console.log(result.text); + console.log(); + console.log('Token usage:', result.usage); + console.log('Cache token usage:', result.providerMetadata?.bedrock?.usage); + console.log('Finish reason:', result.finishReason); + console.log('Response headers:', result.response.headers); +} + +main().catch(console.error); diff --git a/examples/ai-core/src/generate-text/amazon-bedrock-cache-point-system.ts b/examples/ai-core/src/generate-text/amazon-bedrock-cache-point-system.ts new file mode 100644 index 000000000000..5e40d3e720e7 --- /dev/null +++ b/examples/ai-core/src/generate-text/amazon-bedrock-cache-point-system.ts @@ -0,0 +1,35 @@ +import { bedrock } from '@ai-sdk/amazon-bedrock'; +import { generateText } from 'ai'; +import 'dotenv/config'; +import fs from 'node:fs'; + +const errorMessage = fs.readFileSync('data/error-message.txt', 'utf8'); + +async function main() { + const result = await generateText({ + model: bedrock('anthropic.claude-3-5-sonnet-20241022-v2:0'), + maxTokens: 512, + messages: [ + { + role: 'system', + content: `You are a helpful assistant. You may be asked about ${errorMessage}.`, + providerOptions: { + bedrock: { cachePoint: { type: 'default' } }, + }, + }, + { + role: 'user', + content: `Explain the error message`, + }, + ], + }); + + console.log(result.text); + console.log(); + console.log('Token usage:', result.usage); + console.log('Cache token usage:', result.providerMetadata?.bedrock?.usage); + console.log('Finish reason:', result.finishReason); + console.log('Response headers:', result.response.headers); +} + +main().catch(console.error); diff --git a/examples/ai-core/src/generate-text/amazon-bedrock-cache-point-tool-call.ts b/examples/ai-core/src/generate-text/amazon-bedrock-cache-point-tool-call.ts new file mode 100644 index 000000000000..ddb31c46a567 --- /dev/null +++ b/examples/ai-core/src/generate-text/amazon-bedrock-cache-point-tool-call.ts @@ -0,0 +1,168 @@ +import { generateText, tool } from 'ai'; +import 'dotenv/config'; +import { z } from 'zod'; +import { bedrock } from '@ai-sdk/amazon-bedrock'; + +const weatherTool = tool({ + description: 'Get the weather in a location', + parameters: z.object({ + location: z.string().describe('The location to get the weather for'), + }), + // location below is inferred to be a string: + execute: async ({ location }) => ({ + location, + temperature: weatherData[location], + }), +}); + +const weatherData: Record = { + 'New York': 72.4, + 'Los Angeles': 84.2, + Chicago: 68.9, + Houston: 89.7, + Phoenix: 95.6, + Philadelphia: 71.3, + 'San Antonio': 88.4, + 'San Diego': 76.8, + Dallas: 86.5, + 'San Jose': 75.2, + Austin: 87.9, + Jacksonville: 83.6, + 'Fort Worth': 85.7, + Columbus: 69.8, + 'San Francisco': 68.4, + Charlotte: 77.3, + Indianapolis: 70.6, + Seattle: 65.9, + Denver: 71.8, + 'Washington DC': 74.5, + Boston: 69.7, + 'El Paso': 91.2, + Detroit: 67.8, + Nashville: 78.4, + Portland: 66.7, + Memphis: 81.3, + 'Oklahoma City': 82.9, + 'Las Vegas': 93.4, + Louisville: 75.6, + Baltimore: 73.8, + Milwaukee: 66.5, + Albuquerque: 84.7, + Tucson: 92.3, + Fresno: 87.2, + Sacramento: 82.5, + Mesa: 94.8, + 'Kansas City': 77.9, + Atlanta: 80.6, + Miami: 88.3, + Raleigh: 76.4, + Omaha: 73.5, + 'Colorado Springs': 70.2, + 'Long Beach': 79.8, + 'Virginia Beach': 78.1, + Oakland: 71.4, + Minneapolis: 65.8, + Tulsa: 81.7, + Arlington: 85.3, + Tampa: 86.9, + 'New Orleans': 84.5, + Wichita: 79.4, + Cleveland: 68.7, + Bakersfield: 88.6, + Aurora: 72.3, + Anaheim: 81.5, + Honolulu: 84.9, + 'Santa Ana': 80.7, + Riverside: 89.2, + 'Corpus Christi': 87.6, + Lexington: 74.8, + Henderson: 92.7, + Stockton: 83.9, + 'Saint Paul': 66.2, + Cincinnati: 72.9, + Pittsburgh: 70.4, + Greensboro: 75.9, + Anchorage: 52.3, + Plano: 84.8, + Lincoln: 74.2, + Orlando: 85.7, + Irvine: 78.9, + Newark: 71.6, + Toledo: 69.3, + Durham: 77.1, + 'Chula Vista': 77.4, + 'Fort Wayne': 71.2, + 'Jersey City': 72.7, + 'St. Petersburg': 85.4, + Laredo: 90.8, + Madison: 67.3, + Chandler: 93.6, + Buffalo: 66.8, + Lubbock: 83.2, + Scottsdale: 94.1, + Reno: 76.5, + Glendale: 92.8, + Gilbert: 93.9, + 'Winston-Salem': 76.2, + Irving: 85.1, + Hialeah: 87.8, + Garland: 84.6, + Fremont: 73.9, + Boise: 75.3, + Richmond: 76.7, + 'Baton Rouge': 83.7, + Spokane: 67.4, + 'Des Moines': 72.1, + Tacoma: 66.3, + 'San Bernardino': 88.1, + Modesto: 84.3, + Fontana: 87.4, + 'Santa Clarita': 82.6, + Birmingham: 81.9, +}; + +async function main() { + const result = await generateText({ + model: bedrock('anthropic.claude-3-5-sonnet-20241022-v2:0'), + tools: { + weather: weatherTool, + }, + prompt: 'What is the weather in San Francisco?', + // TODO: need a way to set cachePoint on `tools`. + providerOptions: { + bedrock: { + cachePoint: { + type: 'default', + }, + }, + }, + }); + + // typed tool calls: + for (const toolCall of result.toolCalls) { + switch (toolCall.toolName) { + case 'weather': { + toolCall.args.location; // string + break; + } + } + } + + // typed tool results for tools with execute method: + for (const toolResult of result.toolResults) { + switch (toolResult.toolName) { + case 'weather': { + toolResult.args.location; // string + toolResult.result.location; // string + toolResult.result.temperature; // number + break; + } + } + } + + console.log(result.text); + console.log(JSON.stringify(result.toolCalls, null, 2)); + console.log(JSON.stringify(result.providerMetadata, null, 2)); +} + +main().catch(console.error); diff --git a/examples/ai-core/src/generate-text/amazon-bedrock-cache-point-user-image.ts b/examples/ai-core/src/generate-text/amazon-bedrock-cache-point-user-image.ts new file mode 100644 index 000000000000..99073fed15be --- /dev/null +++ b/examples/ai-core/src/generate-text/amazon-bedrock-cache-point-user-image.ts @@ -0,0 +1,36 @@ +import { bedrock } from '@ai-sdk/amazon-bedrock'; +import { generateText } from 'ai'; +import 'dotenv/config'; +import fs from 'node:fs'; + +async function main() { + const result = await generateText({ + model: bedrock('anthropic.claude-3-5-sonnet-20241022-v2:0'), + messages: [ + { + role: 'user', + content: [ + { type: 'image', image: fs.readFileSync('./data/comic-cat.png') }, + { + type: 'text', + text: 'What is in this image?', + }, + ], + providerOptions: { bedrock: { cachePoint: { type: 'default' } } }, + }, + ], + }); + + console.log(result.text); + console.log(); + console.log('Token usage:', result.usage); + // TODO: no cache token usage for some reason + // https://docs.aws.amazon.com/bedrock/latest/userguide/prompt-caching.html + // the only delta is some of the lead-in to passing the message bytes, and + // perhaps the size of the image. + console.log('Cache token usage:', result.providerMetadata?.bedrock?.usage); + console.log('Finish reason:', result.finishReason); + console.log('Response headers:', result.response.headers); +} + +main().catch(console.error); diff --git a/examples/ai-core/src/generate-text/amazon-bedrock-cache-point-user.ts b/examples/ai-core/src/generate-text/amazon-bedrock-cache-point-user.ts new file mode 100644 index 000000000000..318373e8c70e --- /dev/null +++ b/examples/ai-core/src/generate-text/amazon-bedrock-cache-point-user.ts @@ -0,0 +1,42 @@ +import { bedrock } from '@ai-sdk/amazon-bedrock'; +import { generateText } from 'ai'; +import 'dotenv/config'; +import fs from 'node:fs'; + +const errorMessage = fs.readFileSync('data/error-message.txt', 'utf8'); + +async function main() { + const result = await generateText({ + model: bedrock('anthropic.claude-3-5-sonnet-20241022-v2:0'), + messages: [ + { + role: 'user', + content: [ + { + type: 'text', + text: `I was dreaming last night and I dreamt of an error message: ${errorMessage}`, + }, + ], + providerOptions: { bedrock: { cachePoint: { type: 'default' } } }, + }, + { + role: 'user', + content: [ + { + type: 'text', + text: 'Explain the error message.', + }, + ], + }, + ], + }); + + console.log(result.text); + console.log(); + console.log('Token usage:', result.usage); + console.log('Cache token usage:', result.providerMetadata?.bedrock?.usage); + console.log('Finish reason:', result.finishReason); + console.log('Response headers:', result.response.headers); +} + +main().catch(console.error); diff --git a/examples/ai-core/src/stream-text/amazon-bedrock-cache-point-assistant.ts b/examples/ai-core/src/stream-text/amazon-bedrock-cache-point-assistant.ts new file mode 100644 index 000000000000..f4e99e23eb7b --- /dev/null +++ b/examples/ai-core/src/stream-text/amazon-bedrock-cache-point-assistant.ts @@ -0,0 +1,52 @@ +import { bedrock } from '@ai-sdk/amazon-bedrock'; +import { streamText } from 'ai'; +import 'dotenv/config'; +import fs from 'node:fs'; + +const errorMessage = fs.readFileSync('data/error-message.txt', 'utf8'); + +async function main() { + const result = streamText({ + model: bedrock('anthropic.claude-3-5-sonnet-20241022-v2:0'), + messages: [ + { + role: 'assistant', + content: [ + { + type: 'text', + text: 'You are a JavaScript expert.', + }, + { + type: 'text', + text: `Error message: ${errorMessage}`, + }, + ], + providerOptions: { bedrock: { cachePoint: { type: 'default' } } }, + }, + { + role: 'user', + content: [ + { + type: 'text', + text: 'Explain the error message.', + }, + ], + }, + ], + }); + + for await (const textPart of result.textStream) { + process.stdout.write(textPart); + } + + console.log(); + console.log('Token usage:', await result.usage); + console.log( + 'Cache token usage:', + (await result.providerMetadata)?.bedrock?.usage, + ); + console.log('Finish reason:', await result.finishReason); + console.log('Response headers:', (await result.response).headers); +} + +main().catch(console.error); diff --git a/examples/ai-core/src/stream-text/amazon-bedrock-cache-point-image.ts b/examples/ai-core/src/stream-text/amazon-bedrock-cache-point-image.ts new file mode 100644 index 000000000000..135aff4e49d4 --- /dev/null +++ b/examples/ai-core/src/stream-text/amazon-bedrock-cache-point-image.ts @@ -0,0 +1,33 @@ +import { bedrock } from '@ai-sdk/amazon-bedrock'; +import { streamText } from 'ai'; +import 'dotenv/config'; +import fs from 'node:fs'; + +async function main() { + const result = streamText({ + model: bedrock('anthropic.claude-3-5-sonnet-20241022-v2:0'), + maxTokens: 512, + messages: [ + { + role: 'user', + content: [ + { type: 'text', text: 'Describe the image in detail.' }, + { type: 'image', image: fs.readFileSync('./data/comic-cat.png') }, + ], + }, + ], + }); + + for await (const textPart of result.textStream) { + process.stdout.write(textPart); + } + + console.log(); + console.log('Token usage:', await result.usage); + console.log( + 'Cache token usage:', + (await result.providerMetadata)?.bedrock?.usage, + ); +} + +main().catch(console.error); diff --git a/examples/ai-core/src/stream-text/amazon-bedrock-cache-point-system.ts b/examples/ai-core/src/stream-text/amazon-bedrock-cache-point-system.ts new file mode 100644 index 000000000000..fc19402bd661 --- /dev/null +++ b/examples/ai-core/src/stream-text/amazon-bedrock-cache-point-system.ts @@ -0,0 +1,40 @@ +import { bedrock } from '@ai-sdk/amazon-bedrock'; +import { streamText } from 'ai'; +import 'dotenv/config'; +import fs from 'node:fs'; + +const errorMessage = fs.readFileSync('data/error-message.txt', 'utf8'); + +async function main() { + const result = streamText({ + model: bedrock('anthropic.claude-3-5-sonnet-20241022-v2:0'), + messages: [ + { + role: 'system', + content: `You are a helpful assistant. You may be asked about ${errorMessage}.`, + providerOptions: { + bedrock: { cachePoint: { type: 'default' } }, + }, + }, + { + role: 'user', + content: `Explain the error message`, + }, + ], + }); + + for await (const textPart of result.textStream) { + process.stdout.write(textPart); + } + + console.log(); + console.log('Token usage:', await result.usage); + console.log( + 'Cache token usage:', + (await result.providerMetadata)?.bedrock?.usage, + ); + console.log('Finish reason:', await result.finishReason); + console.log('Response headers:', (await result.response).headers); +} + +main().catch(console.error); diff --git a/examples/ai-core/src/stream-text/amazon-bedrock-cache-point-tool-call.ts b/examples/ai-core/src/stream-text/amazon-bedrock-cache-point-tool-call.ts new file mode 100644 index 000000000000..080fd419b434 --- /dev/null +++ b/examples/ai-core/src/stream-text/amazon-bedrock-cache-point-tool-call.ts @@ -0,0 +1,202 @@ +import { bedrock } from '@ai-sdk/amazon-bedrock'; +import { + streamText, + tool, + CoreMessage, + ToolCallPart, + ToolResultPart, +} from 'ai'; +import 'dotenv/config'; +import { z } from 'zod'; + +const messages: CoreMessage[] = []; + +const weatherTool = tool({ + description: 'Get the weather in a location', + parameters: z.object({ + location: z.string().describe('The location to get the weather for'), + }), + // location below is inferred to be a string: + execute: async ({ location }) => ({ + location, + temperature: weatherData[location], + }), +}); + +const weatherData: Record = { + 'New York': 72.4, + 'Los Angeles': 84.2, + Chicago: 68.9, + Houston: 89.7, + Phoenix: 95.6, + Philadelphia: 71.3, + 'San Antonio': 88.4, + 'San Diego': 76.8, + Dallas: 86.5, + 'San Jose': 75.2, + Austin: 87.9, + Jacksonville: 83.6, + 'Fort Worth': 85.7, + Columbus: 69.8, + 'San Francisco': 68.4, + Charlotte: 77.3, + Indianapolis: 70.6, + Seattle: 65.9, + Denver: 71.8, + 'Washington DC': 74.5, + Boston: 69.7, + 'El Paso': 91.2, + Detroit: 67.8, + Nashville: 78.4, + Portland: 66.7, + Memphis: 81.3, + 'Oklahoma City': 82.9, + 'Las Vegas': 93.4, + Louisville: 75.6, + Baltimore: 73.8, + Milwaukee: 66.5, + Albuquerque: 84.7, + Tucson: 92.3, + Fresno: 87.2, + Sacramento: 82.5, + Mesa: 94.8, + 'Kansas City': 77.9, + Atlanta: 80.6, + Miami: 88.3, + Raleigh: 76.4, + Omaha: 73.5, + 'Colorado Springs': 70.2, + 'Long Beach': 79.8, + 'Virginia Beach': 78.1, + Oakland: 71.4, + Minneapolis: 65.8, + Tulsa: 81.7, + Arlington: 85.3, + Tampa: 86.9, + 'New Orleans': 84.5, + Wichita: 79.4, + Cleveland: 68.7, + Bakersfield: 88.6, + Aurora: 72.3, + Anaheim: 81.5, + Honolulu: 84.9, + 'Santa Ana': 80.7, + Riverside: 89.2, + 'Corpus Christi': 87.6, + Lexington: 74.8, + Henderson: 92.7, + Stockton: 83.9, + 'Saint Paul': 66.2, + Cincinnati: 72.9, + Pittsburgh: 70.4, + Greensboro: 75.9, + Anchorage: 52.3, + Plano: 84.8, + Lincoln: 74.2, + Orlando: 85.7, + Irvine: 78.9, + Newark: 71.6, + Toledo: 69.3, + Durham: 77.1, + 'Chula Vista': 77.4, + 'Fort Wayne': 71.2, + 'Jersey City': 72.7, + 'St. Petersburg': 85.4, + Laredo: 90.8, + Madison: 67.3, + Chandler: 93.6, + Buffalo: 66.8, + Lubbock: 83.2, + Scottsdale: 94.1, + Reno: 76.5, + Glendale: 92.8, + Gilbert: 93.9, + 'Winston-Salem': 76.2, + Irving: 85.1, + Hialeah: 87.8, + Garland: 84.6, + Fremont: 73.9, + Boise: 75.3, + Richmond: 76.7, + 'Baton Rouge': 83.7, + Spokane: 67.4, + 'Des Moines': 72.1, + Tacoma: 66.3, + 'San Bernardino': 88.1, + Modesto: 84.3, + Fontana: 87.4, + 'Santa Clarita': 82.6, + Birmingham: 81.9, +}; + +async function main() { + let toolResponseAvailable = false; + + const result = streamText({ + model: bedrock('anthropic.claude-3-haiku-20240307-v1:0'), + maxTokens: 512, + tools: { + weather: weatherTool, + }, + toolChoice: 'required', + prompt: 'What is the weather in San Francisco?', + // TODO: need a way to set cachePoint on `tools`. + providerOptions: { + bedrock: { + cachePoint: { + type: 'default', + }, + }, + }, + }); + + let fullResponse = ''; + const toolCalls: ToolCallPart[] = []; + const toolResponses: ToolResultPart[] = []; + + for await (const delta of result.fullStream) { + switch (delta.type) { + case 'text-delta': { + fullResponse += delta.textDelta; + process.stdout.write(delta.textDelta); + break; + } + + case 'tool-call': { + toolCalls.push(delta); + + process.stdout.write( + `\nTool call: '${delta.toolName}' ${JSON.stringify(delta.args)}`, + ); + break; + } + + case 'tool-result': { + toolResponses.push(delta); + + process.stdout.write( + `\nTool response: '${delta.toolName}' ${JSON.stringify( + delta.result, + )}`, + ); + break; + } + } + } + process.stdout.write('\n\n'); + + messages.push({ + role: 'assistant', + content: [{ type: 'text', text: fullResponse }, ...toolCalls], + }); + + if (toolResponses.length > 0) { + messages.push({ role: 'tool', content: toolResponses }); + } + + toolResponseAvailable = toolCalls.length > 0; + console.log('Messages:', messages[0].content); + console.log(JSON.stringify(result.providerMetadata, null, 2)); +} + +main().catch(console.error); diff --git a/examples/ai-core/src/stream-text/amazon-bedrock-cache-point-user.ts b/examples/ai-core/src/stream-text/amazon-bedrock-cache-point-user.ts new file mode 100644 index 000000000000..2bfe4a3078b5 --- /dev/null +++ b/examples/ai-core/src/stream-text/amazon-bedrock-cache-point-user.ts @@ -0,0 +1,48 @@ +import { bedrock } from '@ai-sdk/amazon-bedrock'; +import { streamText } from 'ai'; +import 'dotenv/config'; +import fs from 'node:fs'; + +const errorMessage = fs.readFileSync('data/error-message.txt', 'utf8'); + +async function main() { + const result = streamText({ + model: bedrock('anthropic.claude-3-5-sonnet-20241022-v2:0'), + messages: [ + { + role: 'user', + content: [ + { + type: 'text', + text: `I was dreaming last night and I dreamt of an error message: ${errorMessage}`, + }, + ], + providerOptions: { bedrock: { cachePoint: { type: 'default' } } }, + }, + { + role: 'user', + content: [ + { + type: 'text', + text: 'Explain the error message.', + }, + ], + }, + ], + }); + + for await (const textPart of result.textStream) { + process.stdout.write(textPart); + } + + console.log(); + console.log('Token usage:', await result.usage); + console.log( + 'Cache token usage:', + (await result.providerMetadata)?.bedrock?.usage, + ); + console.log('Finish reason:', await result.finishReason); + console.log('Response headers:', (await result.response).headers); +} + +main().catch(console.error); diff --git a/packages/amazon-bedrock/src/bedrock-api-types.ts b/packages/amazon-bedrock/src/bedrock-api-types.ts index 3ed1f8a1d716..904c6d4ded74 100644 --- a/packages/amazon-bedrock/src/bedrock-api-types.ts +++ b/packages/amazon-bedrock/src/bedrock-api-types.ts @@ -2,11 +2,8 @@ import { JSONObject } from '@ai-sdk/provider'; import { Resolvable } from '@ai-sdk/provider-utils'; export interface BedrockConverseInput { - system?: Array<{ text: string }>; - messages: Array<{ - role: string; - content: Array; - }>; + system?: BedrockSystemMessages; + messages: BedrockMessages; toolConfig?: BedrockToolConfiguration; inferenceConfig?: { maxTokens?: number; @@ -21,6 +18,29 @@ export interface BedrockConverseInput { | undefined; } +export type BedrockSystemMessages = Array; + +export type BedrockMessages = Array< + BedrockAssistantMessage | BedrockUserMessage +>; + +export interface BedrockAssistantMessage { + role: 'assistant'; + content: Array; +} + +export interface BedrockUserMessage { + role: 'user'; + content: Array; +} + +export const BEDROCK_CACHE_POINT = { + cachePoint: { type: 'default' }, +} as const; + +export type BedrockCachePoint = { cachePoint: { type: 'default' } }; +export type BedrockSystemContentBlock = { text: string } | BedrockCachePoint; + export interface BedrockGuardrailConfiguration { guardrails?: Array<{ name: string; @@ -44,7 +64,7 @@ export interface BedrockTool { } export interface BedrockToolConfiguration { - tools?: BedrockTool[]; + tools?: Array; toolChoice?: | { tool: { name: string } } | { auto: {} } @@ -118,4 +138,5 @@ export type BedrockContentBlock = | BedrockImageBlock | BedrockTextBlock | BedrockToolResultBlock - | BedrockToolUseBlock; + | BedrockToolUseBlock + | BedrockCachePoint; diff --git a/packages/amazon-bedrock/src/bedrock-chat-language-model.test.ts b/packages/amazon-bedrock/src/bedrock-chat-language-model.test.ts index a6c0062b727d..f1d9ba6d6989 100644 --- a/packages/amazon-bedrock/src/bedrock-chat-language-model.test.ts +++ b/packages/amazon-bedrock/src/bedrock-chat-language-model.test.ts @@ -870,6 +870,99 @@ describe('doStream', () => { const body = await server.calls[0].requestBody; expect(body).toMatchObject({ foo: 'bar' }); }); + + it('should include cache token usage in providerMetadata', async () => { + setupMockEventStreamHandler(); + server.urls[streamUrl].response = { + type: 'stream-chunks', + chunks: [ + JSON.stringify({ + contentBlockDelta: { + contentBlockIndex: 0, + delta: { text: 'Hello' }, + }, + }) + '\n', + JSON.stringify({ + metadata: { + usage: { + inputTokens: 4, + outputTokens: 34, + totalTokens: 38, + cacheReadInputTokens: 2, + cacheWriteInputTokens: 3, + }, + }, + }) + '\n', + JSON.stringify({ + messageStop: { + stopReason: 'stop_sequence', + }, + }) + '\n', + ], + }; + + const { stream } = await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(await convertReadableStreamToArray(stream)).toStrictEqual([ + { type: 'text-delta', textDelta: 'Hello' }, + { + type: 'finish', + finishReason: 'stop', + usage: { promptTokens: 4, completionTokens: 34 }, + providerMetadata: { + bedrock: { + usage: { + cacheReadInputTokens: 2, + cacheWriteInputTokens: 3, + }, + }, + }, + }, + ]); + }); + + it('should handle system messages with cache points', async () => { + setupMockEventStreamHandler(); + server.urls[streamUrl].response = { + type: 'stream-chunks', + chunks: [ + JSON.stringify({ + contentBlockDelta: { + contentBlockIndex: 0, + delta: { text: 'Hello' }, + }, + }) + '\n', + JSON.stringify({ + messageStop: { + stopReason: 'stop_sequence', + }, + }) + '\n', + ], + }; + + await model.doStream({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: [ + { + role: 'system', + content: 'System Prompt', + providerMetadata: { bedrock: { cachePoint: { type: 'default' } } }, + }, + { role: 'user', content: [{ type: 'text', text: 'Hello' }] }, + ], + }); + + const requestBody = await server.calls[0].requestBody; + expect(requestBody).toMatchObject({ + system: [{ text: 'System Prompt' }, { cachePoint: { type: 'default' } }], + messages: [{ role: 'user', content: [{ text: 'Hello' }] }], + }); + }); }); describe('doGenerate', () => { @@ -880,6 +973,8 @@ describe('doGenerate', () => { inputTokens: 4, outputTokens: 34, totalTokens: 38, + cacheReadInputTokens: undefined, + cacheWriteInputTokens: undefined, }, stopReason = 'stop_sequence', trace, @@ -894,6 +989,8 @@ describe('doGenerate', () => { inputTokens: number; outputTokens: number; totalTokens: number; + cacheReadInputTokens?: number; + cacheWriteInputTokens?: number; }; stopReason?: string; trace?: typeof mockTrace; @@ -1281,4 +1378,59 @@ describe('doGenerate', () => { const body = await server.calls[0].requestBody; expect(body).toMatchObject({ foo: 'bar' }); }); + + it('should include cache token usage in providerMetadata', async () => { + prepareJsonResponse({ + content: 'Testing', + usage: { + inputTokens: 4, + outputTokens: 34, + totalTokens: 38, + cacheReadInputTokens: 2, + cacheWriteInputTokens: 3, + }, + }); + + const response = await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: TEST_PROMPT, + }); + + expect(response.providerMetadata).toEqual({ + bedrock: { + usage: { + cacheReadInputTokens: 2, + cacheWriteInputTokens: 3, + }, + }, + }); + expect(response.usage).toEqual({ + promptTokens: 4, + completionTokens: 34, + }); + }); + + it('should handle system messages with cache points', async () => { + prepareJsonResponse({}); + + await model.doGenerate({ + inputFormat: 'prompt', + mode: { type: 'regular' }, + prompt: [ + { + role: 'system', + content: 'System Prompt', + providerMetadata: { bedrock: { cachePoint: { type: 'default' } } }, + }, + { role: 'user', content: [{ type: 'text', text: 'Hello' }] }, + ], + }); + + const requestBody = await server.calls[0].requestBody; + expect(requestBody).toMatchObject({ + system: [{ text: 'System Prompt' }, { cachePoint: { type: 'default' } }], + messages: [{ role: 'user', content: [{ text: 'Hello' }] }], + }); + }); }); diff --git a/packages/amazon-bedrock/src/bedrock-chat-language-model.ts b/packages/amazon-bedrock/src/bedrock-chat-language-model.ts index a19660deb111..e15d2973d14d 100644 --- a/packages/amazon-bedrock/src/bedrock-chat-language-model.ts +++ b/packages/amazon-bedrock/src/bedrock-chat-language-model.ts @@ -21,6 +21,7 @@ import { BedrockConverseInput, BedrockStopReason, BEDROCK_STOP_REASONS, + BEDROCK_CACHE_POINT, } from './bedrock-api-types'; import { BedrockChatModelId, @@ -120,7 +121,7 @@ export class BedrockChatLanguageModel implements LanguageModelV1 { }; const baseArgs: BedrockConverseInput = { - system: system ? [{ text: system }] : undefined, + system, additionalModelRequestFields: this.settings.additionalModelRequestFields, ...(Object.keys(inferenceConfig).length > 0 && { inferenceConfig, @@ -203,9 +204,24 @@ export class BedrockChatLanguageModel implements LanguageModelV1 { const { messages: rawPrompt, ...rawSettings } = args; - const providerMetadata = response.trace - ? { bedrock: { trace: response.trace as JSONObject } } - : undefined; + const providerMetadata = + response.trace || response.usage + ? { + bedrock: { + ...(response.trace && typeof response.trace === 'object' + ? { trace: response.trace as JSONObject } + : {}), + ...(response.usage && { + usage: { + cacheReadInputTokens: + response.usage?.cacheReadInputTokens ?? Number.NaN, + cacheWriteInputTokens: + response.usage?.cacheWriteInputTokens ?? Number.NaN, + }, + }), + }, + } + : undefined; return { text: @@ -327,10 +343,32 @@ export class BedrockChatLanguageModel implements LanguageModelV1 { value.metadata.usage?.outputTokens ?? Number.NaN, }; - if (value.metadata.trace) { + const cacheUsage = + value.metadata.usage?.cacheReadInputTokens != null || + value.metadata.usage?.cacheWriteInputTokens != null + ? { + usage: { + cacheReadInputTokens: + value.metadata.usage?.cacheReadInputTokens ?? + Number.NaN, + cacheWriteInputTokens: + value.metadata.usage?.cacheWriteInputTokens ?? + Number.NaN, + }, + } + : undefined; + + const trace = value.metadata.trace + ? { + trace: value.metadata.trace as JSONObject, + } + : undefined; + + if (cacheUsage || trace) { providerMetadata = { bedrock: { - trace: value.metadata.trace as JSONObject, + ...cacheUsage, + ...trace, }, }; } @@ -455,6 +493,8 @@ const BedrockResponseSchema = z.object({ inputTokens: z.number(), outputTokens: z.number(), totalTokens: z.number(), + cacheReadInputTokens: z.number().nullish(), + cacheWriteInputTokens: z.number().nullish(), }), }); @@ -499,6 +539,8 @@ const BedrockStreamSchema = z.object({ trace: z.unknown().nullish(), usage: z .object({ + cacheReadInputTokens: z.number().nullish(), + cacheWriteInputTokens: z.number().nullish(), inputTokens: z.number(), outputTokens: z.number(), }) diff --git a/packages/amazon-bedrock/src/bedrock-chat-prompt.ts b/packages/amazon-bedrock/src/bedrock-chat-prompt.ts deleted file mode 100644 index c792f67517e1..000000000000 --- a/packages/amazon-bedrock/src/bedrock-chat-prompt.ts +++ /dev/null @@ -1,20 +0,0 @@ -import { BedrockContentBlock } from './bedrock-api-types'; - -export type BedrockMessagesPrompt = { - system?: string; - messages: BedrockMessages; -}; - -export type BedrockMessages = Array; - -export type BedrockMessage = BedrockUserMessage | BedrockAssistantMessage; - -export interface BedrockUserMessage { - role: 'user'; - content: Array; -} - -export interface BedrockAssistantMessage { - role: 'assistant'; - content: Array; -} diff --git a/packages/amazon-bedrock/src/convert-to-bedrock-chat-messages.test.ts b/packages/amazon-bedrock/src/convert-to-bedrock-chat-messages.test.ts index 10a115ef4e85..4bb12ae1b420 100644 --- a/packages/amazon-bedrock/src/convert-to-bedrock-chat-messages.test.ts +++ b/packages/amazon-bedrock/src/convert-to-bedrock-chat-messages.test.ts @@ -7,7 +7,7 @@ describe('system messages', () => { { role: 'system', content: 'World' }, ]); - expect(system).toEqual('Hello\nWorld'); + expect(system).toEqual([{ text: 'Hello' }, { text: 'World' }]); }); it('should throw an error if a system message is provided after a non-system message', async () => { @@ -18,6 +18,21 @@ describe('system messages', () => { ]), ).toThrowError(); }); + + it('should set isSystemCachePoint when system message has cache point', async () => { + const result = convertToBedrockChatMessages([ + { + role: 'system', + content: 'Hello', + providerMetadata: { bedrock: { cachePoint: { type: 'default' } } }, + }, + ]); + + expect(result).toEqual({ + system: [{ text: 'Hello' }, { cachePoint: { type: 'default' } }], + messages: [], + }); + }); }); describe('user messages', () => { @@ -76,7 +91,27 @@ describe('user messages', () => { }, ]); - expect(system).toEqual('Hello'); + expect(system).toEqual([{ text: 'Hello' }]); + }); + + it('should add cache point to user message content when specified', async () => { + const result = convertToBedrockChatMessages([ + { + role: 'user', + content: [{ type: 'text', text: 'Hello' }], + providerMetadata: { bedrock: { cachePoint: { type: 'default' } } }, + }, + ]); + + expect(result).toEqual({ + messages: [ + { + role: 'user', + content: [{ text: 'Hello' }, { cachePoint: { type: 'default' } }], + }, + ], + system: [], + }); }); }); @@ -104,7 +139,7 @@ describe('assistant messages', () => { content: [{ text: 'assistant content' }], }, ], - system: undefined, + system: [], }); }); @@ -134,7 +169,7 @@ describe('assistant messages', () => { content: [{ text: 'assistant ' }, { text: 'content' }], }, ], - system: undefined, + system: [], }); }); @@ -169,7 +204,7 @@ describe('assistant messages', () => { content: [{ text: 'user content 2' }], }, ], - system: undefined, + system: [], }); }); @@ -189,7 +224,27 @@ describe('assistant messages', () => { content: [{ text: 'Hello' }, { text: 'World' }, { text: '!' }], }, ], - system: undefined, + system: [], + }); + }); + + it('should add cache point to assistant message content when specified', async () => { + const result = convertToBedrockChatMessages([ + { + role: 'assistant', + content: [{ type: 'text', text: 'Hello' }], + providerMetadata: { bedrock: { cachePoint: { type: 'default' } } }, + }, + ]); + + expect(result).toEqual({ + messages: [ + { + role: 'assistant', + content: [{ text: 'Hello' }, { cachePoint: { type: 'default' } }], + }, + ], + system: [], }); }); }); diff --git a/packages/amazon-bedrock/src/convert-to-bedrock-chat-messages.ts b/packages/amazon-bedrock/src/convert-to-bedrock-chat-messages.ts index 4b29a5e905da..a7ecc2edaa35 100644 --- a/packages/amazon-bedrock/src/convert-to-bedrock-chat-messages.ts +++ b/packages/amazon-bedrock/src/convert-to-bedrock-chat-messages.ts @@ -1,29 +1,41 @@ +import { + BEDROCK_CACHE_POINT, + BedrockAssistantMessage, + BedrockCachePoint, + BedrockDocumentFormat, + BedrockImageFormat, + BedrockMessages, + BedrockSystemMessages, + BedrockUserMessage, +} from './bedrock-api-types'; import { JSONObject, LanguageModelV1Message, LanguageModelV1Prompt, + LanguageModelV1ProviderMetadata, UnsupportedFunctionalityError, } from '@ai-sdk/provider'; import { - createIdGenerator, convertUint8ArrayToBase64, + createIdGenerator, } from '@ai-sdk/provider-utils'; -import { BedrockDocumentFormat, BedrockImageFormat } from './bedrock-api-types'; -import { - BedrockAssistantMessage, - BedrockMessagesPrompt, - BedrockUserMessage, -} from './bedrock-chat-prompt'; const generateFileId = createIdGenerator({ prefix: 'file', size: 16 }); -export function convertToBedrockChatMessages( - prompt: LanguageModelV1Prompt, -): BedrockMessagesPrompt { +function getCachePoint( + providerMetadata: LanguageModelV1ProviderMetadata | undefined, +): BedrockCachePoint | undefined { + return providerMetadata?.bedrock?.cachePoint as BedrockCachePoint | undefined; +} + +export function convertToBedrockChatMessages(prompt: LanguageModelV1Prompt): { + system: BedrockSystemMessages; + messages: BedrockMessages; +} { const blocks = groupIntoBlocks(prompt); - let system: string | undefined = undefined; - const messages: BedrockMessagesPrompt['messages'] = []; + let system: BedrockSystemMessages = []; + const messages: BedrockMessages = []; for (let i = 0; i < blocks.length; i++) { const block = blocks[i]; @@ -39,7 +51,12 @@ export function convertToBedrockChatMessages( }); } - system = block.messages.map(({ content }) => content).join('\n'); + for (const message of block.messages) { + system.push({ text: message.content }); + if (getCachePoint(message.providerMetadata)) { + system.push(BEDROCK_CACHE_POINT); + } + } break; } @@ -48,7 +65,7 @@ export function convertToBedrockChatMessages( const bedrockContent: BedrockUserMessage['content'] = []; for (const message of block.messages) { - const { role, content } = message; + const { role, content, providerMetadata } = message; switch (role) { case 'user': { for (let j = 0; j < content.length; j++) { @@ -130,6 +147,10 @@ export function convertToBedrockChatMessages( throw new Error(`Unsupported role: ${_exhaustiveCheck}`); } } + + if (getCachePoint(providerMetadata)) { + bedrockContent.push(BEDROCK_CACHE_POINT); + } } messages.push({ role: 'user', content: bedrockContent }); @@ -176,6 +197,9 @@ export function convertToBedrockChatMessages( } } } + if (getCachePoint(message.providerMetadata)) { + bedrockContent.push(BEDROCK_CACHE_POINT); + } } messages.push({ role: 'assistant', content: bedrockContent }); @@ -190,10 +214,7 @@ export function convertToBedrockChatMessages( } } - return { - system, - messages, - }; + return { system, messages }; } type SystemBlock = {