Skip to content

Commit

Permalink
✨ feat: Azure Vision Support & Docs Update (danny-avila#1389)
Browse files Browse the repository at this point in the history
* feat(AzureOpenAI): Vision Support

* chore(ci/OpenAIClient.test): update test to reflect Azure now uses chatCompletion method as opposed to getCompletion, while still testing the latter method

* docs: update documentation mainly revolving around Azure setup, but also reformatting the 'Tokens and API' section completely

* docs: add images and links to ai_setup.md

* docs: ai setup reference
  • Loading branch information
danny-avila authored Dec 18, 2023
1 parent cc51465 commit c694e6a
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 19 deletions.
11 changes: 10 additions & 1 deletion app/clients/OpenAIClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ class OpenAIClient extends BaseClient {
let streamResult = null;
this.modelOptions.user = this.user;
const invalidBaseUrl = this.completionsUrl && extractBaseURL(this.completionsUrl) === null;
const useOldMethod = !!(this.azure || invalidBaseUrl || !this.isChatCompletion);
const useOldMethod = !!(invalidBaseUrl || !this.isChatCompletion);
if (typeof opts.onProgress === 'function' && useOldMethod) {
await this.getCompletion(
payload,
Expand Down Expand Up @@ -764,6 +764,15 @@ ${convo}
modelOptions.max_tokens = 4000;
}

if (this.azure || this.options.azure) {
// Azure does not accept `model` in the body, so we need to remove it.
delete modelOptions.model;

opts.baseURL = this.azureEndpoint.split('/chat')[0];
opts.defaultQuery = { 'api-version': this.azure.azureOpenAIApiVersion };
opts.defaultHeaders = { ...opts.defaultHeaders, 'api-key': this.apiKey };
}

let chatCompletion;
const openai = new OpenAI({
apiKey: this.apiKey,
Expand Down
161 changes: 143 additions & 18 deletions app/clients/specs/OpenAIClient.test.js
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
require('dotenv').config();
const OpenAI = require('openai');
const { fetchEventSource } = require('@waylaidwanderer/fetch-event-source');
const { genAzureChatCompletion } = require('~/utils/azureUtils');
const OpenAIClient = require('../OpenAIClient');
Expand Down Expand Up @@ -41,6 +42,97 @@ jest.mock('langchain/chat_models/openai', () => {
};
});

jest.mock('openai');

jest.spyOn(OpenAI, 'constructor').mockImplementation(function (...options) {
// We can add additional logic here if needed
return new OpenAI(...options);
});

const finalChatCompletion = jest.fn().mockResolvedValue({
choices: [
{
message: { role: 'assistant', content: 'Mock message content' },
finish_reason: 'Mock finish reason',
},
],
});

const stream = jest.fn().mockImplementation(() => {
let isDone = false;
let isError = false;
let errorCallback = null;

const onEventHandlers = {
abort: () => {
// Mock abort behavior
},
error: (callback) => {
errorCallback = callback; // Save the error callback for later use
},
finalMessage: (callback) => {
callback({ role: 'assistant', content: 'Mock Response' });
isDone = true; // Set stream to done
},
};

const mockStream = {
on: jest.fn((event, callback) => {
if (onEventHandlers[event]) {
onEventHandlers[event](callback);
}
return mockStream;
}),
finalChatCompletion,
controller: { abort: jest.fn() },
triggerError: () => {
isError = true;
if (errorCallback) {
errorCallback(new Error('Mock error'));
}
},
[Symbol.asyncIterator]: () => {
return {
next: () => {
if (isError) {
return Promise.reject(new Error('Mock error'));
}
if (isDone) {
return Promise.resolve({ done: true });
}
const chunk = { choices: [{ delta: { content: 'Mock chunk' } }] };
return Promise.resolve({ value: chunk, done: false });
},
};
},
};
return mockStream;
});

const create = jest.fn().mockResolvedValue({
choices: [
{
message: { content: 'Mock message content' },
finish_reason: 'Mock finish reason',
},
],
});

OpenAI.mockImplementation(() => ({
beta: {
chat: {
completions: {
stream,
},
},
},
chat: {
completions: {
create,
},
},
}));

describe('OpenAIClient', () => {
let client, client2;
const model = 'gpt-4';
Expand Down Expand Up @@ -456,45 +548,78 @@ describe('OpenAIClient', () => {
});
});

describe('sendMessage/getCompletion', () => {
describe('sendMessage/getCompletion/chatCompletion', () => {
afterEach(() => {
delete process.env.AZURE_OPENAI_DEFAULT_MODEL;
delete process.env.AZURE_USE_MODEL_AS_DEPLOYMENT_NAME;
delete process.env.OPENROUTER_API_KEY;
});

it('[Azure OpenAI] should call getCompletion and fetchEventSource with correct args', async () => {
// Set a default model
process.env.AZURE_OPENAI_DEFAULT_MODEL = 'gpt4-turbo';

it('should call getCompletion and fetchEventSource when using a text/instruct model', async () => {
const model = 'text-davinci-003';
const onProgress = jest.fn().mockImplementation(() => ({}));
client.azure = defaultAzureOptions;
const getCompletion = jest.spyOn(client, 'getCompletion');
await client.sendMessage('Hi mom!', {
replaceOptions: true,

const testClient = new OpenAIClient('test-api-key', {
...defaultOptions,
onProgress,
azure: defaultAzureOptions,
modelOptions: { model },
});

const getCompletion = jest.spyOn(testClient, 'getCompletion');
await testClient.sendMessage('Hi mom!', { onProgress });

expect(getCompletion).toHaveBeenCalled();
expect(getCompletion.mock.calls.length).toBe(1);
expect(getCompletion.mock.calls[0][0][0].role).toBe('user');
expect(getCompletion.mock.calls[0][0][0].content).toBe('Hi mom!');
expect(getCompletion.mock.calls[0][0]).toBe(
'||>Instructions:\nYou are ChatGPT, a large language model trained by OpenAI. Respond conversationally.\nCurrent date: December 18, 2023\n\n||>User:\nHi mom!\n||>Assistant:\n',
);

expect(fetchEventSource).toHaveBeenCalled();
expect(fetchEventSource.mock.calls.length).toBe(1);

// Check if the first argument (url) is correct
const expectedURL = genAzureChatCompletion(defaultAzureOptions);
const firstCallArgs = fetchEventSource.mock.calls[0];

const expectedURL = 'https://api.openai.com/v1/completions';
expect(firstCallArgs[0]).toBe(expectedURL);
// Should not have model in the deployment name
expect(firstCallArgs[0]).not.toContain('gpt4-turbo');

// Should not include the model in request body
const requestBody = JSON.parse(firstCallArgs[1].body);
expect(requestBody).not.toHaveProperty('model');
expect(requestBody).toHaveProperty('model');
expect(requestBody.model).toBe(model);
});

it('[Azure OpenAI] should call chatCompletion and OpenAI.stream with correct args', async () => {
// Set a default model
process.env.AZURE_OPENAI_DEFAULT_MODEL = 'gpt4-turbo';

const onProgress = jest.fn().mockImplementation(() => ({}));
client.azure = defaultAzureOptions;
const chatCompletion = jest.spyOn(client, 'chatCompletion');
await client.sendMessage('Hi mom!', {
replaceOptions: true,
...defaultOptions,
modelOptions: { model: 'gpt4-turbo', stream: true },
onProgress,
azure: defaultAzureOptions,
});

expect(chatCompletion).toHaveBeenCalled();
expect(chatCompletion.mock.calls.length).toBe(1);

const chatCompletionArgs = chatCompletion.mock.calls[0][0];
const { payload } = chatCompletionArgs;

expect(payload[0].role).toBe('user');
expect(payload[0].content).toBe('Hi mom!');

// Azure OpenAI does not use the model property, and will error if it's passed
// This check ensures the model property is not present
const streamArgs = stream.mock.calls[0][0];
expect(streamArgs).not.toHaveProperty('model');

// Check if the baseURL is correct
const constructorArgs = OpenAI.mock.calls[0][0];
const expectedURL = genAzureChatCompletion(defaultAzureOptions).split('/chat')[0];
expect(constructorArgs.baseURL).toBe(expectedURL);
});
});
});

0 comments on commit c694e6a

Please sign in to comment.