Skip to content

Commit

Permalink
Support all Ollama model options (#2448)
Browse files Browse the repository at this point in the history
* support all ollama api/generate params

* changed casing of options to be more consistent

* fixed inconsistant case

* add missing model options to ollama chat model
  • Loading branch information
Basti-an authored Aug 31, 2023
1 parent 3a6aa88 commit 2ef694d
Show file tree
Hide file tree
Showing 3 changed files with 171 additions and 0 deletions.
68 changes: 68 additions & 0 deletions langchain/src/chat_models/ollama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,48 @@ export class ChatOllama extends SimpleChatModel implements OllamaInput {

baseUrl = "http://localhost:11434";

embeddingOnly?: boolean;

f16KV?: boolean;

frequencyPenalty?: number;

logitsAll?: boolean;

lowVram?: boolean;

mainGpu?: number;

mirostat?: number;

mirostatEta?: number;

mirostatTau?: number;

numBatch?: number;

numCtx?: number;

numGpu?: number;

numGqa?: number;

numKeep?: number;

numThread?: number;

penalizeNewline?: boolean;

presencePenalty?: number;

repeatLastN?: number;

repeatPenalty?: number;

ropeFrequencyBase?: number;

ropeFrequencyScale?: number;

temperature?: number;

stop?: string[];
Expand All @@ -59,25 +85,50 @@ export class ChatOllama extends SimpleChatModel implements OllamaInput {

topP?: number;

typicalP?: number;

useMLock?: boolean;

useMMap?: boolean;

vocabOnly?: boolean;

constructor(fields: OllamaInput & BaseChatModelParams) {
super(fields);
this.model = fields.model ?? this.model;
this.baseUrl = fields.baseUrl?.endsWith("/")
? fields.baseUrl.slice(0, -1)
: fields.baseUrl ?? this.baseUrl;
this.embeddingOnly = fields.embeddingOnly;
this.f16KV = fields.f16KV;
this.frequencyPenalty = fields.frequencyPenalty;
this.logitsAll = fields.logitsAll;
this.lowVram = fields.lowVram;
this.mainGpu = fields.mainGpu;
this.mirostat = fields.mirostat;
this.mirostatEta = fields.mirostatEta;
this.mirostatTau = fields.mirostatTau;
this.numBatch = fields.numBatch;
this.numCtx = fields.numCtx;
this.numGpu = fields.numGpu;
this.numGqa = fields.numGqa;
this.numKeep = fields.numKeep;
this.numThread = fields.numThread;
this.penalizeNewline = fields.penalizeNewline;
this.presencePenalty = fields.presencePenalty;
this.repeatLastN = fields.repeatLastN;
this.repeatPenalty = fields.repeatPenalty;
this.ropeFrequencyBase = fields.ropeFrequencyBase;
this.ropeFrequencyScale = fields.ropeFrequencyScale;
this.temperature = fields.temperature;
this.stop = fields.stop;
this.tfsZ = fields.tfsZ;
this.topK = fields.topK;
this.topP = fields.topP;
this.typicalP = fields.typicalP;
this.useMLock = fields.useMLock;
this.useMMap = fields.useMMap;
this.vocabOnly = fields.vocabOnly;
}

_llmType() {
Expand All @@ -94,19 +145,36 @@ export class ChatOllama extends SimpleChatModel implements OllamaInput {
return {
model: this.model,
options: {
embedding_only: this.embeddingOnly,
f16_kv: this.f16KV,
frequency_penalty: this.frequencyPenalty,
logits_all: this.logitsAll,
low_vram: this.lowVram,
main_gpu: this.mainGpu,
mirostat: this.mirostat,
mirostat_eta: this.mirostatEta,
mirostat_tau: this.mirostatTau,
num_batch: this.numBatch,
num_ctx: this.numCtx,
num_gpu: this.numGpu,
num_gqa: this.numGqa,
num_keep: this.numKeep,
num_thread: this.numThread,
penalize_newline: this.penalizeNewline,
presence_penalty: this.presencePenalty,
repeat_last_n: this.repeatLastN,
repeat_penalty: this.repeatPenalty,
rope_frequency_base: this.ropeFrequencyBase,
rope_frequency_scale: this.ropeFrequencyScale,
temperature: this.temperature,
stop: options?.stop ?? this.stop,
tfs_z: this.tfsZ,
top_k: this.topK,
top_p: this.topP,
typical_p: this.typicalP,
use_mlock: this.useMLock,
use_mmap: this.useMMap,
vocab_only: this.vocabOnly,
},
};
}
Expand Down
69 changes: 69 additions & 0 deletions langchain/src/llms/ollama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,48 @@ export class Ollama extends LLM implements OllamaInput {

baseUrl = "http://localhost:11434";

embeddingOnly?: boolean;

f16KV?: boolean;

frequencyPenalty?: number;

logitsAll?: boolean;

lowVram?: boolean;

mainGpu?: number;

mirostat?: number;

mirostatEta?: number;

mirostatTau?: number;

numBatch?: number;

numCtx?: number;

numGpu?: number;

numGqa?: number;

numKeep?: number;

numThread?: number;

penalizeNewline?: boolean;

presencePenalty?: number;

repeatLastN?: number;

repeatPenalty?: number;

ropeFrequencyBase?: number;

ropeFrequencyScale?: number;

temperature?: number;

stop?: string[];
Expand All @@ -50,25 +76,51 @@ export class Ollama extends LLM implements OllamaInput {

topP?: number;

typicalP?: number;

useMLock?: boolean;

useMMap?: boolean;

vocabOnly?: boolean;

constructor(fields: OllamaInput & BaseLLMParams) {
super(fields);
this.model = fields.model ?? this.model;
this.baseUrl = fields.baseUrl?.endsWith("/")
? fields.baseUrl.slice(0, -1)
: fields.baseUrl ?? this.baseUrl;

this.embeddingOnly = fields.embeddingOnly;
this.f16KV = fields.f16KV;
this.frequencyPenalty = fields.frequencyPenalty;
this.logitsAll = fields.logitsAll;
this.lowVram = fields.lowVram;
this.mainGpu = fields.mainGpu;
this.mirostat = fields.mirostat;
this.mirostatEta = fields.mirostatEta;
this.mirostatTau = fields.mirostatTau;
this.numBatch = fields.numBatch;
this.numCtx = fields.numCtx;
this.numGpu = fields.numGpu;
this.numGqa = fields.numGqa;
this.numKeep = fields.numKeep;
this.numThread = fields.numThread;
this.penalizeNewline = fields.penalizeNewline;
this.presencePenalty = fields.presencePenalty;
this.repeatLastN = fields.repeatLastN;
this.repeatPenalty = fields.repeatPenalty;
this.ropeFrequencyBase = fields.ropeFrequencyBase;
this.ropeFrequencyScale = fields.ropeFrequencyScale;
this.temperature = fields.temperature;
this.stop = fields.stop;
this.tfsZ = fields.tfsZ;
this.topK = fields.topK;
this.topP = fields.topP;
this.typicalP = fields.typicalP;
this.useMLock = fields.useMLock;
this.useMMap = fields.useMMap;
this.vocabOnly = fields.vocabOnly;
}

_llmType() {
Expand All @@ -79,19 +131,36 @@ export class Ollama extends LLM implements OllamaInput {
return {
model: this.model,
options: {
embedding_only: this.embeddingOnly,
f16_kv: this.f16KV,
frequency_penalty: this.frequencyPenalty,
logits_all: this.logitsAll,
low_vram: this.lowVram,
main_gpu: this.mainGpu,
mirostat: this.mirostat,
mirostat_eta: this.mirostatEta,
mirostat_tau: this.mirostatTau,
num_batch: this.numBatch,
num_ctx: this.numCtx,
num_gpu: this.numGpu,
num_gqa: this.numGqa,
num_keep: this.numKeep,
num_thread: this.numThread,
penalize_newline: this.penalizeNewline,
presence_penalty: this.presencePenalty,
repeat_last_n: this.repeatLastN,
repeat_penalty: this.repeatPenalty,
rope_frequency_base: this.ropeFrequencyBase,
rope_frequency_scale: this.ropeFrequencyScale,
temperature: this.temperature,
stop: options?.stop ?? this.stop,
tfs_z: this.tfsZ,
top_k: this.topK,
top_p: this.topP,
typical_p: this.typicalP,
use_mlock: this.useMLock,
use_mmap: this.useMMap,
vocab_only: this.vocabOnly,
},
};
}
Expand Down
34 changes: 34 additions & 0 deletions langchain/src/util/ollama.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,40 +2,74 @@ import { BaseLanguageModelCallOptions } from "../base_language/index.js";
import { IterableReadableStream } from "./stream.js";

export interface OllamaInput {
embeddingOnly?: boolean;
f16KV?: boolean;
frequencyPenalty?: number;
logitsAll?: boolean;
lowVram?: boolean;
mainGpu?: number;
model?: string;
baseUrl?: string;
mirostat?: number;
mirostatEta?: number;
mirostatTau?: number;
numBatch?: number;
numCtx?: number;
numGpu?: number;
numGqa?: number;
numKeep?: number;
numThread?: number;
penalizeNewline?: boolean;
presencePenalty?: number;
repeatLastN?: number;
repeatPenalty?: number;
ropeFrequencyBase?: number;
ropeFrequencyScale?: number;
temperature?: number;
stop?: string[];
tfsZ?: number;
topK?: number;
topP?: number;
typicalP?: number;
useMLock?: boolean;
useMMap?: boolean;
vocabOnly?: boolean;
}

export interface OllamaRequestParams {
model: string;
prompt: string;
options: {
embedding_only?: boolean;
f16_kv?: boolean;
frequency_penalty?: number;
logits_all?: boolean;
low_vram?: boolean;
main_gpu?: number;
mirostat?: number;
mirostat_eta?: number;
mirostat_tau?: number;
num_batch?: number;
num_ctx?: number;
num_gpu?: number;
num_gqa?: number;
num_keep?: number;
num_thread?: number;
penalize_newline?: boolean;
presence_penalty?: number;
repeat_last_n?: number;
repeat_penalty?: number;
rope_frequency_base?: number;
rope_frequency_scale?: number;
temperature?: number;
stop?: string[];
tfs_z?: number;
top_k?: number;
top_p?: number;
typical_p?: number;
use_mlock?: boolean;
use_mmap?: boolean;
vocab_only?: boolean;
};
}

Expand Down

1 comment on commit 2ef694d

@vercel
Copy link

@vercel vercel bot commented on 2ef694d Aug 31, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.