Skip to content

Commit

Permalink
fix: support for chat.models override base configuration (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhengxs2018 authored Jun 19, 2024
1 parent 726e842 commit e2070d4
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 43 deletions.
23 changes: 17 additions & 6 deletions src/base/common/language-models/providers/TongyiProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,25 @@ export class TongyiLanguageModelProvider implements ILanguageModelProvider {
private _newLLM(options: { [name: string]: any }) {
const config = this.configService;

const {
model,
temperature,
maxTokens,
topP,
apiKey = config.get<string>('tongyi.apiKey'),
enableSearch = config.get('tongyi.enableSearch', true),
clientOptions = {},
} = options;

return new ChatAlibabaTongyi({
alibabaApiKey: config.get<string>('tongyi.apiKey'),
alibabaApiKey: apiKey,
streaming: true,
model: this._resolveChatModel(options.model),
temperature: options.temperature,
maxTokens: options.maxTokens,
topP: options.topP,
enableSearch: config.get('tongyi.enableSearch'),
model: this._resolveChatModel(model),
temperature: temperature,
maxTokens: maxTokens,
topP: topP,
enableSearch: enableSearch,
...clientOptions,
});
}

Expand Down
23 changes: 17 additions & 6 deletions src/base/common/language-models/providers/WenxinProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,14 +79,25 @@ export class WenxinLanguageModelProvider implements ILanguageModelProvider {
private _newLLM(options: { [name: string]: any }) {
const config = this.configService;

const {
model,
temperature,
penaltyScore,
topP,
apiKey = config.get<string>('qianfan.apiKey'),
secretKey = config.get('qianfan.secretKey'),
clientOptions = {},
} = options;

return new ChatBaiduWenxin({
baiduApiKey: config.get<string>('qianfan.apiKey'),
baiduSecretKey: config.get<string>('qianfan.secretKey'),
baiduApiKey: apiKey,
baiduSecretKey: secretKey,
streaming: true,
model: this._resolveChatModel(options.model),
temperature: options.temperature,
topP: options.topP,
penaltyScore: options.penaltyScore,
model: this._resolveChatModel(model),
temperature: temperature,
topP: topP,
penaltyScore: penaltyScore,
...clientOptions,
});
}

Expand Down
22 changes: 16 additions & 6 deletions src/base/common/language-models/providers/anthropicProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,25 @@ export class AnthropicLanguageModelProvider implements ILanguageModelProvider {

private _newLLM(options: { [name: string]: any }) {
const config = this.configService;
const {
model,
baseURL = config.get('anthropic.baseURL'),
apiKey = config.get('anthropic.apiKey'),
temperature,
maxTokens,
topP,
clientOptions = {}
} = options;

return new ChatAnthropic({
anthropicApiKey: config.get<string>('anthropic.apiKey'),
anthropicApiUrl: config.get<string>('anthropic.baseURL'),
anthropicApiKey: apiKey,
anthropicApiUrl: baseURL,
streaming: true,
model: this._resolveChatModel(options.model),
temperature: options.temperature,
maxTokens: options.maxTokens,
topP: options.topP,
temperature,
maxTokens,
topP,
model: this._resolveChatModel(model),
...clientOptions,
});
}

Expand Down
27 changes: 16 additions & 11 deletions src/base/common/language-models/providers/openaiProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -178,24 +178,29 @@ export class OpenAILanguageModelProvider implements ILanguageModelProvider {
}

private _newLLM(options: { [name: string]: any }) {
const { baseURL, apiKey, project, organization } = options;
const configService = this.configService;

const apiType = configService.get<'openai' | 'azure'>('openai.apiType');
const config = this.configService;
const {
baseURL = config.get('openai.baseURL'),
apiKey = config.get('openai.apiKey'),
project = config.get('openai.project'),
organization = config.get('openai.organization'),
deployment = organization,
apiType = config.get<'openai' | 'azure'>('openai.apiType'),
} = options;

if (apiType === 'azure') {
return new AzureOpenAI({
baseURL: baseURL || configService.get('openai.baseURL'),
apiKey: apiKey || configService.get('openai.apiKey'),
deployment: organization || configService.get('openai.organization'),
baseURL: baseURL,
apiKey: apiKey,
deployment: deployment,
});
}

return new OpenAI({
baseURL: configService.get('openai.baseURL'),
project: project || configService.get('openai.project'),
apiKey: apiKey || configService.get('openai.apiKey'),
organization: organization || configService.get('openai.organization'),
baseURL: baseURL,
project: project,
apiKey: apiKey,
organization: organization,
});
}

Expand Down
17 changes: 3 additions & 14 deletions src/editor/views/chat/continue/continueViewProvider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -235,17 +235,7 @@ export class ContinueViewProvider extends AbstractWebviewViewProvider implements

private async listModels(): Promise<IChatModelResource[]> {
const resources = this.configService.getConfig<IChatModelResource[]>('chat.models');
if (!resources) {
return [];
}

return resources.map(res => {
return {
title: res.title,
provider: res.provider,
model: res.model,
};
});
return resources ?? [];
}

readonly slashCommandsMap: Record<string, string> = {
Expand Down Expand Up @@ -324,16 +314,15 @@ export class ContinueViewProvider extends AbstractWebviewViewProvider implements
const models = await this.listModels();

const title = event.data.title;
const metadata = models.find(m => m.title === title);
const resource = models.find(m => m.title === title);

const completionOptions = event.data.completionOptions;

await this.lm.chat(
mapToChatMessages(event.data.messages),
{
...resource,
...completionOptions,
provider: metadata?.provider,
model: metadata?.model,
},
{
report(fragment) {
Expand Down

0 comments on commit e2070d4

Please sign in to comment.