Skip to content

Commit

Permalink
feat(embedding): implement ILanguageModelProvider in LocalEmbeddingsP…
Browse files Browse the repository at this point in the history
…rovider

The LocalEmbeddingsProvider class now implements the ILanguageModelProvider interface. This includes the addition of several methods such as provideChatResponse, provideCompletionResponse, provideEmbedDocuments, and provideEmbedQuery. Also, a log message was added to indicate the completion of LanceDB indexing.
  • Loading branch information
phodal committed Jul 5, 2024
1 parent 1d5c96f commit a7a1862
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/base/common/language-models/languageModelsService.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import { OllamaLanguageModelProvider } from './providers/ollamaProvider';
import { OpenAILanguageModelProvider } from './providers/openaiProvider';
import { TongyiLanguageModelProvider } from './providers/TongyiProvider';
import { WenxinLanguageModelProvider } from './providers/WenxinProvider';
import { LocalEmbeddingsProvider } from "../../../code-search/embedding/LocalEmbeddingsProvider";

export interface LanguageModelSelector {
/**
Expand Down Expand Up @@ -39,6 +40,7 @@ export class LanguageModelsService {
['tongyi', new TongyiLanguageModelProvider(configService)],
['ollama', new OllamaLanguageModelProvider(configService)],
['transformers', new HuggingFaceTransformersLanguageModelProvider(configService)],
// ['transformers', LocalEmbeddingsProvider.getInstance()],
]);
}

Expand Down
35 changes: 33 additions & 2 deletions src/code-search/embedding/LocalEmbeddingsProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import { Embedding } from './_base/Embedding';
import { EmbeddingsProvider } from './_base/EmbeddingsProvider';
import { mean_pooling, mergedTensor, normalize_, reshape, tensorData } from './_base/EmbeddingUtils';
import { logger } from 'base/common/log/log';
import type { ILanguageModelProvider } from "base/common/language-models/languageModels";
import { CancellationToken } from 'vscode';

// @ts-expect-error
const ortPromise = import('onnxruntime-node');
Expand All @@ -20,7 +22,7 @@ const InferenceSessionCreate = (...args: any[]) => {
*
* @deprecated Please use LanguageModelsService instead.
*/
export class LocalEmbeddingsProvider implements EmbeddingsProvider {
export class LocalEmbeddingsProvider implements EmbeddingsProvider, ILanguageModelProvider {
id: string = 'local';
env: any;
tokenizer: any;
Expand All @@ -30,12 +32,25 @@ export class LocalEmbeddingsProvider implements EmbeddingsProvider {

private static instance: LocalEmbeddingsProvider;

private constructor() {}
private constructor() {
}

identifier: string = 'local';

async provideChatResponse(): Promise<never> {
throw new Error('This method is not implemented');
}

async provideCompletionResponse(): Promise<never> {
throw new Error('This method is not implemented');
}

static getInstance(): LocalEmbeddingsProvider {
if (!LocalEmbeddingsProvider.instance) {
LocalEmbeddingsProvider.instance = new LocalEmbeddingsProvider();
LocalEmbeddingsProvider.instance.init();
}

return LocalEmbeddingsProvider.instance;
}

Expand Down Expand Up @@ -64,6 +79,22 @@ export class LocalEmbeddingsProvider implements EmbeddingsProvider {
logger.appendLine("'hello' text's first 10 values" + value[0].slice(0, 10).join(', '));
}

async provideEmbedDocuments(
texts: string[],
options: { [name: string]: any },
token?: CancellationToken,
): Promise<number[][]> {
return this.embed(texts);
}

async provideEmbedQuery(
input: string | string[],
options: { [name: string]: any },
token?: CancellationToken,
): Promise<number[]> {
return (await this.embed(Array.isArray(input) ? input : [input]))[0];
}

async embed(chunks: string[]): Promise<Embedding[]> {
if (chunks.length === 0) {
return [];
Expand Down
1 change: 1 addition & 0 deletions src/code-search/indexing/LanceDbIndex.ts
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ export class LanceDbIndex implements CodebaseIndex {
}

markComplete(results.del, IndexResultType.Delete);
logger.appendLine('Completed LanceDB Indexing');
yield { progress: 1, desc: 'Completed Calculating Embeddings' };
}

Expand Down
2 changes: 2 additions & 0 deletions src/extension.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ import { LanguageModelsService } from './base/common/language-models/languageMod
import { CommandsService } from './commands/commandsService';
import { ChatViewService } from './editor/views/chat/chatViewService';

(globalThis as any).self = globalThis;

export async function activate(context: ExtensionContext) {
const instantiationService = new InstantiationService();

Expand Down

0 comments on commit a7a1862

Please sign in to comment.