diff --git a/src/core/sliding-window/__tests__/sliding-window.test.ts b/src/core/sliding-window/__tests__/sliding-window.test.ts new file mode 100644 index 0000000000..182dea67f5 --- /dev/null +++ b/src/core/sliding-window/__tests__/sliding-window.test.ts @@ -0,0 +1,130 @@ +// npx jest src/core/sliding-window/__tests__/sliding-window.test.ts + +import { Anthropic } from "@anthropic-ai/sdk" + +import { ModelInfo } from "../../../shared/api" +import { truncateConversation, truncateConversationIfNeeded } from "../index" + +describe("truncateConversation", () => { + it("should retain the first message", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { role: "user", content: "First message" }, + { role: "assistant", content: "Second message" }, + { role: "user", content: "Third message" }, + ] + + const result = truncateConversation(messages, 0.5) + + // With 2 messages after the first, 0.5 fraction means remove 1 message + // But 1 is odd, so it rounds down to 0 (to make it even) + expect(result.length).toBe(3) // First message + 2 remaining messages + expect(result[0]).toEqual(messages[0]) + expect(result[1]).toEqual(messages[1]) + expect(result[2]).toEqual(messages[2]) + }) + + it("should remove the specified fraction of messages (rounded to even number)", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { role: "user", content: "First message" }, + { role: "assistant", content: "Second message" }, + { role: "user", content: "Third message" }, + { role: "assistant", content: "Fourth message" }, + { role: "user", content: "Fifth message" }, + ] + + // 4 messages excluding first, 0.5 fraction = 2 messages to remove + // 2 is already even, so no rounding needed + const result = truncateConversation(messages, 0.5) + + expect(result.length).toBe(3) + expect(result[0]).toEqual(messages[0]) + expect(result[1]).toEqual(messages[3]) + expect(result[2]).toEqual(messages[4]) + }) + + it("should round to an even number of messages to remove", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { role: "user", content: "First message" }, + { role: "assistant", content: "Second message" }, + { role: "user", content: "Third message" }, + { role: "assistant", content: "Fourth message" }, + { role: "user", content: "Fifth message" }, + { role: "assistant", content: "Sixth message" }, + { role: "user", content: "Seventh message" }, + ] + + // 6 messages excluding first, 0.3 fraction = 1.8 messages to remove + // 1.8 rounds down to 1, then to 0 to make it even + const result = truncateConversation(messages, 0.3) + + expect(result.length).toBe(7) // No messages removed + expect(result).toEqual(messages) + }) + + it("should handle edge case with fracToRemove = 0", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { role: "user", content: "First message" }, + { role: "assistant", content: "Second message" }, + { role: "user", content: "Third message" }, + ] + + const result = truncateConversation(messages, 0) + + expect(result).toEqual(messages) + }) + + it("should handle edge case with fracToRemove = 1", () => { + const messages: Anthropic.Messages.MessageParam[] = [ + { role: "user", content: "First message" }, + { role: "assistant", content: "Second message" }, + { role: "user", content: "Third message" }, + { role: "assistant", content: "Fourth message" }, + ] + + // 3 messages excluding first, 1.0 fraction = 3 messages to remove + // But 3 is odd, so it rounds down to 2 to make it even + const result = truncateConversation(messages, 1) + + expect(result.length).toBe(2) + expect(result[0]).toEqual(messages[0]) + expect(result[1]).toEqual(messages[3]) + }) +}) + +describe("truncateConversationIfNeeded", () => { + const createModelInfo = (contextWindow: number, supportsPromptCache: boolean, maxTokens?: number): ModelInfo => ({ + contextWindow, + supportsPromptCache, + maxTokens, + }) + + const messages: Anthropic.Messages.MessageParam[] = [ + { role: "user", content: "First message" }, + { role: "assistant", content: "Second message" }, + { role: "user", content: "Third message" }, + { role: "assistant", content: "Fourth message" }, + { role: "user", content: "Fifth message" }, + ] + + it("should not truncate if tokens are below threshold for prompt caching models", () => { + const modelInfo = createModelInfo(200000, true, 50000) + const totalTokens = 100000 // Below threshold + const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo) + expect(result).toEqual(messages) + }) + + it("should not truncate if tokens are below threshold for non-prompt caching models", () => { + const modelInfo = createModelInfo(200000, false) + const totalTokens = 100000 // Below threshold + const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo) + expect(result).toEqual(messages) + }) + + it("should use 80% of context window as threshold if it's greater than (contextWindow - buffer)", () => { + const modelInfo = createModelInfo(50000, true) // Small context window + const totalTokens = 40001 // Above 80% threshold (40000) + const mockResult = [messages[0], messages[3], messages[4]] + const result = truncateConversationIfNeeded(messages, totalTokens, modelInfo) + expect(result).toEqual(mockResult) + }) +}) diff --git a/src/core/sliding-window/index.ts b/src/core/sliding-window/index.ts index ee4a1543e7..d213f069f1 100644 --- a/src/core/sliding-window/index.ts +++ b/src/core/sliding-window/index.ts @@ -1,4 +1,5 @@ import { Anthropic } from "@anthropic-ai/sdk" + import { ModelInfo } from "../../shared/api" /** @@ -55,13 +56,15 @@ export function truncateConversationIfNeeded( /** * Calculates the maximum allowed tokens for models that support prompt caching. * - * The maximum is computed as the greater of (contextWindow - 40000) and 80% of the contextWindow. + * The maximum is computed as the greater of (contextWindow - buffer) and 80% of the contextWindow. * * @param {ModelInfo} modelInfo - The model information containing the context window size. * @returns {number} The maximum number of tokens allowed for prompt caching models. */ function getMaxTokensForPromptCachingModels(modelInfo: ModelInfo): number { - return Math.max(modelInfo.contextWindow - 40_000, modelInfo.contextWindow * 0.8) + // The buffer needs to be at least as large as `modelInfo.maxTokens`. + const buffer = modelInfo.maxTokens ? Math.max(40_000, modelInfo.maxTokens) : 40_000 + return Math.max(modelInfo.contextWindow - buffer, modelInfo.contextWindow * 0.8) } /** @@ -83,7 +86,9 @@ function getTruncFractionForPromptCachingModels(modelInfo: ModelInfo): number { * @returns {number} The maximum number of tokens allowed for non-prompt caching models. */ function getMaxTokensForNonPromptCachingModels(modelInfo: ModelInfo): number { - return Math.max(modelInfo.contextWindow - 40_000, modelInfo.contextWindow * 0.8) + // The buffer needs to be at least as large as `modelInfo.maxTokens`. + const buffer = modelInfo.maxTokens ? Math.max(40_000, modelInfo.maxTokens) : 40_000 + return Math.max(modelInfo.contextWindow - buffer, modelInfo.contextWindow * 0.8) } /** diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index b4819d9683..12beecb976 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -2126,6 +2126,7 @@ export class ClineProvider implements vscode.WebviewViewProvider { switch (rawModel.id) { case "anthropic/claude-3.7-sonnet": + case "anthropic/claude-3.7-sonnet:beta": case "anthropic/claude-3.5-sonnet": case "anthropic/claude-3.5-sonnet:beta": // NOTE: this needs to be synced with api.ts/openrouter default model info.