Skip to content

Commit

Permalink
feat: Add Eperimental Ollama embedding provider support
Browse files Browse the repository at this point in the history
* feat: Add Ollama embedding provider support
* feat: Fetch and display Ollama embedding models in UI
  • Loading branch information
vj-presidio committed Feb 4, 2025
1 parent 2d63907 commit c399cb8
Show file tree
Hide file tree
Showing 12 changed files with 253 additions and 10 deletions.
45 changes: 45 additions & 0 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 4 additions & 3 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -228,13 +228,14 @@
"@anthropic-ai/sdk": "^0.26.0",
"@anthropic-ai/vertex-sdk": "^0.4.1",
"@google/generative-ai": "^0.18.0",
"@mistralai/mistralai": "^1.3.6",
"@modelcontextprotocol/sdk": "^1.0.1",
"@langchain/aws": "^0.1.1",
"@langchain/community": "^0.3.11",
"@langchain/core": "^0.3.17",
"@langchain/ollama": "^0.1.5",
"@langchain/openai": "^0.3.12",
"@langchain/textsplitters": "^0.1.0",
"@mistralai/mistralai": "^1.3.6",
"@modelcontextprotocol/sdk": "^1.0.1",
"@types/clone-deep": "^4.0.4",
"@types/get-folder-size": "^3.0.4",
"@types/pdf-parse": "^1.1.4",
Expand All @@ -247,8 +248,8 @@
"default-shell": "^2.2.0",
"delay": "^6.0.0",
"diff": "^5.2.0",
"faiss-node": "^0.5.1",
"execa": "^9.5.2",
"faiss-node": "^0.5.1",
"fast-deep-equal": "^3.1.3",
"firebase": "^11.2.0",
"get-folder-size": "^5.0.0",
Expand Down
54 changes: 54 additions & 0 deletions src/core/webview/ClineProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
"embeddingAzureOpenAIApiInstanceName",
"embeddingAzureOpenAIApiEmbeddingsDeploymentName",
"embeddingAzureOpenAIApiVersion",
"embeddingOllamaBaseUrl",
"embeddingOllamaModelId",
]
return customGlobalKeys.includes(key)
}
Expand Down Expand Up @@ -989,6 +991,13 @@ export class ClineProvider implements vscode.WebviewViewProvider {
ollamaModels,
})
break
case "requestOllamaEmbeddingModels":
const ollamaEmbeddingModels = await this.getOllamaEmbeddingModels(message.text)
this.postMessageToWebview({
type: "ollamaEmbeddingModels",
ollamaEmbeddingModels,
})
break
case "requestLmStudioModels":
const lmStudioModels = await this.getLmStudioModels(message.text)
this.postMessageToWebview({
Expand Down Expand Up @@ -1125,6 +1134,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
azureOpenAIApiInstanceName,
azureOpenAIApiEmbeddingsDeploymentName,
azureOpenAIApiVersion,
ollamaBaseUrl,
ollamaModelId,
} = message.embeddingConfiguration

// Update Global State
Expand All @@ -1139,6 +1150,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
"embeddingAzureOpenAIApiEmbeddingsDeploymentName",
azureOpenAIApiEmbeddingsDeploymentName,
)
await this.customUpdateState("embeddingOllamaBaseUrl", ollamaBaseUrl)
await this.customUpdateState("embeddingOllamaModelId", ollamaModelId)
// Update Secrets
await this.customStoreSecret("embeddingAwsAccessKey", awsAccessKey)
await this.customStoreSecret("embeddingAwsSecretKey", awsSecretKey)
Expand Down Expand Up @@ -1365,6 +1378,41 @@ export class ClineProvider implements vscode.WebviewViewProvider {

// Ollama

async getOllamaEmbeddingModels(baseUrl?: string) {
try {
if (!baseUrl) {
baseUrl = "http://localhost:11434"
}
if (!URL.canParse(baseUrl)) {
return []
}
const response = await axios.get(`${baseUrl}/api/tags`)
const modelsArray = response.data?.models?.map((model: any) => model.name) || []
const models = [...new Set<string>(modelsArray)]
// TODO: Currently OLLAM local API doen't support diffrentiate between embedding and chat models
// so we are only considering models that have the following inclusion, as OLLAMA release new
// models this list has to be updated, or we have to wait for OLLAMA to support this natively.
// And diretctly fetching from the Public remote API is not also avaialble.
// https://ollama.com/search?c=embedding
const PUBLIC_KNOWN_MODELS = [
"nomic-embed-text",
"mxbai-embed-large",
"snowflake-arctic-embed",
"bge-m3",
"all-minilm",
"bge-large",
"snowflake-arctic-embed2",
"paraphrase-multilingual",
"granite-embedding",
]
return models.filter((model: string) =>
PUBLIC_KNOWN_MODELS.some((known) => model.toLowerCase().includes(known.toLowerCase())),
)
} catch (error) {
return []
}
}

async getOllamaModels(baseUrl?: string) {
try {
if (!baseUrl) {
Expand Down Expand Up @@ -1841,6 +1889,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
azureOpenAIApiEmbeddingsDeploymentName,
azureOpenAIApiVersion,
isEmbeddingConfigurationValid,
embeddingOllamaBaseUrl,
embeddingOllamaModelId,
] = await Promise.all([
this.customGetState("apiProvider") as Promise<ApiProvider | undefined>,
this.customGetState("apiModelId") as Promise<string | undefined>,
Expand Down Expand Up @@ -1897,6 +1947,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
this.customGetState("embeddingAzureOpenAIApiEmbeddingsDeploymentName") as Promise<string | undefined>,
this.customGetState("embeddingAzureOpenAIApiVersion") as Promise<string | undefined>,
this.customGetState("isEmbeddingConfigurationValid") as Promise<boolean | undefined>,
this.customGetState("embeddingOllamaBaseUrl") as Promise<string | undefined>,
this.customGetState("embeddingOllamaModelId") as Promise<string | undefined>,
])

let apiProvider: ApiProvider
Expand Down Expand Up @@ -1967,6 +2019,8 @@ export class ClineProvider implements vscode.WebviewViewProvider {
azureOpenAIApiEmbeddingsDeploymentName,
azureOpenAIApiVersion,
isEmbeddingConfigurationValid,
ollamaBaseUrl: embeddingOllamaBaseUrl,
ollamaModelId: embeddingOllamaModelId,
},
lastShownAnnouncementId,
customInstructions,
Expand Down
6 changes: 5 additions & 1 deletion src/embedding/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@ import { OpenAiNativeEmbeddingHandler } from "./providers/openai-native"
import { BedrockEmbeddings } from "@langchain/aws"
import { AzureOpenAIEmbeddings, OpenAIEmbeddings } from "@langchain/openai"
import { EmbeddingConfiguration } from "../shared/embeddings"
import type { OllamaEmbeddings } from "@langchain/ollama"
import { OllamaEmbeddingHandler } from "./providers/ollama"

export interface EmbeddingHandler {
getClient(): BedrockEmbeddings | OpenAIEmbeddings | AzureOpenAIEmbeddings
getClient(): BedrockEmbeddings | OpenAIEmbeddings | AzureOpenAIEmbeddings | OllamaEmbeddings
validateAPIKey(): Promise<boolean>
}

Expand All @@ -19,6 +21,8 @@ export function buildEmbeddingHandler(configuration: EmbeddingConfiguration): Em
return new OpenAiEmbeddingHandler(options)
case "openai-native":
return new OpenAiNativeEmbeddingHandler(options)
case "ollama":
return new OllamaEmbeddingHandler(options)
default:
throw new Error(`Unsupported embedding provider: ${provider}`)
}
Expand Down
32 changes: 32 additions & 0 deletions src/embedding/providers/ollama.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import { BedrockEmbeddings } from "@langchain/aws"
import { OpenAIEmbeddings, AzureOpenAIEmbeddings } from "@langchain/openai"
import { EmbeddingHandler } from "../"
import { EmbeddingHandlerOptions } from "../../shared/embeddings"
import { OllamaEmbeddings } from "@langchain/ollama"

export class OllamaEmbeddingHandler implements EmbeddingHandler {
private options: EmbeddingHandlerOptions
private client: OllamaEmbeddings

constructor(options: EmbeddingHandlerOptions) {
this.options = options
this.client = new OllamaEmbeddings({
model: this.options.ollamaModelId,
baseUrl: this.options.ollamaBaseUrl || "http://localhost:11434",
})
}

getClient() {
return this.client
}

async validateAPIKey(): Promise<boolean> {
try {
await this.client.embedQuery("Test")
return true
} catch (error) {
console.error("Error validating Ollama credentials: ", error)
return false
}
}
}
3 changes: 2 additions & 1 deletion src/integrations/code-prep/FindFilesToEditAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ import { basename, join } from "node:path"
import { buildApiHandler } from "../../api"
import { ensureFaissPlatformDeps } from "../../utils/faiss"
import { EmbeddingConfiguration } from "../../shared/embeddings"
import { OllamaEmbeddings } from "@langchain/ollama"

export class FindFilesToEditAgent {
private srcFolder: string
private llmApiConfig: ApiConfiguration
private embeddingConfig: EmbeddingConfiguration
private embeddings: OpenAIEmbeddings | BedrockEmbeddings
private embeddings: OpenAIEmbeddings | BedrockEmbeddings | OllamaEmbeddings
private vectorStore: FaissStore
private task: string
private buildContextOptions: HaiBuildContextOptions
Expand Down
5 changes: 3 additions & 2 deletions src/integrations/code-prep/VectorizeCodeAgent.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@ import { createHash } from "node:crypto"
import { HaiBuildDefaults } from "../../shared/haiDefaults"
import { EmbeddingConfiguration } from "../../shared/embeddings"
import { fileExists } from "../../utils/runtime-downloader"
import { OllamaEmbeddings } from "@langchain/ollama"

export class VectorizeCodeAgent extends EventEmitter {
private srcFolder: string
private abortController = new AbortController()

private embeddings: OpenAIEmbeddings | BedrockEmbeddings
private embeddings: OpenAIEmbeddings | BedrockEmbeddings | OllamaEmbeddings
private vectorStore: FaissStore
private buildContextOptions: HaiBuildContextOptions
private contextDir: string
Expand Down Expand Up @@ -225,7 +226,7 @@ export class VectorizeCodeAgent extends EventEmitter {
const fileName = basename(codeFilePath)

const textSplitter = new RecursiveCharacterTextSplitter({
chunkSize: 8191,
chunkSize: this.embeddingConfig.provider !== "ollama" ? 8191 : 512,
chunkOverlap: 0,
})
const texts = await textSplitter.splitText(fileContent)
Expand Down
11 changes: 11 additions & 0 deletions src/integrations/code-prep/helper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import { azureOpenAIApiVersion, EmbeddingConfiguration } from "../../shared/embe
// @ts-ignore
import walk from "ignore-walk"
import ignore from "ignore"
import { OllamaEmbeddings } from "@langchain/ollama"

/**
* Recursively retrieves all code files from a given source folder,
Expand Down Expand Up @@ -84,6 +85,16 @@ export function getEmbeddings(embeddingConfig: EmbeddingConfiguration) {
})
}

case "ollama": {
if (!embeddingConfig.ollamaModelId) {
throw new Error("Ollama model ID is required")
}
return new OllamaEmbeddings({
model: embeddingConfig.ollamaModelId,
baseUrl: embeddingConfig.ollamaBaseUrl || "http://localhost:11434",
})
}

default:
throw new Error(`Unsupported embedding provider: ${embeddingConfig.provider}`)
}
Expand Down
2 changes: 2 additions & 0 deletions src/shared/ExtensionMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ export interface ExtensionMessage {
| "llmConfigValidation"
| "embeddingConfigValidation"
| "existingFiles"
| "ollamaEmbeddingModels"
text?: string
bool?: boolean
action?:
Expand All @@ -49,6 +50,7 @@ export interface ExtensionMessage {
state?: ExtensionState
images?: string[]
ollamaModels?: string[]
ollamaEmbeddingModels?: string[]
lmStudioModels?: string[]
vsCodeLmModels?: { vendor?: string; family?: string; version?: string; id?: string }[]
filePaths?: string[]
Expand Down
1 change: 1 addition & 0 deletions src/shared/WebviewMessage.ts
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ export interface WebviewMessage {
| "uploadInstruction"
| "deleteInstruction"
| "fileInstructions"
| "requestOllamaEmbeddingModels"
text?: string
disabled?: boolean
askResponse?: ClineAskResponse
Expand Down
8 changes: 7 additions & 1 deletion src/shared/embeddings.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export type EmbeddingProvider = "bedrock" | "openai-native" | "openai"
export type EmbeddingProvider = "bedrock" | "openai-native" | "openai" | "ollama"

export interface EmbeddingHandlerOptions {
modelId?: string
Expand All @@ -17,6 +17,8 @@ export interface EmbeddingHandlerOptions {
azureOpenAIApiEmbeddingsDeploymentName?: string
azureOpenAIApiVersion?: string
maxRetries?: number
ollamaBaseUrl?: string
ollamaModelId?: string
}

export type EmbeddingConfiguration = EmbeddingHandlerOptions & {
Expand Down Expand Up @@ -74,6 +76,7 @@ export const embeddingProviderModels = {
bedrock: bedrockEmbeddingModels,
"openai-native": openAiNativeEmbeddingModels,
openai: {},
ollama: {},
} as const

export const defaultEmbeddingConfigs: Record<EmbeddingProvider, { defaultModel: string }> = {
Expand All @@ -86,6 +89,9 @@ export const defaultEmbeddingConfigs: Record<EmbeddingProvider, { defaultModel:
openai: {
defaultModel: "",
},
ollama: {
defaultModel: "",
},
}

export const azureOpenAIApiVersion = "2023-05-15"
Loading

0 comments on commit c399cb8

Please sign in to comment.