Skip to content

Commit

Permalink
Allow setting default LLM args on instance level
Browse files Browse the repository at this point in the history
Closes #20
  • Loading branch information
yamalight committed Oct 30, 2024
1 parent d76c46b commit 686c604
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 12 deletions.
8 changes: 4 additions & 4 deletions packages/litlytics/engine/runPrompt.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import type { CoreMessage, CoreTool } from 'ai';
import type { CoreMessage } from 'ai';
import type { LLMProviders } from '../litlytics';
import { executeOnLLM } from '../llm/llm';
import type { LLMModel, LLMProvider, LLMRequest } from '../llm/types';
import type { LLMArgs, LLMModel, LLMProvider, LLMRequest } from '../llm/types';

export interface RunPromptFromMessagesArgs {
provider: LLMProviders;
key: string;
model: LLMModel;
messages: CoreMessage[];
args?: Record<string, CoreTool>;
args?: LLMArgs;
}
export const runPromptFromMessages = async ({
provider,
Expand Down Expand Up @@ -36,7 +36,7 @@ export interface RunPromptArgs {
model: LLMModel;
system: string;
user: string;
args?: Record<string, CoreTool>;
args?: LLMArgs;
}
export const runPrompt = async ({
provider,
Expand Down
17 changes: 14 additions & 3 deletions packages/litlytics/litlytics.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {
import { runStep, type RunStepArgs } from './engine/runStep';
import { runLLMStep, type RunLLMStepArgs } from './engine/step/runLLMStep';
import { testPipelineStep } from './engine/testStep';
import type { LLMModel, LLMProvider } from './llm/types';
import type { LLMArgs, LLMModel, LLMProvider } from './llm/types';
import { OUTPUT_ID } from './output/Output';
import {
pipelineFromText,
Expand Down Expand Up @@ -38,6 +38,7 @@ export { modelCosts } from './llm/costs';
export {
LLMModelsList,
LLMProvidersList,
type LLMArgs,
type LLMModel,
type LLMProvider,
} from './llm/types';
Expand Down Expand Up @@ -70,6 +71,7 @@ export class LitLytics {
// model config
provider?: LLMProviders;
model?: LLMModel;
llmArgs?: LLMArgs;
#llmKey?: string;

// pipeline
Expand All @@ -82,14 +84,17 @@ export class LitLytics {
provider,
model,
key,
llmArgs,
}: {
provider: LLMProviders;
model: LLMModel;
key: string;
llmArgs?: LLMArgs;
}) {
this.provider = provider;
this.model = model;
this.#llmKey = key;
this.llmArgs = llmArgs;
}

/**
Expand Down Expand Up @@ -179,7 +184,10 @@ export class LitLytics {
key: this.#llmKey,
model: this.model,
messages,
args,
args: {
...args,
...this.llmArgs,
},
});
};

Expand All @@ -202,7 +210,10 @@ export class LitLytics {
model: this.model,
system,
user,
args,
args: {
...args,
...this.llmArgs,
},
});
};

Expand Down
6 changes: 4 additions & 2 deletions packages/litlytics/llm/types.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import type { CoreMessage, CoreTool } from 'ai';
import type { CoreMessage, generateText } from 'ai';

export const LLMProvidersList = [
'openai',
Expand Down Expand Up @@ -40,10 +40,12 @@ export const LLMModelsList = {
export type LLMModel =
(typeof LLMModelsList)[keyof typeof LLMModelsList][number];

export type LLMArgs = Partial<Parameters<typeof generateText>[0]>;

export interface LLMRequest {
provider: LLMProvider;
key: string;
model: LLMModel;
messages: CoreMessage[];
modelArgs?: Record<string, CoreTool>;
modelArgs?: LLMArgs;
}
5 changes: 3 additions & 2 deletions packages/litlytics/step/Step.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import type { CoreTool, LanguageModelUsage } from 'ai';
import type { LanguageModelUsage } from 'ai';
import type { Doc } from '../doc/Document';
import type { LLMArgs } from '../llm/types';

export interface StepResult {
stepId: string;
Expand Down Expand Up @@ -46,7 +47,7 @@ export interface ProcessingStep extends BaseStep {
input?: StepInput;
// llm
prompt?: string;
llmArgs?: Record<string, CoreTool>;
llmArgs?: LLMArgs;
// code
code?: string;
codeExplanation?: string;
Expand Down
55 changes: 54 additions & 1 deletion packages/litlytics/test/litlytics.test.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import type { LanguageModelUsage } from 'ai';
import { expect, test } from 'vitest';
import { expect, test, vi } from 'vitest';
import * as run from '../engine/runPrompt';
import {
LitLytics,
OUTPUT_ID,
type Doc,
type LLMArgs,
type Pipeline,
type StepInput,
} from '../litlytics';
Expand Down Expand Up @@ -229,3 +231,54 @@ test('should generate suggested tasks for current pipeline', async () => {
const newNonTestDoc = litlytics.docs.find((d) => d.id === docNonTest.id);
expect(newNonTestDoc?.summary).toBeUndefined();
});

test('should pass llm args when running prompt', async () => {
const testArgs: LLMArgs = {
temperature: 0.5,
maxTokens: 1000,
};
const litlytics = new LitLytics({
provider: 'openai',
model: 'test',
key: 'test',
llmArgs: testArgs,
});
litlytics.pipeline.pipelineDescription = 'test description';

const testResult = `Step name: Generate Title and Description
Step type: llm
Step input: doc
Step description: Generate an Etsy product title and description based on the provided document describing the product.
---
Step name: Check for Copyrighted Terms
Step type: llm
Step input: result
Step description: Analyze the generated title and description for possible copyrighted terms and suggest edits.
`;

// mock prompt replies
const spy = vi
.spyOn(run, 'runPrompt')
.mockImplementation(
async ({
user,
args,
}: {
system: string;
user: string;
args?: LLMArgs;
}) => {
expect(args).toEqual(testArgs);
expect(user).toEqual('test description');
return { result: testResult, usage: {} as LanguageModelUsage };
}
);
// run generation
await litlytics.generatePipeline();
// check that spy was called
expect(spy).toHaveBeenCalled();
// cleanup
spy.mockClear();
});

0 comments on commit 686c604

Please sign in to comment.