Skip to content

Commit

Permalink
feat: support ollama and deep seek r1
Browse files Browse the repository at this point in the history
  • Loading branch information
zhu-xiaowei committed Jan 23, 2025
1 parent ae66c38 commit d53c06b
Show file tree
Hide file tree
Showing 14 changed files with 437 additions and 91 deletions.
35 changes: 24 additions & 11 deletions react-native/src/api/bedrock-api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import {
TextContent,
} from '../chat/util/BedrockMessageConvertor.ts';
import { invokeOpenAIWithCallBack } from './open-api.ts';
import { invokeOllamaWithCallBack } from './ollama-api.ts';

type CallbackFunction = (
result: string,
Expand All @@ -43,7 +44,8 @@ export const invokeBedrockWithCallBack = async (
) => {
const isDeepSeek = getTextModel().modelId.includes('deepseek');
const isOpenAI = getTextModel().modelId.includes('gpt');
if (chatMode === ChatMode.Text && (isDeepSeek || isOpenAI)) {
const isOllama = getTextModel().modelId.startsWith('ollama-');
if (chatMode === ChatMode.Text && (isDeepSeek || isOpenAI || isOllama)) {
if (isDeepSeek && getDeepSeekApiKey().length === 0) {
callback('Please configure your DeepSeek API Key', true, true);
return;
Expand All @@ -52,13 +54,23 @@ export const invokeBedrockWithCallBack = async (
callback('Please configure your OpenAI API Key', true, true);
return;
}
await invokeOpenAIWithCallBack(
messages,
prompt,
shouldStop,
controller,
callback
);
if (isOllama) {
await invokeOllamaWithCallBack(
messages,
prompt,
shouldStop,
controller,
callback
);
} else {
await invokeOpenAIWithCallBack(
messages,
prompt,
shouldStop,
controller,
callback
);
}
return;
}
if (!isConfigured()) {
Expand Down Expand Up @@ -88,7 +100,6 @@ export const invokeBedrockWithCallBack = async (
reactNative: { textStreaming: true },
};
const url = getApiPrefix() + '/converse';
let intervalId: ReturnType<typeof setInterval>;
let completeMessage = '';
const timeoutId = setTimeout(() => controller.abort(), 60000);
fetch(url!, options)
Expand Down Expand Up @@ -131,7 +142,6 @@ export const invokeBedrockWithCallBack = async (
}
})
.catch(error => {
clearInterval(intervalId);
if (shouldStop()) {
if (completeMessage === '') {
completeMessage = '...';
Expand Down Expand Up @@ -199,6 +209,7 @@ export const invokeBedrockWithCallBack = async (
};

export const requestAllModels = async (): Promise<AllModel> => {
const controller = new AbortController();
const url = getApiPrefix() + '/models';
const bodyObject = {
region: getRegion(),
Expand All @@ -213,16 +224,18 @@ export const requestAllModels = async (): Promise<AllModel> => {
body: JSON.stringify(bodyObject),
reactNative: { textStreaming: true },
};

const timeoutId = setTimeout(() => controller.abort(), 3000);
try {
const response = await fetch(url, options);
clearTimeout(timeoutId);
if (!response.ok) {
console.log(`HTTP error! status: ${response.status}`);
return { imageModel: [], textModel: [] };
}
return await response.json();
} catch (error) {
console.log('Error fetching models:', error);
clearTimeout(timeoutId);
return { imageModel: [], textModel: [] };
}
};
Expand Down
186 changes: 186 additions & 0 deletions react-native/src/api/ollama-api.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import { Model, OllamaModel, SystemPrompt, Usage } from '../types/Chat.ts';
import { getOllamaApiUrl, getTextModel } from '../storage/StorageUtils.ts';
import {
BedrockMessage,
ImageContent,
OpenAIMessage,
TextContent,
} from '../chat/util/BedrockMessageConvertor.ts';

type CallbackFunction = (
result: string,
complete: boolean,
needStop: boolean,
usage?: Usage
) => void;
export const invokeOllamaWithCallBack = async (
messages: BedrockMessage[],
prompt: SystemPrompt | null,
shouldStop: () => boolean,
controller: AbortController,
callback: CallbackFunction
) => {
const bodyObject = {
model: getTextModel().modelId.split('ollama-')[1],
messages: getOllamaMessages(messages, prompt),
};
console.log(JSON.stringify(bodyObject, null, 2));
const options = {
method: 'POST',
headers: {
accept: '*/*',
'content-type': 'application/json',
},
body: JSON.stringify(bodyObject),
signal: controller.signal,
reactNative: { textStreaming: true },
};
const url = getOllamaApiUrl() + '/api/chat';
let completeMessage = '';
const timeoutId = setTimeout(() => controller.abort(), 60000);
fetch(url!, options)
.then(response => {
return response.body;
})
.then(async body => {
clearTimeout(timeoutId);
if (!body) {
return;
}
const reader = body.getReader();
const decoder = new TextDecoder();
while (true) {
const { done, value } = await reader.read();
const chunk = decoder.decode(value, { stream: true });
if (!chunk) {
break;
}
const parsed = parseStreamData(chunk);
if (parsed.error) {
callback(parsed.error, true, true);
break;
}
completeMessage += parsed.content;
if (parsed.usage && parsed.usage.inputTokens) {
callback(completeMessage, true, false, parsed.usage);
} else {
callback(completeMessage, done, false);
}
if (done) {
break;
}
}
})
.catch(error => {
console.log(error);
clearTimeout(timeoutId);
if (shouldStop()) {
if (completeMessage === '') {
completeMessage = '...';
}
callback(completeMessage, true, true);
} else {
const errorMsg = String(error);
const errorInfo = 'Request error: ' + errorMsg;
callback(completeMessage + '\n\n' + errorInfo, true, true);
}
});
};

const parseStreamData = (chunk: string) => {
let content = '';
let usage: Usage | undefined;

try {
const parsedData: OllamaResponse = JSON.parse(chunk);

if (parsedData.message?.content) {
content = parsedData.message?.content;
}

if (parsedData.done) {
usage = {
modelName: getTextModel().modelName,
inputTokens: parsedData.prompt_eval_count,
outputTokens: parsedData.eval_count,
totalTokens: parsedData.prompt_eval_count + parsedData.eval_count,
};
}
} catch (error) {
console.info('parse error:', error, chunk);
return { error: chunk };
}
return { content, usage };
};

type OllamaResponse = {
model: string;
created_at: string;
message?: {
role: string;
content: string;
};
done: boolean;
prompt_eval_count: number;
eval_count: number;
};

function getOllamaMessages(
messages: BedrockMessage[],
prompt: SystemPrompt | null
): OpenAIMessage[] {
return [
...(prompt ? [{ role: 'system', content: prompt.prompt }] : []),
...messages.map(message => {
const images = message.content
.filter(content => (content as ImageContent).image)
.map(content => (content as ImageContent).image.source.bytes);

return {
role: message.role,
content: message.content
.map(content => {
if ((content as TextContent).text) {
return (content as TextContent).text;
}
return '';
})
.join('\n'),
images: images.length > 0 ? images : undefined,
};
}),
];
}

export const requestAllOllamaModels = async (): Promise<Model[]> => {
const controller = new AbortController();
const modelsUrl = getOllamaApiUrl() + '/api/tags';
console.log(modelsUrl);
const options = {
method: 'GET',
headers: {
accept: 'application/json',
'content-type': 'application/json',
},
signal: controller.signal,
reactNative: { textStreaming: true },
};
const timeoutId = setTimeout(() => controller.abort(), 3000);
try {
const response = await fetch(modelsUrl, options);
clearTimeout(timeoutId);
if (!response.ok) {
console.log(`HTTP error! status: ${response.status}`);
return [];
}
const data = await response.json();
return data.models.map((item: OllamaModel) => ({
modelId: 'ollama-' + item.name,
modelName: item.name,
}));
} catch (error) {
clearTimeout(timeoutId);
console.log('Error fetching models:', error);
return [];
}
};
50 changes: 40 additions & 10 deletions react-native/src/api/open-api.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ export const invokeOpenAIWithCallBack = async (
reactNative: { textStreaming: true },
};
const url = getApiURL();
let intervalId: ReturnType<typeof setInterval>;
let completeMessage = '';
const timeoutId = setTimeout(() => controller.abort(), 60000);
fetch(url!, options)
Expand All @@ -58,15 +57,37 @@ export const invokeOpenAIWithCallBack = async (
}
const reader = body.getReader();
const decoder = new TextDecoder();
let isFirstReason = true;
let isFirstContent = true;
let lastChunk = '';
while (true) {
const { done, value } = await reader.read();
const chunk = decoder.decode(value, { stream: true });
const parsed = parseStreamData(chunk);
const parsed = parseStreamData(chunk, lastChunk);
if (parsed.error) {
callback(parsed.error, true, true);
callback(completeMessage + '\n\n' + parsed.error, true, true);
break;
}
completeMessage += parsed.content;
if (parsed.reason) {
const formattedReason = parsed.reason.replace(/\n\n/g, '\n>\n>');
if (isFirstReason) {
completeMessage += '> ';
isFirstReason = false;
}
completeMessage += formattedReason;
}
if (parsed.content) {
if (!isFirstReason && isFirstContent) {
completeMessage += '\n\n';
isFirstContent = false;
}
completeMessage += parsed.content;
}
if (parsed.dataChunk) {
lastChunk = parsed.dataChunk;
} else {
lastChunk = '';
}
if (parsed.usage && parsed.usage.inputTokens) {
callback(completeMessage, false, false, parsed.usage);
} else {
Expand All @@ -79,7 +100,6 @@ export const invokeOpenAIWithCallBack = async (
})
.catch(error => {
console.log(error);
clearInterval(intervalId);
if (shouldStop()) {
if (completeMessage === '') {
completeMessage = '...';
Expand All @@ -93,9 +113,10 @@ export const invokeOpenAIWithCallBack = async (
});
};

const parseStreamData = (chunk: string) => {
const dataChunks = chunk.split('\n\n');
const parseStreamData = (chunk: string, lastChunk: string = '') => {
const dataChunks = (lastChunk + chunk).split('\n\n');
let content = '';
let reason = '';
let usage: Usage | undefined;

for (const dataChunk of dataChunks) {
Expand All @@ -114,6 +135,10 @@ const parseStreamData = (chunk: string) => {
content += parsedData.choices[0].delta.content;
}

if (parsedData.choices[0]?.delta?.reasoning_content) {
reason += parsedData.choices[0].delta.reasoning_content;
}

if (parsedData.usage) {
usage = {
modelName: getTextModel().modelName,
Expand All @@ -125,17 +150,22 @@ const parseStreamData = (chunk: string) => {
};
}
} catch (error) {
console.info('parse error:', error, cleanedData);
return { error: cleanedData };
if (lastChunk.length > 0) {
return { error: error + cleanedData };
}
if (reason || content) {
return { reason, content, dataChunk, usage };
}
}
}
return { content, usage };
return { reason, content, usage };
};

type ChatResponse = {
choices: Array<{
delta: {
content: string;
reasoning_content: string;
};
}>;
usage?: {
Expand Down
Binary file added react-native/src/assets/ollama-white.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit d53c06b

Please sign in to comment.