From d53c06b3a0edd4e52331bf91701063b92022a1a1 Mon Sep 17 00:00:00 2001 From: xiaoweii Date: Thu, 23 Jan 2025 14:54:23 +0800 Subject: [PATCH] feat: support ollama and deep seek r1 --- react-native/src/api/bedrock-api.ts | 35 ++-- react-native/src/api/ollama-api.ts | 186 ++++++++++++++++++ react-native/src/api/open-api.ts | 50 ++++- react-native/src/assets/ollama-white.png | Bin 0 -> 2570 bytes .../chat/component/CustomMessageComponent.tsx | 4 + .../chat/component/CustomSendComponent.tsx | 3 +- .../src/chat/component/EmptyChatComponent.tsx | 3 + react-native/src/settings/CustomTextInput.tsx | 59 ++++++ .../src/settings/DropdownComponent.tsx | 2 +- react-native/src/settings/ModelPrice.ts | 6 +- react-native/src/settings/SettingsScreen.tsx | 136 +++++++------ react-native/src/storage/Constants.ts | 24 +-- react-native/src/storage/StorageUtils.ts | 16 ++ react-native/src/types/Chat.ts | 4 + 14 files changed, 437 insertions(+), 91 deletions(-) create mode 100644 react-native/src/api/ollama-api.ts create mode 100644 react-native/src/assets/ollama-white.png create mode 100644 react-native/src/settings/CustomTextInput.tsx diff --git a/react-native/src/api/bedrock-api.ts b/react-native/src/api/bedrock-api.ts index e4fd3b4..3ef07b3 100644 --- a/react-native/src/api/bedrock-api.ts +++ b/react-native/src/api/bedrock-api.ts @@ -24,6 +24,7 @@ import { TextContent, } from '../chat/util/BedrockMessageConvertor.ts'; import { invokeOpenAIWithCallBack } from './open-api.ts'; +import { invokeOllamaWithCallBack } from './ollama-api.ts'; type CallbackFunction = ( result: string, @@ -43,7 +44,8 @@ export const invokeBedrockWithCallBack = async ( ) => { const isDeepSeek = getTextModel().modelId.includes('deepseek'); const isOpenAI = getTextModel().modelId.includes('gpt'); - if (chatMode === ChatMode.Text && (isDeepSeek || isOpenAI)) { + const isOllama = getTextModel().modelId.startsWith('ollama-'); + if (chatMode === ChatMode.Text && (isDeepSeek || isOpenAI || isOllama)) { if (isDeepSeek && getDeepSeekApiKey().length === 0) { callback('Please configure your DeepSeek API Key', true, true); return; @@ -52,13 +54,23 @@ export const invokeBedrockWithCallBack = async ( callback('Please configure your OpenAI API Key', true, true); return; } - await invokeOpenAIWithCallBack( - messages, - prompt, - shouldStop, - controller, - callback - ); + if (isOllama) { + await invokeOllamaWithCallBack( + messages, + prompt, + shouldStop, + controller, + callback + ); + } else { + await invokeOpenAIWithCallBack( + messages, + prompt, + shouldStop, + controller, + callback + ); + } return; } if (!isConfigured()) { @@ -88,7 +100,6 @@ export const invokeBedrockWithCallBack = async ( reactNative: { textStreaming: true }, }; const url = getApiPrefix() + '/converse'; - let intervalId: ReturnType; let completeMessage = ''; const timeoutId = setTimeout(() => controller.abort(), 60000); fetch(url!, options) @@ -131,7 +142,6 @@ export const invokeBedrockWithCallBack = async ( } }) .catch(error => { - clearInterval(intervalId); if (shouldStop()) { if (completeMessage === '') { completeMessage = '...'; @@ -199,6 +209,7 @@ export const invokeBedrockWithCallBack = async ( }; export const requestAllModels = async (): Promise => { + const controller = new AbortController(); const url = getApiPrefix() + '/models'; const bodyObject = { region: getRegion(), @@ -213,9 +224,10 @@ export const requestAllModels = async (): Promise => { body: JSON.stringify(bodyObject), reactNative: { textStreaming: true }, }; - + const timeoutId = setTimeout(() => controller.abort(), 3000); try { const response = await fetch(url, options); + clearTimeout(timeoutId); if (!response.ok) { console.log(`HTTP error! status: ${response.status}`); return { imageModel: [], textModel: [] }; @@ -223,6 +235,7 @@ export const requestAllModels = async (): Promise => { return await response.json(); } catch (error) { console.log('Error fetching models:', error); + clearTimeout(timeoutId); return { imageModel: [], textModel: [] }; } }; diff --git a/react-native/src/api/ollama-api.ts b/react-native/src/api/ollama-api.ts new file mode 100644 index 0000000..24734c3 --- /dev/null +++ b/react-native/src/api/ollama-api.ts @@ -0,0 +1,186 @@ +import { Model, OllamaModel, SystemPrompt, Usage } from '../types/Chat.ts'; +import { getOllamaApiUrl, getTextModel } from '../storage/StorageUtils.ts'; +import { + BedrockMessage, + ImageContent, + OpenAIMessage, + TextContent, +} from '../chat/util/BedrockMessageConvertor.ts'; + +type CallbackFunction = ( + result: string, + complete: boolean, + needStop: boolean, + usage?: Usage +) => void; +export const invokeOllamaWithCallBack = async ( + messages: BedrockMessage[], + prompt: SystemPrompt | null, + shouldStop: () => boolean, + controller: AbortController, + callback: CallbackFunction +) => { + const bodyObject = { + model: getTextModel().modelId.split('ollama-')[1], + messages: getOllamaMessages(messages, prompt), + }; + console.log(JSON.stringify(bodyObject, null, 2)); + const options = { + method: 'POST', + headers: { + accept: '*/*', + 'content-type': 'application/json', + }, + body: JSON.stringify(bodyObject), + signal: controller.signal, + reactNative: { textStreaming: true }, + }; + const url = getOllamaApiUrl() + '/api/chat'; + let completeMessage = ''; + const timeoutId = setTimeout(() => controller.abort(), 60000); + fetch(url!, options) + .then(response => { + return response.body; + }) + .then(async body => { + clearTimeout(timeoutId); + if (!body) { + return; + } + const reader = body.getReader(); + const decoder = new TextDecoder(); + while (true) { + const { done, value } = await reader.read(); + const chunk = decoder.decode(value, { stream: true }); + if (!chunk) { + break; + } + const parsed = parseStreamData(chunk); + if (parsed.error) { + callback(parsed.error, true, true); + break; + } + completeMessage += parsed.content; + if (parsed.usage && parsed.usage.inputTokens) { + callback(completeMessage, true, false, parsed.usage); + } else { + callback(completeMessage, done, false); + } + if (done) { + break; + } + } + }) + .catch(error => { + console.log(error); + clearTimeout(timeoutId); + if (shouldStop()) { + if (completeMessage === '') { + completeMessage = '...'; + } + callback(completeMessage, true, true); + } else { + const errorMsg = String(error); + const errorInfo = 'Request error: ' + errorMsg; + callback(completeMessage + '\n\n' + errorInfo, true, true); + } + }); +}; + +const parseStreamData = (chunk: string) => { + let content = ''; + let usage: Usage | undefined; + + try { + const parsedData: OllamaResponse = JSON.parse(chunk); + + if (parsedData.message?.content) { + content = parsedData.message?.content; + } + + if (parsedData.done) { + usage = { + modelName: getTextModel().modelName, + inputTokens: parsedData.prompt_eval_count, + outputTokens: parsedData.eval_count, + totalTokens: parsedData.prompt_eval_count + parsedData.eval_count, + }; + } + } catch (error) { + console.info('parse error:', error, chunk); + return { error: chunk }; + } + return { content, usage }; +}; + +type OllamaResponse = { + model: string; + created_at: string; + message?: { + role: string; + content: string; + }; + done: boolean; + prompt_eval_count: number; + eval_count: number; +}; + +function getOllamaMessages( + messages: BedrockMessage[], + prompt: SystemPrompt | null +): OpenAIMessage[] { + return [ + ...(prompt ? [{ role: 'system', content: prompt.prompt }] : []), + ...messages.map(message => { + const images = message.content + .filter(content => (content as ImageContent).image) + .map(content => (content as ImageContent).image.source.bytes); + + return { + role: message.role, + content: message.content + .map(content => { + if ((content as TextContent).text) { + return (content as TextContent).text; + } + return ''; + }) + .join('\n'), + images: images.length > 0 ? images : undefined, + }; + }), + ]; +} + +export const requestAllOllamaModels = async (): Promise => { + const controller = new AbortController(); + const modelsUrl = getOllamaApiUrl() + '/api/tags'; + console.log(modelsUrl); + const options = { + method: 'GET', + headers: { + accept: 'application/json', + 'content-type': 'application/json', + }, + signal: controller.signal, + reactNative: { textStreaming: true }, + }; + const timeoutId = setTimeout(() => controller.abort(), 3000); + try { + const response = await fetch(modelsUrl, options); + clearTimeout(timeoutId); + if (!response.ok) { + console.log(`HTTP error! status: ${response.status}`); + return []; + } + const data = await response.json(); + return data.models.map((item: OllamaModel) => ({ + modelId: 'ollama-' + item.name, + modelName: item.name, + })); + } catch (error) { + clearTimeout(timeoutId); + console.log('Error fetching models:', error); + return []; + } +}; diff --git a/react-native/src/api/open-api.ts b/react-native/src/api/open-api.ts index 3132cbe..f5567a3 100644 --- a/react-native/src/api/open-api.ts +++ b/react-native/src/api/open-api.ts @@ -44,7 +44,6 @@ export const invokeOpenAIWithCallBack = async ( reactNative: { textStreaming: true }, }; const url = getApiURL(); - let intervalId: ReturnType; let completeMessage = ''; const timeoutId = setTimeout(() => controller.abort(), 60000); fetch(url!, options) @@ -58,15 +57,37 @@ export const invokeOpenAIWithCallBack = async ( } const reader = body.getReader(); const decoder = new TextDecoder(); + let isFirstReason = true; + let isFirstContent = true; + let lastChunk = ''; while (true) { const { done, value } = await reader.read(); const chunk = decoder.decode(value, { stream: true }); - const parsed = parseStreamData(chunk); + const parsed = parseStreamData(chunk, lastChunk); if (parsed.error) { - callback(parsed.error, true, true); + callback(completeMessage + '\n\n' + parsed.error, true, true); break; } - completeMessage += parsed.content; + if (parsed.reason) { + const formattedReason = parsed.reason.replace(/\n\n/g, '\n>\n>'); + if (isFirstReason) { + completeMessage += '> '; + isFirstReason = false; + } + completeMessage += formattedReason; + } + if (parsed.content) { + if (!isFirstReason && isFirstContent) { + completeMessage += '\n\n'; + isFirstContent = false; + } + completeMessage += parsed.content; + } + if (parsed.dataChunk) { + lastChunk = parsed.dataChunk; + } else { + lastChunk = ''; + } if (parsed.usage && parsed.usage.inputTokens) { callback(completeMessage, false, false, parsed.usage); } else { @@ -79,7 +100,6 @@ export const invokeOpenAIWithCallBack = async ( }) .catch(error => { console.log(error); - clearInterval(intervalId); if (shouldStop()) { if (completeMessage === '') { completeMessage = '...'; @@ -93,9 +113,10 @@ export const invokeOpenAIWithCallBack = async ( }); }; -const parseStreamData = (chunk: string) => { - const dataChunks = chunk.split('\n\n'); +const parseStreamData = (chunk: string, lastChunk: string = '') => { + const dataChunks = (lastChunk + chunk).split('\n\n'); let content = ''; + let reason = ''; let usage: Usage | undefined; for (const dataChunk of dataChunks) { @@ -114,6 +135,10 @@ const parseStreamData = (chunk: string) => { content += parsedData.choices[0].delta.content; } + if (parsedData.choices[0]?.delta?.reasoning_content) { + reason += parsedData.choices[0].delta.reasoning_content; + } + if (parsedData.usage) { usage = { modelName: getTextModel().modelName, @@ -125,17 +150,22 @@ const parseStreamData = (chunk: string) => { }; } } catch (error) { - console.info('parse error:', error, cleanedData); - return { error: cleanedData }; + if (lastChunk.length > 0) { + return { error: error + cleanedData }; + } + if (reason || content) { + return { reason, content, dataChunk, usage }; + } } } - return { content, usage }; + return { reason, content, usage }; }; type ChatResponse = { choices: Array<{ delta: { content: string; + reasoning_content: string; }; }>; usage?: { diff --git a/react-native/src/assets/ollama-white.png b/react-native/src/assets/ollama-white.png new file mode 100644 index 0000000000000000000000000000000000000000..7f1140c3c372b1240a67986f71f8f250856cd92d GIT binary patch literal 2570 zcmV+l3ib7gP)C0001!P)t-s|Ns90 z008^@`z0kM6B8329v%z~4D#~w>gwv$)YSO+_@SYp#>U2FWo7Q}?#jx_y}iAZl$4s9 zny|32EiElsSy@w4Q{CO&_4W0HgoHgkJ>lWu<>lpaa&mWfcQ-dTGBPsR+1XxRUbVHg zOiWCtr>BXDiL0xtM@L7wxw*o^!p_dljg5_MZEb03X+lCmeSLkSqaeiq00}ZlL_t(| z+U;5kbJ{QrMv3i&glBobABC2!E#3eBb^9b6JdT51YO>T&XmQWI#^1N-02ZM_rIu8pO5nEVOp z6r}0B!0_(e+0f24)T=!VK$q5rKxR!(7~cK-LrP_??o$KfhsbZ!nmH=`>Jv13Z~Z5f zVwy7e(bJsg#>>yPr0D!TgmvAX``yYH1Svf|20(|IYI<4X7cQUgaLtBfSpHGruhDAu z2n0ow6hto)aA&^5F}U8q+8`xZlW(N?qI`q`>cR_a5B=gw-LZlM-Xg1HZNbnduA%{y zV5MlAB}s77jj8!3Da`&>E1(MpJBgJ){t%{cYOw|k1})sa;V!Iz&z!MLjbby5 z2Qcp|W3(9o0JxF@-_40@V(C2hl~SjnBzX}q3k0s%ARtVenQ+A#W7uJ)Ub0*}6Huf$ z{^0AIXls~91L%p&E=0J2rl*s{qBWn-^DJ=~S_d+Q(zyT_DGklhC>p?%$b|D-k1z&$ z6*Tkj-Q{4qT75Qu?B6o}?}iirYzqO<8i?)GqmAH)sI_eY%!ohZv{lO*lT;!;?@*v! zxq*$!vQH8lKR-x61VB^lu5JP3D?zaNz9mwStqe$MT#%W}tg`kzRje?51b|{lEW*Xq zVj!V*O4Qf5Afps(UC`hM|J3ZTb_;zV03-XF(+7c|cdG)o3Wo0~{C4IN55VUly9)qE z0U(mt2Bw4K_Fxzbei;j(^9MAb(zye$001F~Z@$*^R+iiswh`gHl#^NzyS?-_P_0fbE zW%ro9z!O`%aJ{t~3kXMk>yF(GK?p`5k)Mukd1+p# zD;@wp+;uz=@tOKG7$5i3Dy)!y z(30AR2N7k+Zl={Ap0P74Iz&WsR0mQ~bFC7rfOd=@+fv5nl$9-DD3$0WvePuOR`3dr z0h^GKgd)1Iaiv|r3dAb_p1_$>TJ?hCEnxFp01rq-H&5UP= zSFFm)9*nQ0f6CY6*Tb9tk&@Mbli{{a)z+94;{MjY_7Ye|4I=SPMN_VCSO=n9= zQwUu2>Z4ZHn$>$tMX64ynPYXfWO!-3z-)yls{?3|_HM}SAK0=)#@cv6P%*vKi?g{G zqb~R7CE&iZ2K*;RMBz2G9dj0JBca1E-@^KM^Dbp(?`*!@vD8&trsqzl zPv;toVe>iR-nsS_Uy;a9B9M0g{DJi8xj9@k62Ah!G3jHR+y3RA}WwO3@&T-N! zmrEpEZq;l@WSkcm`|a4*vz6B1K#JQ#E`(Z<%&$)`PX3%MEfPD1b--`jz&HB?(o$dQ zYWk6QZq@~?I$dzkx^LIl32WLm^5w@+W)=W|TjQ?f;%{bb0^sJR9J%i37AIF(8~VIB zPXS)-4v#`JT0;ZN(jQ=nhdBg&g9SL_nOrnkin|^Gdp-_ap_7@n=OkV0Iy<&V!)R- zo9C!jbo;B$$x(}s%MwJ)l$99NEp!hn8hQoT&#a^$?Uucm=jV*2=R)DHOg6*;J*+ai zKKGK_@7#TBIA4<=cikbCN+EgG<>Ov+00KxNg)op@9b#!kW$UsTP|HfF#_QT!vhCqF zf=e@Wgs!^=khI%xT4(3TW<5Ey_c6my@47Rm$^bS{wQn?Dnl^C)M1#qbrI#j#X`kR8 z*R$K9rJ3aA685iqvO7|Jq-37-x({;|bGbfEHK5fzpfdBk&bQr6c1efZyz`h~3ZY4M+7{3X^+f zc3BlbHZYvY+-i1Twd<1G?e;(bLk6(EXS~U;bf-EF1CUQ)yR-^zcE@X=b#NMp2M}!o zp%ia~q|7EBz!Rk*EgCr?OryGlzA&+gY(RU>_V;2)k~n4CeL=P}iGK$rRmQ%`5pY;# zR7fUq{c@M4#kdxd54aXZ3^`lcS7u|+eiaRZBY|)C85Lm6Hj}{jKJL`3DHY~t9sppx zK*64*q|+6!jmZjxiuWIbl$7T&xm?T|90Gdf)_^De$RdWs4=S?p9l_{#LecLG zIX0}hcS4OhL}x?|QRGutha~ = ({ const isDeepSeek = userName.includes('DeepSeek'); const isOpenAI = userName.includes('GPT'); + const isOllama = userName.includes(':'); const modelIcon = isDeepSeek ? require('../../assets/deepseek.png') : isOpenAI ? require('../../assets/openai.png') + : isOllama + ? require('../../assets/ollama-white.png') : require('../../assets/bedrock.png'); const imgSource = @@ -253,6 +256,7 @@ const customMarkedStyles: MarkedStyles = { h2: { fontSize: 24 }, h3: { fontSize: 20 }, h4: { fontSize: 18 }, + blockquote: { marginVertical: 8 }, }; export default CustomMessageComponent; diff --git a/react-native/src/chat/component/CustomSendComponent.tsx b/react-native/src/chat/component/CustomSendComponent.tsx index 4bcd669..1c454c6 100644 --- a/react-native/src/chat/component/CustomSendComponent.tsx +++ b/react-native/src/chat/component/CustomSendComponent.tsx @@ -87,7 +87,8 @@ const isMultiModalModel = (): boolean => { return ( textModelId.includes('claude-3') || textModelId.includes('nova-pro') || - textModelId.includes('nova-lite') + textModelId.includes('nova-lite') || + textModelId.startsWith('ollama') ); }; diff --git a/react-native/src/chat/component/EmptyChatComponent.tsx b/react-native/src/chat/component/EmptyChatComponent.tsx index b2a846a..b6a6b00 100644 --- a/react-native/src/chat/component/EmptyChatComponent.tsx +++ b/react-native/src/chat/component/EmptyChatComponent.tsx @@ -25,10 +25,13 @@ export const EmptyChatComponent = ({ const navigation = useNavigation(); const isDeepSeek = getTextModel().modelId.includes('deepseek'); const isOpenAI = getTextModel().modelId.includes('gpt'); + const isOllama = getTextModel().modelId.startsWith('ollama-'); const modelIcon = isDeepSeek ? require('../../assets/deepseek.png') : isOpenAI ? require('../../assets/openai.png') + : isOllama + ? require('../../assets/ollama-white.png') : require('../../assets/bedrock.png'); const source = chatMode === ChatMode.Text ? modelIcon : require('../../assets/image.png'); diff --git a/react-native/src/settings/CustomTextInput.tsx b/react-native/src/settings/CustomTextInput.tsx new file mode 100644 index 0000000..2a688a5 --- /dev/null +++ b/react-native/src/settings/CustomTextInput.tsx @@ -0,0 +1,59 @@ +import React from 'react'; +import { StyleSheet, Text, TextInput, View } from 'react-native'; + +interface CustomTextInputProps { + label: string; + value: string; + onChangeText: (text: string) => void; + placeholder: string; + secureTextEntry?: boolean; +} + +const CustomTextInput: React.FC = ({ + label, + value, + onChangeText, + placeholder, + secureTextEntry = false, +}) => { + return ( + + {label} + + + ); +}; + +const styles = StyleSheet.create({ + container: { + marginBottom: 12, + marginTop: 8, + }, + label: { + position: 'absolute', + backgroundColor: 'white', + color: 'black', + left: 8, + top: -8, + zIndex: 999, + paddingHorizontal: 4, + fontSize: 12, + fontWeight: '500', + }, + input: { + height: 44, + borderColor: 'gray', + borderWidth: 1, + borderRadius: 6, + paddingHorizontal: 10, + color: 'black', + }, +}); + +export default CustomTextInput; diff --git a/react-native/src/settings/DropdownComponent.tsx b/react-native/src/settings/DropdownComponent.tsx index acd4345..fbe90d3 100644 --- a/react-native/src/settings/DropdownComponent.tsx +++ b/react-native/src/settings/DropdownComponent.tsx @@ -84,7 +84,7 @@ const styles = StyleSheet.create({ zIndex: 999, paddingHorizontal: 4, fontSize: 12, - fontWeight: '600', + fontWeight: '500', }, dropdown: { height: 44, diff --git a/react-native/src/settings/ModelPrice.ts b/react-native/src/settings/ModelPrice.ts index 6e111d0..054064b 100644 --- a/react-native/src/settings/ModelPrice.ts +++ b/react-native/src/settings/ModelPrice.ts @@ -81,10 +81,14 @@ function getImagePrice( export const ModelPrice: ModelPriceType = { textModelPrices: { - 'DeepSeek v3': { + 'DeepSeek-V3': { inputTokenPrice: 0.00014, outputTokenPrice: 0.00028, }, + 'DeepSeek-R1': { + inputTokenPrice: 0.00055, + outputTokenPrice: 0.00219, + }, 'GPT-4o': { inputTokenPrice: 0.0025, outputTokenPrice: 0.01, diff --git a/react-native/src/settings/SettingsScreen.tsx b/react-native/src/settings/SettingsScreen.tsx index 0dd3013..fbdff69 100644 --- a/react-native/src/settings/SettingsScreen.tsx +++ b/react-native/src/settings/SettingsScreen.tsx @@ -1,5 +1,5 @@ import * as React from 'react'; -import { useEffect, useState } from 'react'; +import { useEffect, useRef, useState } from 'react'; import { Image, Linking, @@ -9,7 +9,6 @@ import { StyleSheet, Switch, Text, - TextInput, TouchableOpacity, View, } from 'react-native'; @@ -26,6 +25,7 @@ import { getImageModel, getImageSize, getModelUsage, + getOllamaApiUrl, getOpenAIApiKey, getRegion, getTextModel, @@ -35,6 +35,7 @@ import { saveImageModel, saveImageSize, saveKeys, + saveOllamaApiURL, saveOpenAIApiKey, saveRegion, saveTextModel, @@ -42,14 +43,21 @@ import { import { CustomHeaderRightButton } from '../chat/component/CustomHeaderRightButton.tsx'; import { RouteParamList } from '../types/RouteTypes.ts'; import { requestAllModels, requestUpgradeInfo } from '../api/bedrock-api.ts'; -import { DropdownItem, Model, UpgradeInfo } from '../types/Chat.ts'; +import { AllModel, DropdownItem, Model, UpgradeInfo } from '../types/Chat.ts'; import packageJson from '../../package.json'; import { isMac } from '../App.tsx'; import CustomDropdown from './DropdownComponent.tsx'; import { getTotalCost } from './ModelPrice.ts'; -import { getAllRegions } from '../storage/Constants.ts'; +import { + DeepSeekModels, + getAllRegions, + getDefaultTextModels, + GPTModels, +} from '../storage/Constants.ts'; import { showInfo } from '../chat/util/ToastUtils.ts'; +import CustomTextInput from './CustomTextInput.tsx'; +import { requestAllOllamaModels } from '../api/ollama-api.ts'; const initUpgradeInfo: UpgradeInfo = { needUpgrade: false, @@ -62,6 +70,7 @@ const GITHUB_LINK = 'https://github.com/aws-samples/swift-chat'; function SettingsScreen(): React.JSX.Element { const [apiUrl, setApiUrl] = useState(getApiUrl); const [apiKey, setApiKey] = useState(getApiKey); + const [ollamaApiUrl, setOllamaApiUrl] = useState(getOllamaApiUrl); const [deepSeekApiKey, setDeepSeekApiKey] = useState(getDeepSeekApiKey); const [openAIApiKey, setOpenAIApiKey] = useState(getOpenAIApiKey); const [region, setRegion] = useState(getRegion); @@ -74,7 +83,7 @@ function SettingsScreen(): React.JSX.Element { const [selectedImageModel, setSelectedImageModel] = useState(''); const [upgradeInfo, setUpgradeInfo] = useState(initUpgradeInfo); const [cost, setCost] = useState('0.00'); - + const controllerRef = useRef(null); useEffect(() => { return navigation.addListener('focus', () => { setCost(getTotalCost(getModelUsage()).toString()); @@ -113,6 +122,13 @@ function SettingsScreen(): React.JSX.Element { } }, [apiUrl, apiKey]); + useEffect(() => { + saveOllamaApiURL(ollamaApiUrl); + if (ollamaApiUrl.length > 0) { + fetchAndSetModelNames().then(); + } + }, [ollamaApiUrl]); + useEffect(() => { saveDeepSeekApiKey(deepSeekApiKey); }, [deepSeekApiKey]); @@ -122,6 +138,9 @@ function SettingsScreen(): React.JSX.Element { }, [openAIApiKey]); const fetchAndSetModelNames = async () => { + controllerRef.current = new AbortController(); + let ollamaModels: Model[] = []; + ollamaModels = await requestAllOllamaModels(); const response = await requestAllModels(); if (response.imageModel.length > 0) { setImageModels(response.imageModel); @@ -137,38 +156,31 @@ function SettingsScreen(): React.JSX.Element { saveImageModel(response.imageModel[0]); } } - if (response.textModel.length > 0) { + if (response.textModel.length === 0) { + response.textModel = [...getDefaultTextModels(), ...ollamaModels]; + } else { response.textModel = [ - { - modelName: 'DeepSeek v3', - modelId: 'deepseek-chat', - }, - { - modelName: 'GPT-4o', - modelId: 'gpt-4o', - }, - { - modelName: 'GPT-4o mini', - modelId: 'gpt-4o-mini', - }, ...response.textModel, + ...GPTModels, + ...DeepSeekModels, + ...ollamaModels, ]; - setTextModels(response.textModel); - const textModel = getTextModel(); - const targetModels = response.textModel.filter( - model => model.modelName === textModel.modelName + } + setTextModels(response.textModel); + const textModel = getTextModel(); + const targetModels = response.textModel.filter( + model => model.modelName === textModel.modelName + ); + if (targetModels && targetModels.length === 1) { + setSelectedTextModel(targetModels[0].modelId); + saveTextModel(targetModels[0]); + } else { + const defaultMissMatchModel = response.textModel.filter( + model => model.modelName === 'Claude 3 Sonnet' ); - if (targetModels && targetModels.length === 1) { - setSelectedTextModel(targetModels[0].modelId); - saveTextModel(targetModels[0]); - } else { - const defaultMissMatchModel = response.textModel.filter( - model => model.modelName === 'Claude 3 Sonnet' - ); - if (defaultMissMatchModel && defaultMissMatchModel.length === 1) { - setSelectedTextModel(defaultMissMatchModel[0].modelId); - saveTextModel(defaultMissMatchModel[0]); - } + if (defaultMissMatchModel && defaultMissMatchModel.length === 1) { + setSelectedTextModel(defaultMissMatchModel[0].modelId); + saveTextModel(defaultMissMatchModel[0]); } } if (response.imageModel.length > 0 || response.textModel.length > 0) { @@ -230,37 +242,20 @@ function SettingsScreen(): React.JSX.Element { return ( - API URL - Amazon Bedrock + - API Key - - DeepSeek API Key - - OpenAI API Key - + + Other Model Provider + + + + + Select Model