From 2ef694dfc87d316bb8d72be85ec1910ccf379258 Mon Sep 17 00:00:00 2001 From: Sebastian Wiendlocha Date: Thu, 31 Aug 2023 03:37:12 +0200 Subject: [PATCH] Support all Ollama model options (#2448) * support all ollama api/generate params * changed casing of options to be more consistent * fixed inconsistant case * add missing model options to ollama chat model --- langchain/src/chat_models/ollama.ts | 68 ++++++++++++++++++++++++++++ langchain/src/llms/ollama.ts | 69 +++++++++++++++++++++++++++++ langchain/src/util/ollama.ts | 34 ++++++++++++++ 3 files changed, 171 insertions(+) diff --git a/langchain/src/chat_models/ollama.ts b/langchain/src/chat_models/ollama.ts index 530e70222428..797d02882f6b 100644 --- a/langchain/src/chat_models/ollama.ts +++ b/langchain/src/chat_models/ollama.ts @@ -33,22 +33,48 @@ export class ChatOllama extends SimpleChatModel implements OllamaInput { baseUrl = "http://localhost:11434"; + embeddingOnly?: boolean; + + f16KV?: boolean; + + frequencyPenalty?: number; + + logitsAll?: boolean; + + lowVram?: boolean; + + mainGpu?: number; + mirostat?: number; mirostatEta?: number; mirostatTau?: number; + numBatch?: number; + numCtx?: number; numGpu?: number; + numGqa?: number; + + numKeep?: number; + numThread?: number; + penalizeNewline?: boolean; + + presencePenalty?: number; + repeatLastN?: number; repeatPenalty?: number; + ropeFrequencyBase?: number; + + ropeFrequencyScale?: number; + temperature?: number; stop?: string[]; @@ -59,25 +85,50 @@ export class ChatOllama extends SimpleChatModel implements OllamaInput { topP?: number; + typicalP?: number; + + useMLock?: boolean; + + useMMap?: boolean; + + vocabOnly?: boolean; + constructor(fields: OllamaInput & BaseChatModelParams) { super(fields); this.model = fields.model ?? this.model; this.baseUrl = fields.baseUrl?.endsWith("/") ? fields.baseUrl.slice(0, -1) : fields.baseUrl ?? this.baseUrl; + this.embeddingOnly = fields.embeddingOnly; + this.f16KV = fields.f16KV; + this.frequencyPenalty = fields.frequencyPenalty; + this.logitsAll = fields.logitsAll; + this.lowVram = fields.lowVram; + this.mainGpu = fields.mainGpu; this.mirostat = fields.mirostat; this.mirostatEta = fields.mirostatEta; this.mirostatTau = fields.mirostatTau; + this.numBatch = fields.numBatch; this.numCtx = fields.numCtx; this.numGpu = fields.numGpu; + this.numGqa = fields.numGqa; + this.numKeep = fields.numKeep; this.numThread = fields.numThread; + this.penalizeNewline = fields.penalizeNewline; + this.presencePenalty = fields.presencePenalty; this.repeatLastN = fields.repeatLastN; this.repeatPenalty = fields.repeatPenalty; + this.ropeFrequencyBase = fields.ropeFrequencyBase; + this.ropeFrequencyScale = fields.ropeFrequencyScale; this.temperature = fields.temperature; this.stop = fields.stop; this.tfsZ = fields.tfsZ; this.topK = fields.topK; this.topP = fields.topP; + this.typicalP = fields.typicalP; + this.useMLock = fields.useMLock; + this.useMMap = fields.useMMap; + this.vocabOnly = fields.vocabOnly; } _llmType() { @@ -94,19 +145,36 @@ export class ChatOllama extends SimpleChatModel implements OllamaInput { return { model: this.model, options: { + embedding_only: this.embeddingOnly, + f16_kv: this.f16KV, + frequency_penalty: this.frequencyPenalty, + logits_all: this.logitsAll, + low_vram: this.lowVram, + main_gpu: this.mainGpu, mirostat: this.mirostat, mirostat_eta: this.mirostatEta, mirostat_tau: this.mirostatTau, + num_batch: this.numBatch, num_ctx: this.numCtx, num_gpu: this.numGpu, + num_gqa: this.numGqa, + num_keep: this.numKeep, num_thread: this.numThread, + penalize_newline: this.penalizeNewline, + presence_penalty: this.presencePenalty, repeat_last_n: this.repeatLastN, repeat_penalty: this.repeatPenalty, + rope_frequency_base: this.ropeFrequencyBase, + rope_frequency_scale: this.ropeFrequencyScale, temperature: this.temperature, stop: options?.stop ?? this.stop, tfs_z: this.tfsZ, top_k: this.topK, top_p: this.topP, + typical_p: this.typicalP, + use_mlock: this.useMLock, + use_mmap: this.useMMap, + vocab_only: this.vocabOnly, }, }; } diff --git a/langchain/src/llms/ollama.ts b/langchain/src/llms/ollama.ts index f928993a9648..bbc28a0b6dde 100644 --- a/langchain/src/llms/ollama.ts +++ b/langchain/src/llms/ollama.ts @@ -24,22 +24,48 @@ export class Ollama extends LLM implements OllamaInput { baseUrl = "http://localhost:11434"; + embeddingOnly?: boolean; + + f16KV?: boolean; + + frequencyPenalty?: number; + + logitsAll?: boolean; + + lowVram?: boolean; + + mainGpu?: number; + mirostat?: number; mirostatEta?: number; mirostatTau?: number; + numBatch?: number; + numCtx?: number; numGpu?: number; + numGqa?: number; + + numKeep?: number; + numThread?: number; + penalizeNewline?: boolean; + + presencePenalty?: number; + repeatLastN?: number; repeatPenalty?: number; + ropeFrequencyBase?: number; + + ropeFrequencyScale?: number; + temperature?: number; stop?: string[]; @@ -50,25 +76,51 @@ export class Ollama extends LLM implements OllamaInput { topP?: number; + typicalP?: number; + + useMLock?: boolean; + + useMMap?: boolean; + + vocabOnly?: boolean; + constructor(fields: OllamaInput & BaseLLMParams) { super(fields); this.model = fields.model ?? this.model; this.baseUrl = fields.baseUrl?.endsWith("/") ? fields.baseUrl.slice(0, -1) : fields.baseUrl ?? this.baseUrl; + + this.embeddingOnly = fields.embeddingOnly; + this.f16KV = fields.f16KV; + this.frequencyPenalty = fields.frequencyPenalty; + this.logitsAll = fields.logitsAll; + this.lowVram = fields.lowVram; + this.mainGpu = fields.mainGpu; this.mirostat = fields.mirostat; this.mirostatEta = fields.mirostatEta; this.mirostatTau = fields.mirostatTau; + this.numBatch = fields.numBatch; this.numCtx = fields.numCtx; this.numGpu = fields.numGpu; + this.numGqa = fields.numGqa; + this.numKeep = fields.numKeep; this.numThread = fields.numThread; + this.penalizeNewline = fields.penalizeNewline; + this.presencePenalty = fields.presencePenalty; this.repeatLastN = fields.repeatLastN; this.repeatPenalty = fields.repeatPenalty; + this.ropeFrequencyBase = fields.ropeFrequencyBase; + this.ropeFrequencyScale = fields.ropeFrequencyScale; this.temperature = fields.temperature; this.stop = fields.stop; this.tfsZ = fields.tfsZ; this.topK = fields.topK; this.topP = fields.topP; + this.typicalP = fields.typicalP; + this.useMLock = fields.useMLock; + this.useMMap = fields.useMMap; + this.vocabOnly = fields.vocabOnly; } _llmType() { @@ -79,19 +131,36 @@ export class Ollama extends LLM implements OllamaInput { return { model: this.model, options: { + embedding_only: this.embeddingOnly, + f16_kv: this.f16KV, + frequency_penalty: this.frequencyPenalty, + logits_all: this.logitsAll, + low_vram: this.lowVram, + main_gpu: this.mainGpu, mirostat: this.mirostat, mirostat_eta: this.mirostatEta, mirostat_tau: this.mirostatTau, + num_batch: this.numBatch, num_ctx: this.numCtx, num_gpu: this.numGpu, + num_gqa: this.numGqa, + num_keep: this.numKeep, num_thread: this.numThread, + penalize_newline: this.penalizeNewline, + presence_penalty: this.presencePenalty, repeat_last_n: this.repeatLastN, repeat_penalty: this.repeatPenalty, + rope_frequency_base: this.ropeFrequencyBase, + rope_frequency_scale: this.ropeFrequencyScale, temperature: this.temperature, stop: options?.stop ?? this.stop, tfs_z: this.tfsZ, top_k: this.topK, top_p: this.topP, + typical_p: this.typicalP, + use_mlock: this.useMLock, + use_mmap: this.useMMap, + vocab_only: this.vocabOnly, }, }; } diff --git a/langchain/src/util/ollama.ts b/langchain/src/util/ollama.ts index 620c8f2fca3f..bce0b830317f 100644 --- a/langchain/src/util/ollama.ts +++ b/langchain/src/util/ollama.ts @@ -2,40 +2,74 @@ import { BaseLanguageModelCallOptions } from "../base_language/index.js"; import { IterableReadableStream } from "./stream.js"; export interface OllamaInput { + embeddingOnly?: boolean; + f16KV?: boolean; + frequencyPenalty?: number; + logitsAll?: boolean; + lowVram?: boolean; + mainGpu?: number; model?: string; baseUrl?: string; mirostat?: number; mirostatEta?: number; mirostatTau?: number; + numBatch?: number; numCtx?: number; numGpu?: number; + numGqa?: number; + numKeep?: number; numThread?: number; + penalizeNewline?: boolean; + presencePenalty?: number; repeatLastN?: number; repeatPenalty?: number; + ropeFrequencyBase?: number; + ropeFrequencyScale?: number; temperature?: number; stop?: string[]; tfsZ?: number; topK?: number; topP?: number; + typicalP?: number; + useMLock?: boolean; + useMMap?: boolean; + vocabOnly?: boolean; } export interface OllamaRequestParams { model: string; prompt: string; options: { + embedding_only?: boolean; + f16_kv?: boolean; + frequency_penalty?: number; + logits_all?: boolean; + low_vram?: boolean; + main_gpu?: number; mirostat?: number; mirostat_eta?: number; mirostat_tau?: number; + num_batch?: number; num_ctx?: number; num_gpu?: number; + num_gqa?: number; + num_keep?: number; num_thread?: number; + penalize_newline?: boolean; + presence_penalty?: number; repeat_last_n?: number; repeat_penalty?: number; + rope_frequency_base?: number; + rope_frequency_scale?: number; temperature?: number; stop?: string[]; tfs_z?: number; top_k?: number; top_p?: number; + typical_p?: number; + use_mlock?: boolean; + use_mmap?: boolean; + vocab_only?: boolean; }; }