Skip to content

Commit

Permalink
🔧 Fix: Resolve Anthropic Client Issues 🧠 (#1226)
Browse files Browse the repository at this point in the history
* fix: correct preset title for Anthropic endpoint

* fix(Settings/Anthropic): show correct default value for LLM temperature

* fix(AnthropicClient): use `getModelMaxTokens` to get the correct LLM max context tokens, correctly set default temperature to 1, use only 2 params for class constructor, use `getResponseSender` to add correct sender to response message

* refactor(/api/ask|edit/anthropic): save messages to database after the final response is sent to the client, and do not save conversation from route controller

* fix(initializeClient/anthropic): correctly pass client options (endpointOption) to class initialization

* feat(ModelService/Anthropic): add claude-1.2
  • Loading branch information
danny-avila authored Nov 26, 2023
1 parent 4b28964 commit d7ef459
Show file tree
Hide file tree
Showing 9 changed files with 73 additions and 45 deletions.
44 changes: 24 additions & 20 deletions api/app/clients/AnthropicClient.js
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
// const { Agent, ProxyAgent } = require('undici');
const BaseClient = require('./BaseClient');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const Anthropic = require('@anthropic-ai/sdk');
const { encoding_for_model: encodingForModel, get_encoding: getEncoding } = require('tiktoken');
const { getResponseSender, EModelEndpoint } = require('~/server/routes/endpoints/schemas');
const { getModelMaxTokens } = require('~/utils');
const BaseClient = require('./BaseClient');

const HUMAN_PROMPT = '\n\nHuman:';
const AI_PROMPT = '\n\nAssistant:';

const tokenizersCache = {};

class AnthropicClient extends BaseClient {
constructor(apiKey, options = {}, cacheOptions = {}, baseURL) {
super(apiKey, options, cacheOptions);
constructor(apiKey, options = {}) {
super(apiKey, options);
this.apiKey = apiKey || process.env.ANTHROPIC_API_KEY;
this.sender = 'Anthropic';
if (baseURL) {
this.baseURL = baseURL;
}
this.userLabel = HUMAN_PROMPT;
this.assistantLabel = AI_PROMPT;
this.setOptions(options);
Expand Down Expand Up @@ -43,13 +40,13 @@ class AnthropicClient extends BaseClient {
...modelOptions,
// set some good defaults (check for undefined in some cases because they may be 0)
model: modelOptions.model || 'claude-1',
temperature: typeof modelOptions.temperature === 'undefined' ? 0.7 : modelOptions.temperature, // 0 - 1, 0.7 is recommended
temperature: typeof modelOptions.temperature === 'undefined' ? 1 : modelOptions.temperature, // 0 - 1, 1 is default
topP: typeof modelOptions.topP === 'undefined' ? 0.7 : modelOptions.topP, // 0 - 1, default: 0.7
topK: typeof modelOptions.topK === 'undefined' ? 40 : modelOptions.topK, // 1-40, default: 40
stop: modelOptions.stop, // no stop method for now
};

this.maxContextTokens = this.options.maxContextTokens || 99999;
this.maxContextTokens = getModelMaxTokens(this.modelOptions.model) ?? 100000;
this.maxResponseTokens = this.modelOptions.maxOutputTokens || 1500;
this.maxPromptTokens =
this.options.maxPromptTokens || this.maxContextTokens - this.maxResponseTokens;
Expand All @@ -62,6 +59,14 @@ class AnthropicClient extends BaseClient {
);
}

this.sender =
this.options.sender ??
getResponseSender({
model: this.modelOptions.model,
endpoint: EModelEndpoint.anthropic,
modelLabel: this.options.modelLabel,
});

this.startToken = '||>';
this.endToken = '';
this.gptEncoder = this.constructor.getTokenizer('cl100k_base');
Expand All @@ -81,16 +86,15 @@ class AnthropicClient extends BaseClient {
}

getClient() {
if (this.baseURL) {
return new Anthropic({
apiKey: this.apiKey,
baseURL: this.baseURL,
});
} else {
return new Anthropic({
apiKey: this.apiKey,
});
const options = {
apiKey: this.apiKey,
};

if (this.options.reverseProxyUrl) {
options.baseURL = this.options.reverseProxyUrl;
}

return new Anthropic(options);
}

async buildMessages(messages, parentMessageId) {
Expand Down
17 changes: 6 additions & 11 deletions api/server/routes/ask/anthropic.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ const {
setHeaders,
validateEndpoint,
buildEndpointOption,
} = require('../../middleware');
const { saveMessage, getConvoTitle, saveConvo, getConvo } = require('../../../models');
const { sendMessage, createOnProgress } = require('../../utils');
} = require('~/server/middleware');
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
const { sendMessage, createOnProgress } = require('~/server/utils');

router.post('/abort', handleAbort());

Expand Down Expand Up @@ -109,14 +109,6 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
response.parentMessageId = overrideParentMessageId;
}

await saveConvo(user, {
...endpointOption,
...endpointOption.modelOptions,
conversationId,
endpoint: 'anthropic',
});

await saveMessage({ ...response, user });
sendMessage(res, {
title: await getConvoTitle(user, conversationId),
final: true,
Expand All @@ -126,6 +118,9 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
});
res.end();

await saveMessage({ ...response, user });
await saveMessage(userMessage);

// TODO: add anthropic titling
} catch (error) {
const partialText = getPartialText();
Expand Down
10 changes: 6 additions & 4 deletions api/server/routes/edit/anthropic.js
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ const {
setHeaders,
validateEndpoint,
buildEndpointOption,
} = require('../../middleware');
const { saveMessage, getConvoTitle, getConvo } = require('../../../models');
const { sendMessage, createOnProgress } = require('../../utils');
} = require('~/server/middleware');
const { saveMessage, getConvoTitle, getConvo } = require('~/models');
const { sendMessage, createOnProgress } = require('~/server/utils');

router.post('/abort', handleAbort());

Expand Down Expand Up @@ -119,7 +119,6 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
response.parentMessageId = overrideParentMessageId;
}

await saveMessage({ ...response, user });
sendMessage(res, {
title: await getConvoTitle(user, conversationId),
final: true,
Expand All @@ -129,6 +128,9 @@ router.post('/', validateEndpoint, buildEndpointOption, setHeaders, async (req,
});
res.end();

await saveMessage({ ...response, user });
await saveMessage(userMessage);

// TODO: add anthropic titling
} catch (error) {
const partialText = getPartialText();
Expand Down
21 changes: 13 additions & 8 deletions api/server/routes/endpoints/anthropic/initializeClient.js
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
const { AnthropicClient } = require('../../../../app');
const { getUserKey, checkUserKeyExpiry } = require('../../../services/UserService');
const { AnthropicClient } = require('~/app');
const { getUserKey, checkUserKeyExpiry } = require('~/server/services/UserService');

const initializeClient = async ({ req, res }) => {
const ANTHROPIC_API_KEY = process.env.ANTHROPIC_API_KEY;
const initializeClient = async ({ req, res, endpointOption }) => {
const { ANTHROPIC_API_KEY, ANTHROPIC_REVERSE_PROXY } = process.env;
const expiresAt = req.body.key;
const isUserProvided = ANTHROPIC_API_KEY === 'user_provided';

let anthropicApiKey = isUserProvided ? await getAnthropicUserKey(req.user.id) : ANTHROPIC_API_KEY;
let reverseProxy = process.env.ANTHROPIC_REVERSE_PROXY || undefined;
console.log('ANTHROPIC_REVERSE_PROXY', reverseProxy);
const anthropicApiKey = isUserProvided
? await getAnthropicUserKey(req.user.id)
: ANTHROPIC_API_KEY;

if (expiresAt && isUserProvided) {
checkUserKeyExpiry(
Expand All @@ -17,7 +17,12 @@ const initializeClient = async ({ req, res }) => {
);
}

const client = new AnthropicClient(anthropicApiKey, { req, res }, {}, reverseProxy);
const client = new AnthropicClient(anthropicApiKey, {
req,
res,
reverseProxyUrl: ANTHROPIC_REVERSE_PROXY ?? null,
...endpointOption,
});

return {
client,
Expand Down
1 change: 1 addition & 0 deletions api/server/services/ModelService.js
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ const getAnthropicModels = () => {
let models = [
'claude-2.1',
'claude-2',
'claude-1.2',
'claude-1',
'claude-1-100k',
'claude-instant-1',
Expand Down
2 changes: 2 additions & 0 deletions api/utils/tokens.js
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@ const maxTokensMap = {
'gpt-3.5-turbo-16k-0613': 15999,
'gpt-3.5-turbo-1106': 16380, // -5 from max
'gpt-4-1106': 127995, // -5 from max
'claude-2.1': 200000,
'claude-': 100000,
};

/**
Expand Down
19 changes: 19 additions & 0 deletions api/utils/tokens.spec.js
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,25 @@ describe('getModelMaxTokens', () => {
expect(getModelMaxTokens('gpt-4-1106-preview')).toBe(maxTokensMap['gpt-4-1106']);
expect(getModelMaxTokens('gpt-4-1106-vision-preview')).toBe(maxTokensMap['gpt-4-1106']);
});

test('should return correct tokens for Anthropic models', () => {
const models = [
'claude-2.1',
'claude-2',
'claude-1.2',
'claude-1',
'claude-1-100k',
'claude-instant-1',
'claude-instant-1-100k',
];

const claude21MaxTokens = maxTokensMap['claude-2.1'];
const claudeMaxTokens = maxTokensMap['claude-'];
models.forEach((model) => {
const expectedTokens = model === 'claude-2.1' ? claude21MaxTokens : claudeMaxTokens;
expect(getModelMaxTokens(model)).toEqual(expectedTokens);
});
});
});

describe('matchModelName', () => {
Expand Down
2 changes: 1 addition & 1 deletion client/src/components/Endpoints/Settings/Anthropic.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ export default function Settings({ conversation, setOption, models, readonly }:
<div className="flex justify-between">
<Label htmlFor="temp-int" className="text-left text-sm font-medium">
{localize('com_endpoint_temperature')}{' '}
<small className="opacity-40">({localize('com_endpoint_default')}: 0.2)</small>
<small className="opacity-40">({localize('com_endpoint_default')}: 1)</small>
</Label>
<InputNumber
id="temp-int"
Expand Down
2 changes: 1 addition & 1 deletion client/src/utils/presets.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ export const getPresetTitle = (preset: TPreset) => {
if (model) {
_title += `: ${model}`;
}
} else if (endpoint === EModelEndpoint.google) {
} else if (endpoint === EModelEndpoint.google || endpoint === EModelEndpoint.anthropic) {
if (modelLabel) {
_title = modelLabel;
}
Expand Down

0 comments on commit d7ef459

Please sign in to comment.