From c399cb8006b19b4f7ac5a5368ffccdc6d259dcfa Mon Sep 17 00:00:00 2001 From: vi-presidio <194114285+vj-presidio@users.noreply.github.com> Date: Tue, 4 Feb 2025 17:53:31 +0530 Subject: [PATCH] feat: Add Eperimental Ollama embedding provider support * feat: Add Ollama embedding provider support * feat: Fetch and display Ollama embedding models in UI --- package-lock.json | 45 ++++++++++ package.json | 7 +- src/core/webview/ClineProvider.ts | 54 +++++++++++ src/embedding/index.ts | 6 +- src/embedding/providers/ollama.ts | 32 +++++++ .../code-prep/FindFilesToEditAgent.ts | 3 +- .../code-prep/VectorizeCodeAgent.ts | 5 +- src/integrations/code-prep/helper.ts | 11 +++ src/shared/ExtensionMessage.ts | 2 + src/shared/WebviewMessage.ts | 1 + src/shared/embeddings.ts | 8 +- .../components/settings/EmbeddingOptions.tsx | 89 ++++++++++++++++++- 12 files changed, 253 insertions(+), 10 deletions(-) create mode 100644 src/embedding/providers/ollama.ts diff --git a/package-lock.json b/package-lock.json index 9b3a507..7fa9ccc 100644 --- a/package-lock.json +++ b/package-lock.json @@ -16,6 +16,7 @@ "@langchain/aws": "^0.1.1", "@langchain/community": "^0.3.11", "@langchain/core": "^0.3.17", + "@langchain/ollama": "^0.1.5", "@langchain/openai": "^0.3.12", "@langchain/textsplitters": "^0.1.0", "@mistralai/mistralai": "^1.3.6", @@ -7283,6 +7284,35 @@ "zod": "^3.24.1" } }, + "node_modules/@langchain/ollama": { + "version": "0.1.5", + "resolved": "https://registry.npmjs.org/@langchain/ollama/-/ollama-0.1.5.tgz", + "integrity": "sha512-S2tF94uIJtXavekKg10LvTV+jIelOIrubaCnje8BopfiNOVcnzsSulUL4JH0wvdxMZq0vbE4/i9RwC2q9ivOmA==", + "license": "MIT", + "dependencies": { + "ollama": "^0.5.9", + "uuid": "^10.0.0" + }, + "engines": { + "node": ">=18" + }, + "peerDependencies": { + "@langchain/core": ">=0.2.21 <0.4.0" + } + }, + "node_modules/@langchain/ollama/node_modules/uuid": { + "version": "10.0.0", + "resolved": "https://registry.npmjs.org/uuid/-/uuid-10.0.0.tgz", + "integrity": "sha512-8XkAphELsDnEGrDxUOHB3RGvXz6TeuYSGEZBOjtTtPm2lwhGBjLgOzLHB63IUWfBpNucQjND6d3AOudO+H3RWQ==", + "funding": [ + "https://github.com/sponsors/broofa", + "https://github.com/sponsors/ctavan" + ], + "license": "MIT", + "bin": { + "uuid": "dist/bin/uuid" + } + }, "node_modules/@langchain/openai": { "version": "0.3.17", "resolved": "https://registry.npmjs.org/@langchain/openai/-/openai-0.3.17.tgz", @@ -15031,6 +15061,15 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/ollama": { + "version": "0.5.12", + "resolved": "https://registry.npmjs.org/ollama/-/ollama-0.5.12.tgz", + "integrity": "sha512-flVH1fn1c9NF7VV3bW9kSu0E+bYc40b4DxL/gS2Debhao35osJFRDiPOj9sIWTMvcyj78Paw1OuhfIe7uhDWfQ==", + "license": "MIT", + "dependencies": { + "whatwg-fetch": "^3.6.20" + } + }, "node_modules/once": { "version": "1.4.0", "resolved": "https://registry.npmjs.org/once/-/once-1.4.0.tgz", @@ -17900,6 +17939,12 @@ "node": ">=18" } }, + "node_modules/whatwg-fetch": { + "version": "3.6.20", + "resolved": "https://registry.npmjs.org/whatwg-fetch/-/whatwg-fetch-3.6.20.tgz", + "integrity": "sha512-EqhiFU6daOA8kpjOWTL0olhVOF3i7OrFzSYiGsEMB8GcXS+RrzauAERX65xMeNWVqxA6HXH2m69Z9LaKKdisfg==", + "license": "MIT" + }, "node_modules/whatwg-mimetype": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/whatwg-mimetype/-/whatwg-mimetype-4.0.0.tgz", diff --git a/package.json b/package.json index 4f9cde5..c279895 100644 --- a/package.json +++ b/package.json @@ -228,13 +228,14 @@ "@anthropic-ai/sdk": "^0.26.0", "@anthropic-ai/vertex-sdk": "^0.4.1", "@google/generative-ai": "^0.18.0", - "@mistralai/mistralai": "^1.3.6", - "@modelcontextprotocol/sdk": "^1.0.1", "@langchain/aws": "^0.1.1", "@langchain/community": "^0.3.11", "@langchain/core": "^0.3.17", + "@langchain/ollama": "^0.1.5", "@langchain/openai": "^0.3.12", "@langchain/textsplitters": "^0.1.0", + "@mistralai/mistralai": "^1.3.6", + "@modelcontextprotocol/sdk": "^1.0.1", "@types/clone-deep": "^4.0.4", "@types/get-folder-size": "^3.0.4", "@types/pdf-parse": "^1.1.4", @@ -247,8 +248,8 @@ "default-shell": "^2.2.0", "delay": "^6.0.0", "diff": "^5.2.0", - "faiss-node": "^0.5.1", "execa": "^9.5.2", + "faiss-node": "^0.5.1", "fast-deep-equal": "^3.1.3", "firebase": "^11.2.0", "get-folder-size": "^5.0.0", diff --git a/src/core/webview/ClineProvider.ts b/src/core/webview/ClineProvider.ts index 58bfeb4..d37b129 100644 --- a/src/core/webview/ClineProvider.ts +++ b/src/core/webview/ClineProvider.ts @@ -197,6 +197,8 @@ export class ClineProvider implements vscode.WebviewViewProvider { "embeddingAzureOpenAIApiInstanceName", "embeddingAzureOpenAIApiEmbeddingsDeploymentName", "embeddingAzureOpenAIApiVersion", + "embeddingOllamaBaseUrl", + "embeddingOllamaModelId", ] return customGlobalKeys.includes(key) } @@ -989,6 +991,13 @@ export class ClineProvider implements vscode.WebviewViewProvider { ollamaModels, }) break + case "requestOllamaEmbeddingModels": + const ollamaEmbeddingModels = await this.getOllamaEmbeddingModels(message.text) + this.postMessageToWebview({ + type: "ollamaEmbeddingModels", + ollamaEmbeddingModels, + }) + break case "requestLmStudioModels": const lmStudioModels = await this.getLmStudioModels(message.text) this.postMessageToWebview({ @@ -1125,6 +1134,8 @@ export class ClineProvider implements vscode.WebviewViewProvider { azureOpenAIApiInstanceName, azureOpenAIApiEmbeddingsDeploymentName, azureOpenAIApiVersion, + ollamaBaseUrl, + ollamaModelId, } = message.embeddingConfiguration // Update Global State @@ -1139,6 +1150,8 @@ export class ClineProvider implements vscode.WebviewViewProvider { "embeddingAzureOpenAIApiEmbeddingsDeploymentName", azureOpenAIApiEmbeddingsDeploymentName, ) + await this.customUpdateState("embeddingOllamaBaseUrl", ollamaBaseUrl) + await this.customUpdateState("embeddingOllamaModelId", ollamaModelId) // Update Secrets await this.customStoreSecret("embeddingAwsAccessKey", awsAccessKey) await this.customStoreSecret("embeddingAwsSecretKey", awsSecretKey) @@ -1365,6 +1378,41 @@ export class ClineProvider implements vscode.WebviewViewProvider { // Ollama + async getOllamaEmbeddingModels(baseUrl?: string) { + try { + if (!baseUrl) { + baseUrl = "http://localhost:11434" + } + if (!URL.canParse(baseUrl)) { + return [] + } + const response = await axios.get(`${baseUrl}/api/tags`) + const modelsArray = response.data?.models?.map((model: any) => model.name) || [] + const models = [...new Set(modelsArray)] + // TODO: Currently OLLAM local API doen't support diffrentiate between embedding and chat models + // so we are only considering models that have the following inclusion, as OLLAMA release new + // models this list has to be updated, or we have to wait for OLLAMA to support this natively. + // And diretctly fetching from the Public remote API is not also avaialble. + // https://ollama.com/search?c=embedding + const PUBLIC_KNOWN_MODELS = [ + "nomic-embed-text", + "mxbai-embed-large", + "snowflake-arctic-embed", + "bge-m3", + "all-minilm", + "bge-large", + "snowflake-arctic-embed2", + "paraphrase-multilingual", + "granite-embedding", + ] + return models.filter((model: string) => + PUBLIC_KNOWN_MODELS.some((known) => model.toLowerCase().includes(known.toLowerCase())), + ) + } catch (error) { + return [] + } + } + async getOllamaModels(baseUrl?: string) { try { if (!baseUrl) { @@ -1841,6 +1889,8 @@ export class ClineProvider implements vscode.WebviewViewProvider { azureOpenAIApiEmbeddingsDeploymentName, azureOpenAIApiVersion, isEmbeddingConfigurationValid, + embeddingOllamaBaseUrl, + embeddingOllamaModelId, ] = await Promise.all([ this.customGetState("apiProvider") as Promise, this.customGetState("apiModelId") as Promise, @@ -1897,6 +1947,8 @@ export class ClineProvider implements vscode.WebviewViewProvider { this.customGetState("embeddingAzureOpenAIApiEmbeddingsDeploymentName") as Promise, this.customGetState("embeddingAzureOpenAIApiVersion") as Promise, this.customGetState("isEmbeddingConfigurationValid") as Promise, + this.customGetState("embeddingOllamaBaseUrl") as Promise, + this.customGetState("embeddingOllamaModelId") as Promise, ]) let apiProvider: ApiProvider @@ -1967,6 +2019,8 @@ export class ClineProvider implements vscode.WebviewViewProvider { azureOpenAIApiEmbeddingsDeploymentName, azureOpenAIApiVersion, isEmbeddingConfigurationValid, + ollamaBaseUrl: embeddingOllamaBaseUrl, + ollamaModelId: embeddingOllamaModelId, }, lastShownAnnouncementId, customInstructions, diff --git a/src/embedding/index.ts b/src/embedding/index.ts index 862ddc3..2283f2a 100644 --- a/src/embedding/index.ts +++ b/src/embedding/index.ts @@ -4,9 +4,11 @@ import { OpenAiNativeEmbeddingHandler } from "./providers/openai-native" import { BedrockEmbeddings } from "@langchain/aws" import { AzureOpenAIEmbeddings, OpenAIEmbeddings } from "@langchain/openai" import { EmbeddingConfiguration } from "../shared/embeddings" +import type { OllamaEmbeddings } from "@langchain/ollama" +import { OllamaEmbeddingHandler } from "./providers/ollama" export interface EmbeddingHandler { - getClient(): BedrockEmbeddings | OpenAIEmbeddings | AzureOpenAIEmbeddings + getClient(): BedrockEmbeddings | OpenAIEmbeddings | AzureOpenAIEmbeddings | OllamaEmbeddings validateAPIKey(): Promise } @@ -19,6 +21,8 @@ export function buildEmbeddingHandler(configuration: EmbeddingConfiguration): Em return new OpenAiEmbeddingHandler(options) case "openai-native": return new OpenAiNativeEmbeddingHandler(options) + case "ollama": + return new OllamaEmbeddingHandler(options) default: throw new Error(`Unsupported embedding provider: ${provider}`) } diff --git a/src/embedding/providers/ollama.ts b/src/embedding/providers/ollama.ts new file mode 100644 index 0000000..ce3b947 --- /dev/null +++ b/src/embedding/providers/ollama.ts @@ -0,0 +1,32 @@ +import { BedrockEmbeddings } from "@langchain/aws" +import { OpenAIEmbeddings, AzureOpenAIEmbeddings } from "@langchain/openai" +import { EmbeddingHandler } from "../" +import { EmbeddingHandlerOptions } from "../../shared/embeddings" +import { OllamaEmbeddings } from "@langchain/ollama" + +export class OllamaEmbeddingHandler implements EmbeddingHandler { + private options: EmbeddingHandlerOptions + private client: OllamaEmbeddings + + constructor(options: EmbeddingHandlerOptions) { + this.options = options + this.client = new OllamaEmbeddings({ + model: this.options.ollamaModelId, + baseUrl: this.options.ollamaBaseUrl || "http://localhost:11434", + }) + } + + getClient() { + return this.client + } + + async validateAPIKey(): Promise { + try { + await this.client.embedQuery("Test") + return true + } catch (error) { + console.error("Error validating Ollama credentials: ", error) + return false + } + } +} diff --git a/src/integrations/code-prep/FindFilesToEditAgent.ts b/src/integrations/code-prep/FindFilesToEditAgent.ts index 88d25a8..bca49a2 100644 --- a/src/integrations/code-prep/FindFilesToEditAgent.ts +++ b/src/integrations/code-prep/FindFilesToEditAgent.ts @@ -15,12 +15,13 @@ import { basename, join } from "node:path" import { buildApiHandler } from "../../api" import { ensureFaissPlatformDeps } from "../../utils/faiss" import { EmbeddingConfiguration } from "../../shared/embeddings" +import { OllamaEmbeddings } from "@langchain/ollama" export class FindFilesToEditAgent { private srcFolder: string private llmApiConfig: ApiConfiguration private embeddingConfig: EmbeddingConfiguration - private embeddings: OpenAIEmbeddings | BedrockEmbeddings + private embeddings: OpenAIEmbeddings | BedrockEmbeddings | OllamaEmbeddings private vectorStore: FaissStore private task: string private buildContextOptions: HaiBuildContextOptions diff --git a/src/integrations/code-prep/VectorizeCodeAgent.ts b/src/integrations/code-prep/VectorizeCodeAgent.ts index 800c5ac..ba1d4b2 100644 --- a/src/integrations/code-prep/VectorizeCodeAgent.ts +++ b/src/integrations/code-prep/VectorizeCodeAgent.ts @@ -14,12 +14,13 @@ import { createHash } from "node:crypto" import { HaiBuildDefaults } from "../../shared/haiDefaults" import { EmbeddingConfiguration } from "../../shared/embeddings" import { fileExists } from "../../utils/runtime-downloader" +import { OllamaEmbeddings } from "@langchain/ollama" export class VectorizeCodeAgent extends EventEmitter { private srcFolder: string private abortController = new AbortController() - private embeddings: OpenAIEmbeddings | BedrockEmbeddings + private embeddings: OpenAIEmbeddings | BedrockEmbeddings | OllamaEmbeddings private vectorStore: FaissStore private buildContextOptions: HaiBuildContextOptions private contextDir: string @@ -225,7 +226,7 @@ export class VectorizeCodeAgent extends EventEmitter { const fileName = basename(codeFilePath) const textSplitter = new RecursiveCharacterTextSplitter({ - chunkSize: 8191, + chunkSize: this.embeddingConfig.provider !== "ollama" ? 8191 : 512, chunkOverlap: 0, }) const texts = await textSplitter.splitText(fileContent) diff --git a/src/integrations/code-prep/helper.ts b/src/integrations/code-prep/helper.ts index fdb7405..48a7adf 100644 --- a/src/integrations/code-prep/helper.ts +++ b/src/integrations/code-prep/helper.ts @@ -7,6 +7,7 @@ import { azureOpenAIApiVersion, EmbeddingConfiguration } from "../../shared/embe // @ts-ignore import walk from "ignore-walk" import ignore from "ignore" +import { OllamaEmbeddings } from "@langchain/ollama" /** * Recursively retrieves all code files from a given source folder, @@ -84,6 +85,16 @@ export function getEmbeddings(embeddingConfig: EmbeddingConfiguration) { }) } + case "ollama": { + if (!embeddingConfig.ollamaModelId) { + throw new Error("Ollama model ID is required") + } + return new OllamaEmbeddings({ + model: embeddingConfig.ollamaModelId, + baseUrl: embeddingConfig.ollamaBaseUrl || "http://localhost:11434", + }) + } + default: throw new Error(`Unsupported embedding provider: ${embeddingConfig.provider}`) } diff --git a/src/shared/ExtensionMessage.ts b/src/shared/ExtensionMessage.ts index fed49aa..7eb20c9 100644 --- a/src/shared/ExtensionMessage.ts +++ b/src/shared/ExtensionMessage.ts @@ -33,6 +33,7 @@ export interface ExtensionMessage { | "llmConfigValidation" | "embeddingConfigValidation" | "existingFiles" + | "ollamaEmbeddingModels" text?: string bool?: boolean action?: @@ -49,6 +50,7 @@ export interface ExtensionMessage { state?: ExtensionState images?: string[] ollamaModels?: string[] + ollamaEmbeddingModels?: string[] lmStudioModels?: string[] vsCodeLmModels?: { vendor?: string; family?: string; version?: string; id?: string }[] filePaths?: string[] diff --git a/src/shared/WebviewMessage.ts b/src/shared/WebviewMessage.ts index cf11b4c..dd6a40d 100644 --- a/src/shared/WebviewMessage.ts +++ b/src/shared/WebviewMessage.ts @@ -53,6 +53,7 @@ export interface WebviewMessage { | "uploadInstruction" | "deleteInstruction" | "fileInstructions" + | "requestOllamaEmbeddingModels" text?: string disabled?: boolean askResponse?: ClineAskResponse diff --git a/src/shared/embeddings.ts b/src/shared/embeddings.ts index ee3cea4..c7f9653 100644 --- a/src/shared/embeddings.ts +++ b/src/shared/embeddings.ts @@ -1,4 +1,4 @@ -export type EmbeddingProvider = "bedrock" | "openai-native" | "openai" +export type EmbeddingProvider = "bedrock" | "openai-native" | "openai" | "ollama" export interface EmbeddingHandlerOptions { modelId?: string @@ -17,6 +17,8 @@ export interface EmbeddingHandlerOptions { azureOpenAIApiEmbeddingsDeploymentName?: string azureOpenAIApiVersion?: string maxRetries?: number + ollamaBaseUrl?: string + ollamaModelId?: string } export type EmbeddingConfiguration = EmbeddingHandlerOptions & { @@ -74,6 +76,7 @@ export const embeddingProviderModels = { bedrock: bedrockEmbeddingModels, "openai-native": openAiNativeEmbeddingModels, openai: {}, + ollama: {}, } as const export const defaultEmbeddingConfigs: Record = { @@ -86,6 +89,9 @@ export const defaultEmbeddingConfigs: Record(null) const [isLoading, setIsLoading] = useState(false) const [validateEmbedding, setValidateEmbedding] = useState(undefined) + const [ollamaModels, setOllamaModels] = useState([]) useEffect(() => { if (!apiConfiguration || !buildContextOptions?.useSyncWithApi) return @@ -114,6 +115,32 @@ const EmbeddingOptions = ({ showModelOptions, showModelError = true, errorMessag return normalizeEmbeddingConfiguration(embeddingConfiguration) }, [embeddingConfiguration]) + // Poll ollama models + const requestLocalModels = useCallback(() => { + if (selectedProvider === "ollama") { + vscode.postMessage({ + type: "requestOllamaEmbeddingModels", + text: apiConfiguration?.ollamaBaseUrl, + }) + } + }, [selectedProvider, apiConfiguration?.ollamaBaseUrl]) + useEffect(() => { + if (selectedProvider === "ollama") { + requestLocalModels() + } + }, [selectedProvider, requestLocalModels]) + + useInterval(requestLocalModels, selectedProvider === "ollama" ? 2000 : null) + + const handleMessage = useCallback((event: MessageEvent) => { + const message: ExtensionMessage = event.data + if (message.type === "ollamaEmbeddingModels" && message.ollamaEmbeddingModels) { + setOllamaModels(message.ollamaEmbeddingModels) + } + }, []) + + useEvent("message", handleMessage) + useEffect(() => { setEmbeddingConfiguration({ ...embeddingConfiguration, @@ -141,10 +168,11 @@ const EmbeddingOptions = ({ showModelOptions, showModelError = true, errorMessag value={selectedProvider} onChange={handleInputChange("provider")} disabled={isLoading} - style={{ minWidth: 130, position: "relative" }}> + style={{ minWidth: 130, position: "relative", width: "100%" }}> AWS Bedrock OpenAI OpenAI Compatible + Ollama (experimental) @@ -297,6 +325,63 @@ const EmbeddingOptions = ({ showModelOptions, showModelError = true, errorMessag )} + {selectedProvider === "ollama" && ( +
+ + Base URL (optional) + +
+ + + Select a model... + {ollamaModels.map((modelId) => ( + + {modelId} + + ))} + +
+

+ Ollama allows you to run models locally on your computer. For instructions on how to get started, see + their + + quickstart guide. + + You can download list of supported embedding models from{" "} + + here. + +

+
+ )} + {selectedProvider && Object.keys(availableModels).length > 0 && (