diff --git a/core/src/types/index.ts b/core/src/types/index.ts index 5b91fcc8a5..d5b51cfc04 100644 --- a/core/src/types/index.ts +++ b/core/src/types/index.ts @@ -71,9 +71,9 @@ export type ThreadMessage = { object: string; /** Thread id, default is a ulid. **/ thread_id: string; - /** The role of the author of this message. **/ + /** The assistant id of this thread. **/ assistant_id?: string; - // TODO: comment + /** The role of the author of this message. **/ role: ChatCompletionRole; /** The content of this message. **/ content: ThreadContent[]; @@ -125,8 +125,6 @@ export interface Thread { title: string; /** Assistants in this thread. **/ assistants: ThreadAssistantInfo[]; - // if the thread has been init will full assistant info - isFinishInit: boolean; /** The timestamp indicating when this thread was created, represented in ISO 8601 format. **/ created: number; /** The timestamp indicating when this thread was updated, represented in ISO 8601 format. **/ @@ -166,6 +164,7 @@ export type ThreadState = { waitingForResponse: boolean; error?: Error; lastMessage?: string; + isFinishInit?: boolean; }; /** * Represents the inference engine. @@ -291,6 +290,9 @@ export type ModelRuntimeParams = { top_p?: number; stream?: boolean; max_tokens?: number; + stop?: string[]; + frequency_penalty?: number; + presence_penalty?: number; }; /** diff --git a/extensions/assistant-extension/src/index.ts b/extensions/assistant-extension/src/index.ts index 3e1ab4898b..9b48ce6c94 100644 --- a/extensions/assistant-extension/src/index.ts +++ b/extensions/assistant-extension/src/index.ts @@ -22,10 +22,6 @@ export default class JanAssistantExtension implements AssistantExtension { onUnload(): void {} async createAssistant(assistant: Assistant): Promise { - // assuming that assistants/ directory is already created in the onLoad above - - // TODO: check if the directory already exists, then ignore creation for now - const assistantDir = join(JanAssistantExtension._homeDir, assistant.id); await fs.mkdir(assistantDir); @@ -91,7 +87,7 @@ export default class JanAssistantExtension implements AssistantExtension { avatar: "", thread_location: undefined, id: "jan", - object: "assistant", // TODO: maybe we can set default value for this? + object: "assistant", created_at: Date.now(), name: "Jan", description: "A default assistant that can use all downloaded models", diff --git a/extensions/inference-nitro-extension/src/helpers/sse.ts b/extensions/inference-nitro-extension/src/helpers/sse.ts index 978d9e3c56..a7b35f2f06 100644 --- a/extensions/inference-nitro-extension/src/helpers/sse.ts +++ b/extensions/inference-nitro-extension/src/helpers/sse.ts @@ -7,7 +7,6 @@ import { Observable } from "rxjs"; */ export function requestInference( recentMessages: any[], - engine: EngineSettings, model: Model, controller?: AbortController ): Observable { @@ -23,34 +22,41 @@ export function requestInference( headers: { "Content-Type": "application/json", "Access-Control-Allow-Origin": "*", - Accept: "text/event-stream", + Accept: model.parameters.stream + ? "text/event-stream" + : "application/json", }, body: requestBody, signal: controller?.signal, }) .then(async (response) => { - const stream = response.body; - const decoder = new TextDecoder("utf-8"); - const reader = stream?.getReader(); - let content = ""; + if (model.parameters.stream) { + const stream = response.body; + const decoder = new TextDecoder("utf-8"); + const reader = stream?.getReader(); + let content = ""; - while (true && reader) { - const { done, value } = await reader.read(); - if (done) { - break; - } - const text = decoder.decode(value); - const lines = text.trim().split("\n"); - for (const line of lines) { - if (line.startsWith("data: ") && !line.includes("data: [DONE]")) { - const data = JSON.parse(line.replace("data: ", "")); - content += data.choices[0]?.delta?.content ?? ""; - if (content.startsWith("assistant: ")) { - content = content.replace("assistant: ", ""); + while (true && reader) { + const { done, value } = await reader.read(); + if (done) { + break; + } + const text = decoder.decode(value); + const lines = text.trim().split("\n"); + for (const line of lines) { + if (line.startsWith("data: ") && !line.includes("data: [DONE]")) { + const data = JSON.parse(line.replace("data: ", "")); + content += data.choices[0]?.delta?.content ?? ""; + if (content.startsWith("assistant: ")) { + content = content.replace("assistant: ", ""); + } + subscriber.next(content); } - subscriber.next(content); } } + } else { + const data = await response.json(); + subscriber.next(data.choices[0]?.message?.content ?? ""); } subscriber.complete(); }) diff --git a/extensions/inference-nitro-extension/src/index.ts b/extensions/inference-nitro-extension/src/index.ts index 975d94100a..e5f3f43608 100644 --- a/extensions/inference-nitro-extension/src/index.ts +++ b/extensions/inference-nitro-extension/src/index.ts @@ -85,7 +85,6 @@ export default class JanInferenceNitroExtension implements InferenceExtension { */ onUnload(): void {} - private async writeDefaultEngineSettings() { try { const engineFile = join( @@ -164,7 +163,6 @@ export default class JanInferenceNitroExtension implements InferenceExtension { return new Promise(async (resolve, reject) => { requestInference( data.messages ?? [], - JanInferenceNitroExtension._engineSettings, JanInferenceNitroExtension._currentModel ).subscribe({ next: (_content) => {}, @@ -210,8 +208,7 @@ export default class JanInferenceNitroExtension implements InferenceExtension { requestInference( data.messages ?? [], - JanInferenceNitroExtension._engineSettings, - JanInferenceNitroExtension._currentModel, + { ...JanInferenceNitroExtension._currentModel, ...data.model }, instance.controller ).subscribe({ next: (content) => { diff --git a/extensions/inference-openai-extension/src/helpers/sse.ts b/extensions/inference-openai-extension/src/helpers/sse.ts index dbbfc2cb26..14c55779ea 100644 --- a/extensions/inference-openai-extension/src/helpers/sse.ts +++ b/extensions/inference-openai-extension/src/helpers/sse.ts @@ -15,9 +15,9 @@ export function requestInference( controller?: AbortController ): Observable { return new Observable((subscriber) => { - let model_id: string = model.id - if (engine.full_url.includes("openai.azure.com")){ - model_id = engine.full_url.split("/")[5] + let model_id: string = model.id; + if (engine.full_url.includes("openai.azure.com")) { + model_id = engine.full_url.split("/")[5]; } const requestBody = JSON.stringify({ messages: recentMessages, @@ -29,7 +29,9 @@ export function requestInference( method: "POST", headers: { "Content-Type": "application/json", - Accept: "text/event-stream", + Accept: model.parameters.stream + ? "text/event-stream" + : "application/json", "Access-Control-Allow-Origin": "*", Authorization: `Bearer ${engine.api_key}`, "api-key": `${engine.api_key}`, @@ -38,28 +40,33 @@ export function requestInference( signal: controller?.signal, }) .then(async (response) => { - const stream = response.body; - const decoder = new TextDecoder("utf-8"); - const reader = stream?.getReader(); - let content = ""; + if (model.parameters.stream) { + const stream = response.body; + const decoder = new TextDecoder("utf-8"); + const reader = stream?.getReader(); + let content = ""; - while (true && reader) { - const { done, value } = await reader.read(); - if (done) { - break; - } - const text = decoder.decode(value); - const lines = text.trim().split("\n"); - for (const line of lines) { - if (line.startsWith("data: ") && !line.includes("data: [DONE]")) { - const data = JSON.parse(line.replace("data: ", "")); - content += data.choices[0]?.delta?.content ?? ""; - if (content.startsWith("assistant: ")) { - content = content.replace("assistant: ", ""); + while (true && reader) { + const { done, value } = await reader.read(); + if (done) { + break; + } + const text = decoder.decode(value); + const lines = text.trim().split("\n"); + for (const line of lines) { + if (line.startsWith("data: ") && !line.includes("data: [DONE]")) { + const data = JSON.parse(line.replace("data: ", "")); + content += data.choices[0]?.delta?.content ?? ""; + if (content.startsWith("assistant: ")) { + content = content.replace("assistant: ", ""); + } + subscriber.next(content); } - subscriber.next(content); } } + } else { + const data = await response.json(); + subscriber.next(data.choices[0]?.message?.content ?? ""); } subscriber.complete(); }) diff --git a/uikit/package.json b/uikit/package.json index a96b5d37e3..43e73dcf22 100644 --- a/uikit/package.json +++ b/uikit/package.json @@ -25,6 +25,7 @@ "@radix-ui/react-progress": "^1.0.3", "@radix-ui/react-scroll-area": "^1.0.5", "@radix-ui/react-select": "^2.0.0", + "@radix-ui/react-slider": "^1.1.2", "@radix-ui/react-slot": "^1.0.2", "@radix-ui/react-switch": "^1.0.3", "@radix-ui/react-toast": "^1.1.5", diff --git a/uikit/src/index.ts b/uikit/src/index.ts index 067752de0b..3d5eaa82a9 100644 --- a/uikit/src/index.ts +++ b/uikit/src/index.ts @@ -11,3 +11,4 @@ export * from './modal' export * from './command' export * from './textarea' export * from './select' +export * from './slider' diff --git a/uikit/src/input/styles.scss b/uikit/src/input/styles.scss index 76e6c0408a..ba4d81a035 100644 --- a/uikit/src/input/styles.scss +++ b/uikit/src/input/styles.scss @@ -1,5 +1,5 @@ .input { - @apply border-border placeholder:text-muted-foreground flex h-9 w-full rounded-md border bg-transparent px-3 py-1 transition-colors; + @apply border-border placeholder:text-muted-foreground flex h-9 w-full rounded-lg border bg-transparent px-3 py-1 transition-colors; @apply disabled:cursor-not-allowed disabled:opacity-50; @apply focus-visible:ring-secondary focus-visible:outline-none focus-visible:ring-1; @apply file:border-0 file:bg-transparent file:font-medium; diff --git a/uikit/src/main.scss b/uikit/src/main.scss index 1eca363b43..546f22811d 100644 --- a/uikit/src/main.scss +++ b/uikit/src/main.scss @@ -15,6 +15,7 @@ @import './command/styles.scss'; @import './textarea/styles.scss'; @import './select/styles.scss'; +@import './slider/styles.scss'; .animate-spin { animation: spin 1s linear infinite; diff --git a/uikit/src/slider/index.tsx b/uikit/src/slider/index.tsx new file mode 100644 index 0000000000..8994d833ed --- /dev/null +++ b/uikit/src/slider/index.tsx @@ -0,0 +1,25 @@ +'use client' + +import * as React from 'react' +import * as SliderPrimitive from '@radix-ui/react-slider' + +import { twMerge } from 'tailwind-merge' + +const Slider = React.forwardRef< + React.ElementRef, + React.ComponentPropsWithoutRef +>(({ className, ...props }, ref) => ( + + + + + + +)) +Slider.displayName = SliderPrimitive.Root.displayName + +export { Slider } diff --git a/uikit/src/slider/styles.scss b/uikit/src/slider/styles.scss new file mode 100644 index 0000000000..889c351517 --- /dev/null +++ b/uikit/src/slider/styles.scss @@ -0,0 +1,15 @@ +.slider { + @apply relative flex w-full touch-none select-none items-center; + + &-track { + @apply relative h-1.5 w-full grow overflow-hidden rounded-full bg-gray-200 dark:bg-gray-800; + } + + &-range { + @apply absolute h-full bg-blue-600; + } + + &-thumb { + @apply border-primary/50 bg-background focus-visible:ring-ring block h-4 w-4 rounded-full border shadow transition-colors focus-visible:outline-none focus-visible:ring-1 disabled:pointer-events-none disabled:opacity-50; + } +} diff --git a/web/containers/CardSidebar/index.tsx b/web/containers/CardSidebar/index.tsx index c34e9132de..8c7fe33142 100644 --- a/web/containers/CardSidebar/index.tsx +++ b/web/containers/CardSidebar/index.tsx @@ -32,7 +32,7 @@ export default function CardSidebar({ return (
+
+ ) +} + +export default SliderRightPanel diff --git a/web/helpers/atoms/ChatMessage.atom.ts b/web/helpers/atoms/ChatMessage.atom.ts index d365304e4b..33309e6fc3 100644 --- a/web/helpers/atoms/ChatMessage.atom.ts +++ b/web/helpers/atoms/ChatMessage.atom.ts @@ -9,7 +9,7 @@ import { atom } from 'jotai' import { getActiveThreadIdAtom, updateThreadStateLastMessageAtom, -} from './Conversation.atom' +} from './Thread.atom' /** * Stores all chat messages for all threads @@ -76,15 +76,18 @@ export const addNewMessageAtom = atom( } ) -export const deleteConversationMessage = atom(null, (get, set, id: string) => { - const newData: Record = { - ...get(chatMessages), +export const deleteChatMessageAtom = atom( + null, + (get, set, threadId: string) => { + const newData: Record = { + ...get(chatMessages), + } + newData[threadId] = [] + set(chatMessages, newData) } - newData[id] = [] - set(chatMessages, newData) -}) +) -export const cleanConversationMessages = atom(null, (get, set, id: string) => { +export const cleanChatMessageAtom = atom(null, (get, set, id: string) => { const newData: Record = { ...get(chatMessages), } diff --git a/web/helpers/atoms/Conversation.atom.ts b/web/helpers/atoms/Thread.atom.ts similarity index 53% rename from web/helpers/atoms/Conversation.atom.ts rename to web/helpers/atoms/Thread.atom.ts index 21a89c26b4..5ec7173b19 100644 --- a/web/helpers/atoms/Conversation.atom.ts +++ b/web/helpers/atoms/Thread.atom.ts @@ -1,8 +1,13 @@ -import { Thread, ThreadContent, ThreadState } from '@janhq/core' +import { + ModelRuntimeParams, + Thread, + ThreadContent, + ThreadState, +} from '@janhq/core' import { atom } from 'jotai' /** - * Stores the current active conversation id. + * Stores the current active thread id. */ const activeThreadIdAtom = atom(undefined) @@ -10,7 +15,7 @@ export const getActiveThreadIdAtom = atom((get) => get(activeThreadIdAtom)) export const setActiveThreadIdAtom = atom( null, - (_get, set, convoId: string | undefined) => set(activeThreadIdAtom, convoId) + (_get, set, threadId: string | undefined) => set(activeThreadIdAtom, threadId) ) export const waitingToSendMessage = atom(undefined) @@ -20,47 +25,48 @@ export const waitingToSendMessage = atom(undefined) */ export const threadStatesAtom = atom>({}) export const activeThreadStateAtom = atom((get) => { - const activeConvoId = get(activeThreadIdAtom) - if (!activeConvoId) { - console.debug('Active convo id is undefined') + const threadId = get(activeThreadIdAtom) + if (!threadId) { + console.debug('Active thread id is undefined') return undefined } - return get(threadStatesAtom)[activeConvoId] + return get(threadStatesAtom)[threadId] }) -export const updateThreadWaitingForResponseAtom = atom( +export const deleteThreadStateAtom = atom( null, - (get, set, conversationId: string, waitingForResponse: boolean) => { + (get, set, threadId: string) => { const currentState = { ...get(threadStatesAtom) } - currentState[conversationId] = { - ...currentState[conversationId], - waitingForResponse, - error: undefined, - } + delete currentState[threadId] set(threadStatesAtom, currentState) } ) -export const updateConversationErrorAtom = atom( + +export const updateThreadInitSuccessAtom = atom( null, - (get, set, conversationId: string, error?: Error) => { + (get, set, threadId: string) => { const currentState = { ...get(threadStatesAtom) } - currentState[conversationId] = { - ...currentState[conversationId], - error, + currentState[threadId] = { + ...currentState[threadId], + isFinishInit: true, } set(threadStatesAtom, currentState) } ) -export const updateConversationHasMoreAtom = atom( + +export const updateThreadWaitingForResponseAtom = atom( null, - (get, set, conversationId: string, hasMore: boolean) => { + (get, set, threadId: string, waitingForResponse: boolean) => { const currentState = { ...get(threadStatesAtom) } - currentState[conversationId] = { ...currentState[conversationId], hasMore } + currentState[threadId] = { + ...currentState[threadId], + waitingForResponse, + error: undefined, + } set(threadStatesAtom, currentState) } ) - export const updateThreadStateLastMessageAtom = atom( null, (get, set, threadId: string, lastContent?: ThreadContent[]) => { @@ -100,3 +106,42 @@ export const threadsAtom = atom([]) export const activeThreadAtom = atom((get) => get(threadsAtom).find((c) => c.id === get(getActiveThreadIdAtom)) ) + +/** + * Store model params at thread level settings + */ +export const threadModelRuntimeParamsAtom = atom< + Record +>({}) + +export const getActiveThreadModelRuntimeParamsAtom = atom< + ModelRuntimeParams | undefined +>((get) => { + const threadId = get(activeThreadIdAtom) + if (!threadId) { + console.debug('Active thread id is undefined') + return undefined + } + + return get(threadModelRuntimeParamsAtom)[threadId] +}) + +export const getThreadModelRuntimeParamsAtom = atom( + (get, threadId: string) => get(threadModelRuntimeParamsAtom)[threadId] +) + +export const setThreadModelRuntimeParamsAtom = atom( + null, + (get, set, threadId: string, params: ModelRuntimeParams) => { + const currentState = { ...get(threadModelRuntimeParamsAtom) } + currentState[threadId] = params + console.debug( + `Update model params for thread ${threadId}, ${JSON.stringify( + params, + null, + 2 + )}` + ) + set(threadModelRuntimeParamsAtom, currentState) + } +) diff --git a/web/hooks/useActiveModel.ts b/web/hooks/useActiveModel.ts index 699b162790..15084278ca 100644 --- a/web/hooks/useActiveModel.ts +++ b/web/hooks/useActiveModel.ts @@ -1,8 +1,5 @@ /* eslint-disable @typescript-eslint/no-explicit-any */ -import { - EventName, - events, -} from '@janhq/core' +import { EventName, events } from '@janhq/core' import { Model, ModelSettingParams } from '@janhq/core' import { atom, useAtom } from 'jotai' diff --git a/web/hooks/useCreateNewThread.ts b/web/hooks/useCreateNewThread.ts index e2f2aa35dc..ff0b4d0494 100644 --- a/web/hooks/useCreateNewThread.ts +++ b/web/hooks/useCreateNewThread.ts @@ -6,9 +6,9 @@ import { ThreadAssistantInfo, ThreadState, } from '@janhq/core' -import { atom, useAtom, useAtomValue, useSetAtom } from 'jotai' +import { atom, useAtomValue, useSetAtom } from 'jotai' -import { generateThreadId } from '@/utils/conversation' +import { generateThreadId } from '@/utils/thread' import { extensionManager } from '@/extension' import { @@ -16,7 +16,8 @@ import { setActiveThreadIdAtom, threadStatesAtom, updateThreadAtom, -} from '@/helpers/atoms/Conversation.atom' + setThreadModelRuntimeParamsAtom, +} from '@/helpers/atoms/Thread.atom' const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => { // create thread state for this new thread @@ -25,6 +26,8 @@ const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => { const threadState: ThreadState = { hasMore: false, waitingForResponse: false, + lastMessage: undefined, + isFinishInit: false, } currentState[newThread.id] = threadState set(threadStatesAtom, currentState) @@ -35,15 +38,26 @@ const createNewThreadAtom = atom(null, (get, set, newThread: Thread) => { }) export const useCreateNewThread = () => { + const threadStates = useAtomValue(threadStatesAtom) const createNewThread = useSetAtom(createNewThreadAtom) const setActiveThreadId = useSetAtom(setActiveThreadIdAtom) - const [threadStates, setThreadStates] = useAtom(threadStatesAtom) - const threads = useAtomValue(threadsAtom) const updateThread = useSetAtom(updateThreadAtom) + const setThreadModelRuntimeParams = useSetAtom( + setThreadModelRuntimeParamsAtom + ) const requestCreateNewThread = async (assistant: Assistant) => { - const unfinishedThreads = threads.filter((t) => t.isFinishInit === false) - if (unfinishedThreads.length > 0) { + // loop through threads state and filter if there's any thread that is not finish init + let hasUnfinishedInitThread = false + for (const key in threadStates) { + const isFinishInit = threadStates[key].isFinishInit ?? true + if (!isFinishInit) { + hasUnfinishedInitThread = true + break + } + } + + if (hasUnfinishedInitThread) { return } @@ -53,18 +67,10 @@ export const useCreateNewThread = () => { assistant_name: assistant.name, model: { id: '*', - settings: { - ctx_len: 0, - ngl: 0, - embedding: false, - n_parallel: 0, - }, + settings: {}, parameters: { - temperature: 0, - token_limit: 0, - top_k: 0, - top_p: 0, - stream: false, + stream: true, + max_tokens: 1024, }, engine: undefined, }, @@ -78,29 +84,20 @@ export const useCreateNewThread = () => { assistants: [assistantInfo], created: createdAt, updated: createdAt, - isFinishInit: false, } - // TODO: move isFinishInit here - const threadState: ThreadState = { - hasMore: false, - waitingForResponse: false, - lastMessage: undefined, - } - setThreadStates({ ...threadStates, [threadId]: threadState }) + setThreadModelRuntimeParams(thread.id, assistantInfo.model.parameters) + // add the new thread on top of the thread list to the state createNewThread(thread) setActiveThreadId(thread.id) } function updateThreadMetadata(thread: Thread) { - const updatedThread: Thread = { - ...thread, - } - updateThread(updatedThread) + updateThread(thread) extensionManager .get(ExtensionType.Conversational) - ?.saveThread(updatedThread) + ?.saveThread(thread) } return { diff --git a/web/hooks/useDeleteConversation.ts b/web/hooks/useDeleteThread.ts similarity index 62% rename from web/hooks/useDeleteConversation.ts rename to web/hooks/useDeleteThread.ts index b02796b104..8822b6aa86 100644 --- a/web/hooks/useDeleteConversation.ts +++ b/web/hooks/useDeleteThread.ts @@ -11,14 +11,15 @@ import { useActiveModel } from './useActiveModel' import { extensionManager } from '@/extension/ExtensionManager' import { - cleanConversationMessages, - deleteConversationMessage, + cleanChatMessageAtom as cleanChatMessagesAtom, + deleteChatMessageAtom as deleteChatMessagesAtom, getCurrentChatMessagesAtom, } from '@/helpers/atoms/ChatMessage.atom' import { threadsAtom, setActiveThreadIdAtom, -} from '@/helpers/atoms/Conversation.atom' + deleteThreadStateAtom, +} from '@/helpers/atoms/Thread.atom' export default function useDeleteThread() { const { activeModel } = useActiveModel() @@ -26,45 +27,51 @@ export default function useDeleteThread() { const setCurrentPrompt = useSetAtom(currentPromptAtom) const messages = useAtomValue(getCurrentChatMessagesAtom) - const setActiveConvoId = useSetAtom(setActiveThreadIdAtom) - const deleteMessages = useSetAtom(deleteConversationMessage) - const cleanMessages = useSetAtom(cleanConversationMessages) + const setActiveThreadId = useSetAtom(setActiveThreadIdAtom) + const deleteMessages = useSetAtom(deleteChatMessagesAtom) + const cleanMessages = useSetAtom(cleanChatMessagesAtom) + const deleteThreadState = useSetAtom(deleteThreadStateAtom) + + const cleanThread = async (threadId: string) => { + if (threadId) { + const thread = threads.filter((c) => c.id === threadId)[0] + cleanMessages(threadId) - const cleanThread = async (activeThreadId: string) => { - if (activeThreadId) { - const thread = threads.filter((c) => c.id === activeThreadId)[0] - cleanMessages(activeThreadId) if (thread) await extensionManager .get(ExtensionType.Conversational) ?.writeMessages( - activeThreadId, + threadId, messages.filter((msg) => msg.role === ChatCompletionRole.System) ) } } - const deleteThread = async (activeThreadId: string) => { - if (!activeThreadId) { + const deleteThread = async (threadId: string) => { + if (!threadId) { alert('No active thread') return } try { await extensionManager .get(ExtensionType.Conversational) - ?.deleteThread(activeThreadId) - const availableThreads = threads.filter((c) => c.id !== activeThreadId) + ?.deleteThread(threadId) + const availableThreads = threads.filter((c) => c.id !== threadId) setThreads(availableThreads) - deleteMessages(activeThreadId) + + // delete the thread state + deleteThreadState(threadId) + + deleteMessages(threadId) setCurrentPrompt('') toaster({ title: 'Thread successfully deleted.', description: `Thread with ${activeModel?.name} has been successfully deleted.`, }) if (availableThreads.length > 0) { - setActiveConvoId(availableThreads[0].id) + setActiveThreadId(availableThreads[0].id) } else { - setActiveConvoId(undefined) + setActiveThreadId(undefined) } } catch (err) { console.error(err) diff --git a/web/hooks/useGetAllThreads.ts b/web/hooks/useGetAllThreads.ts index 488e64f646..8674346175 100644 --- a/web/hooks/useGetAllThreads.ts +++ b/web/hooks/useGetAllThreads.ts @@ -1,35 +1,50 @@ -import { ExtensionType, ThreadState } from '@janhq/core' +import { ExtensionType, ModelRuntimeParams, ThreadState } from '@janhq/core' import { ConversationalExtension } from '@janhq/core' import { useSetAtom } from 'jotai' import { extensionManager } from '@/extension/ExtensionManager' import { + threadModelRuntimeParamsAtom, threadStatesAtom, threadsAtom, -} from '@/helpers/atoms/Conversation.atom' +} from '@/helpers/atoms/Thread.atom' const useGetAllThreads = () => { - const setConversationStates = useSetAtom(threadStatesAtom) - const setConversations = useSetAtom(threadsAtom) + const setThreadStates = useSetAtom(threadStatesAtom) + const setThreads = useSetAtom(threadsAtom) + const setThreadModelRuntimeParams = useSetAtom(threadModelRuntimeParamsAtom) const getAllThreads = async () => { try { - const threads = await extensionManager - .get(ExtensionType.Conversational) - ?.getThreads() + const threads = + (await extensionManager + .get(ExtensionType.Conversational) + ?.getThreads()) ?? [] + const threadStates: Record = {} - threads?.forEach((thread) => { + const threadModelParams: Record = {} + + threads.forEach((thread) => { if (thread.id != null) { const lastMessage = (thread.metadata?.lastMessage as string) ?? '' + threadStates[thread.id] = { hasMore: true, waitingForResponse: false, lastMessage, + isFinishInit: true, } + + // model params + const modelParams = thread.assistants?.[0]?.model?.parameters + threadModelParams[thread.id] = modelParams } }) - setConversationStates(threadStates) - setConversations(threads ?? []) + + // updating app states + setThreadStates(threadStates) + setThreads(threads) + setThreadModelRuntimeParams(threadModelParams) } catch (error) { console.error(error) } diff --git a/web/hooks/useGetConfiguredModels.ts b/web/hooks/useGetConfiguredModels.ts index 7c4d94edd4..d79778a00c 100644 --- a/web/hooks/useGetConfiguredModels.ts +++ b/web/hooks/useGetConfiguredModels.ts @@ -19,9 +19,6 @@ export function useGetConfiguredModels() { async function fetchModels() { setLoading(true) const models = await getConfiguredModels() - if (process.env.NODE_ENV === 'development') { - // models = [dummyModel, ...models] // TODO: NamH add back dummy model later - } setLoading(false) setModels(models) } diff --git a/web/hooks/useSendChatMessage.ts b/web/hooks/useSendChatMessage.ts index 970aedbecd..8913104d32 100644 --- a/web/hooks/useSendChatMessage.ts +++ b/web/hooks/useSendChatMessage.ts @@ -12,8 +12,9 @@ import { ThreadMessage, events, Model, + ConversationalExtension, + ModelRuntimeParams, } from '@janhq/core' -import { ConversationalExtension } from '@janhq/core' import { useAtom, useAtomValue, useSetAtom } from 'jotai' import { ulid } from 'ulid' @@ -32,9 +33,12 @@ import { } from '@/helpers/atoms/ChatMessage.atom' import { activeThreadAtom, + getActiveThreadModelRuntimeParamsAtom, + threadStatesAtom, updateThreadAtom, + updateThreadInitSuccessAtom, updateThreadWaitingForResponseAtom, -} from '@/helpers/atoms/Conversation.atom' +} from '@/helpers/atoms/Thread.atom' export default function useSendChatMessage() { const activeThread = useAtomValue(activeThreadAtom) @@ -50,6 +54,10 @@ export default function useSendChatMessage() { const [queuedMessage, setQueuedMessage] = useState(false) const modelRef = useRef() + const threadStates = useAtomValue(threadStatesAtom) + const updateThreadInitSuccess = useSetAtom(updateThreadInitSuccessAtom) + const activeModelParams = useAtomValue(getActiveThreadModelRuntimeParamsAtom) + useEffect(() => { modelRef.current = activeModel }, [activeModel]) @@ -109,7 +117,7 @@ export default function useSendChatMessage() { return new Promise((resolve) => { setTimeout(async () => { if (modelRef.current?.id !== modelId) { - console.log('waiting for model to start') + console.debug('waiting for model to start') await WaitForModelStarting(modelId) resolve() } else { @@ -127,8 +135,10 @@ export default function useSendChatMessage() { console.error('No active thread') return } + const activeThreadState = threadStates[activeThread.id] - if (!activeThread.isFinishInit) { + // if the thread is not initialized, we need to initialize it first + if (!activeThreadState.isFinishInit) { if (!selectedModel) { toaster({ title: 'Please select a model' }) return @@ -136,9 +146,14 @@ export default function useSendChatMessage() { const assistantId = activeThread.assistants[0].assistant_id ?? '' const assistantName = activeThread.assistants[0].assistant_name ?? '' const instructions = activeThread.assistants[0].instructions ?? '' + + const modelParams: ModelRuntimeParams = { + ...selectedModel.parameters, + ...activeModelParams, + } + const updatedThread: Thread = { ...activeThread, - isFinishInit: true, assistants: [ { assistant_id: assistantId, @@ -147,13 +162,13 @@ export default function useSendChatMessage() { model: { id: selectedModel.id, settings: selectedModel.settings, - parameters: selectedModel.parameters, + parameters: modelParams, engine: selectedModel.engine, }, }, ], } - + updateThreadInitSuccess(activeThread.id) updateThread(updatedThread) extensionManager @@ -191,11 +206,16 @@ export default function useSendChatMessage() { ]) ) const msgId = ulid() + + const modelRequest = selectedModel ?? activeThread.assistants[0].model const messageRequest: MessageRequest = { id: msgId, threadId: activeThread.id, messages, - model: selectedModel ?? activeThread.assistants[0].model, + model: { + ...modelRequest, + ...(activeModelParams ? { parameters: activeModelParams } : {}), + }, } const timestamp = Date.now() const threadMessage: ThreadMessage = { diff --git a/web/hooks/useSetActiveThread.ts b/web/hooks/useSetActiveThread.ts index a0d5841165..0705901c3f 100644 --- a/web/hooks/useSetActiveThread.ts +++ b/web/hooks/useSetActiveThread.ts @@ -9,7 +9,7 @@ import { setConvoMessagesAtom } from '@/helpers/atoms/ChatMessage.atom' import { getActiveThreadIdAtom, setActiveThreadIdAtom, -} from '@/helpers/atoms/Conversation.atom' +} from '@/helpers/atoms/Thread.atom' export default function useSetActiveThread() { const activeThreadId = useAtomValue(getActiveThreadIdAtom) diff --git a/web/hooks/useUpdateModelParameters.ts b/web/hooks/useUpdateModelParameters.ts new file mode 100644 index 0000000000..d6bb2d0db7 --- /dev/null +++ b/web/hooks/useUpdateModelParameters.ts @@ -0,0 +1,66 @@ +import { + ConversationalExtension, + ExtensionType, + ModelRuntimeParams, + Thread, +} from '@janhq/core' + +import { useAtomValue, useSetAtom } from 'jotai' + +import { extensionManager } from '@/extension' +import { + activeThreadStateAtom, + setThreadModelRuntimeParamsAtom, + threadsAtom, + updateThreadAtom, +} from '@/helpers/atoms/Thread.atom' + +export default function useUpdateModelParameters() { + const threads = useAtomValue(threadsAtom) + const updateThread = useSetAtom(updateThreadAtom) + const setThreadModelRuntimeParams = useSetAtom( + setThreadModelRuntimeParamsAtom + ) + const activeThreadState = useAtomValue(activeThreadStateAtom) + + const updateModelParameter = async ( + threadId: string, + params: ModelRuntimeParams + ) => { + const thread = threads.find((thread) => thread.id === threadId) + if (!thread) { + console.error(`Thread ${threadId} not found`) + return + } + + if (!activeThreadState) { + console.error('No active thread') + return + } + + // update the state + setThreadModelRuntimeParams(thread.id, params) + + if (!activeThreadState.isFinishInit) { + // if thread is not initialized, we don't need to update thread.json + return + } + + const assistants = thread.assistants.map((assistant) => { + assistant.model.parameters = params + return assistant + }) + + // update thread + const updatedThread: Thread = { + ...thread, + assistants, + } + updateThread(updatedThread) + extensionManager + .get(ExtensionType.Conversational) + ?.saveThread(updatedThread) + } + + return { updateModelParameter } +} diff --git a/web/package.json b/web/package.json index dd7faeb1d5..15d2830b09 100644 --- a/web/package.json +++ b/web/package.json @@ -38,7 +38,6 @@ "sass": "^1.69.4", "tailwind-merge": "^2.0.0", "tailwindcss": "3.3.5", - "typescript": "5.2.2", "ulid": "^2.3.0", "uuid": "^9.0.1", "zod": "^3.22.4" diff --git a/web/screens/Chat/MessageToolbar/index.tsx b/web/screens/Chat/MessageToolbar/index.tsx index fe7cac1f5f..7f8e5ca7eb 100644 --- a/web/screens/Chat/MessageToolbar/index.tsx +++ b/web/screens/Chat/MessageToolbar/index.tsx @@ -21,7 +21,7 @@ import { deleteMessageAtom, getCurrentChatMessagesAtom, } from '@/helpers/atoms/ChatMessage.atom' -import { activeThreadAtom } from '@/helpers/atoms/Conversation.atom' +import { activeThreadAtom } from '@/helpers/atoms/Thread.atom' const MessageToolbar = ({ message }: { message: ThreadMessage }) => { const deleteMessage = useSetAtom(deleteMessageAtom) diff --git a/web/screens/Chat/ModelSetting/index.tsx b/web/screens/Chat/ModelSetting/index.tsx new file mode 100644 index 0000000000..e8c9b24538 --- /dev/null +++ b/web/screens/Chat/ModelSetting/index.tsx @@ -0,0 +1,60 @@ +import { useEffect, useState } from 'react' + +import { useForm } from 'react-hook-form' + +import { ModelRuntimeParams } from '@janhq/core' + +import { useAtomValue } from 'jotai' + +import { presetConfiguration } from './predefinedComponent' +import settingComponentBuilder, { + SettingComponentData, +} from './settingComponentBuilder' + +import { + getActiveThreadIdAtom, + getActiveThreadModelRuntimeParamsAtom, +} from '@/helpers/atoms/Thread.atom' + +export default function ModelSetting() { + const threadId = useAtomValue(getActiveThreadIdAtom) + const activeModelParams = useAtomValue(getActiveThreadModelRuntimeParamsAtom) + const [modelParams, setModelParams] = useState< + ModelRuntimeParams | undefined + >(activeModelParams) + + const { register } = useForm() + + useEffect(() => { + setModelParams(activeModelParams) + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [threadId]) + + if (!modelParams) { + return
This thread has no model parameters
+ } + + const componentData: SettingComponentData[] = [] + Object.keys(modelParams).forEach((key) => { + const componentSetting = presetConfiguration[key] + + if (componentSetting) { + if ('value' in componentSetting.controllerData) { + componentSetting.controllerData.value = Number( + modelParams[key as keyof ModelRuntimeParams] + ) + } else if ('checked' in componentSetting.controllerData) { + componentSetting.controllerData.checked = modelParams[ + key as keyof ModelRuntimeParams + ] as boolean + } + componentData.push(componentSetting) + } + }) + + return ( +
+ {settingComponentBuilder(componentData, register)} +
+ ) +} diff --git a/web/screens/Chat/ModelSetting/predefinedComponent.ts b/web/screens/Chat/ModelSetting/predefinedComponent.ts new file mode 100644 index 0000000000..d8299ae10f --- /dev/null +++ b/web/screens/Chat/ModelSetting/predefinedComponent.ts @@ -0,0 +1,59 @@ +import { SettingComponentData } from './settingComponentBuilder' + +export const presetConfiguration: Record = { + max_tokens: { + name: 'max_tokens', + title: 'Max Tokens', + description: 'Maximum context length the model can handle.', + controllerType: 'slider', + controllerData: { + min: 0, + max: 4096, + step: 128, + value: 2048, + }, + }, + ngl: { + name: 'ngl', + title: 'NGL', + description: 'Number of layers in the neural network.', + controllerType: 'slider', + controllerData: { + min: 1, + max: 100, + step: 1, + value: 100, + }, + }, + embedding: { + name: 'embedding', + title: 'Embedding', + description: 'Indicates if embedding layers are used.', + controllerType: 'checkbox', + controllerData: { + checked: true, + }, + }, + stream: { + name: 'stream', + title: 'Stream', + description: 'Stream', + controllerType: 'checkbox', + controllerData: { + checked: false, + }, + }, + temperature: { + name: 'temperature', + title: 'Temperature', + description: + "Controls randomness in model's responses. Higher values lead to more random responses.", + controllerType: 'slider', + controllerData: { + min: 0, + max: 2, + step: 0.1, + value: 0.7, + }, + }, +} diff --git a/web/screens/Chat/ModelSetting/settingComponentBuilder.tsx b/web/screens/Chat/ModelSetting/settingComponentBuilder.tsx new file mode 100644 index 0000000000..604e707733 --- /dev/null +++ b/web/screens/Chat/ModelSetting/settingComponentBuilder.tsx @@ -0,0 +1,67 @@ +/* eslint-disable no-case-declarations */ +/* eslint-disable @typescript-eslint/no-explicit-any */ + +import Checkbox from '@/containers/Checkbox' +import Slider from '@/containers/Slider' + +export type ControllerType = 'slider' | 'checkbox' + +export type SettingComponentData = { + name: string + title: string + description: string + controllerType: ControllerType + controllerData: SliderData | CheckboxData +} + +export type SliderData = { + min: number + max: number + step: number + value: number +} + +type CheckboxData = { + checked: boolean +} + +const settingComponentBuilder = ( + componentData: SettingComponentData[], + register: any +) => { + const components = componentData.map((data) => { + switch (data.controllerType) { + case 'slider': + const { min, max, step, value } = data.controllerData as SliderData + return ( + + ) + case 'checkbox': + const { checked } = data.controllerData as CheckboxData + return ( + + ) + default: + return null + } + }) + + return
{components}
+} + +export default settingComponentBuilder diff --git a/web/screens/Chat/Sidebar/index.tsx b/web/screens/Chat/Sidebar/index.tsx index cf8c46b489..7c3fc57db4 100644 --- a/web/screens/Chat/Sidebar/index.tsx +++ b/web/screens/Chat/Sidebar/index.tsx @@ -16,7 +16,9 @@ import DropdownListSidebar, { import { useCreateNewThread } from '@/hooks/useCreateNewThread' -import { activeThreadAtom } from '@/helpers/atoms/Conversation.atom' +import ModelSetting from '../ModelSetting' + +import { activeThreadAtom, threadStatesAtom } from '@/helpers/atoms/Thread.atom' export const showRightSideBarAtom = atom(true) @@ -25,10 +27,12 @@ export default function Sidebar() { const activeThread = useAtomValue(activeThreadAtom) const selectedModel = useAtomValue(selectedModelAtom) const { updateThreadMetadata } = useCreateNewThread() + const threadStates = useAtomValue(threadStatesAtom) const onReviewInFinderClick = async (type: string) => { if (!activeThread) return - if (!activeThread.isFinishInit) { + const activeThreadState = threadStates[activeThread.id] + if (!activeThreadState.isFinishInit) { alert('Thread is not started yet') return } @@ -60,7 +64,8 @@ export default function Sidebar() { const onViewJsonClick = async (type: string) => { if (!activeThread) return - if (!activeThread.isFinishInit) { + const activeThreadState = threadStates[activeThread.id] + if (!activeThreadState.isFinishInit) { alert('Thread is not started yet') return } @@ -189,6 +194,9 @@ export default function Sidebar() { >
+
+ +
diff --git a/web/screens/Chat/ThreadList/index.tsx b/web/screens/Chat/ThreadList/index.tsx index 8a4b85d17a..5b5a8d91da 100644 --- a/web/screens/Chat/ThreadList/index.tsx +++ b/web/screens/Chat/ThreadList/index.tsx @@ -12,7 +12,7 @@ import { import { twMerge } from 'tailwind-merge' import { useCreateNewThread } from '@/hooks/useCreateNewThread' -import useDeleteThread from '@/hooks/useDeleteConversation' +import useDeleteThread from '@/hooks/useDeleteThread' import useGetAllThreads from '@/hooks/useGetAllThreads' import useGetAssistants from '@/hooks/useGetAssistants' @@ -25,7 +25,7 @@ import { activeThreadAtom, threadStatesAtom, threadsAtom, -} from '@/helpers/atoms/Conversation.atom' +} from '@/helpers/atoms/Thread.atom' export default function ThreadList() { const threads = useAtomValue(threadsAtom) diff --git a/web/screens/Chat/index.tsx b/web/screens/Chat/index.tsx index 7053ae1a29..741fadbaf8 100644 --- a/web/screens/Chat/index.tsx +++ b/web/screens/Chat/index.tsx @@ -29,9 +29,9 @@ import { activeThreadAtom, getActiveThreadIdAtom, waitingToSendMessage, -} from '@/helpers/atoms/Conversation.atom' +} from '@/helpers/atoms/Thread.atom' -import { activeThreadStateAtom } from '@/helpers/atoms/Conversation.atom' +import { activeThreadStateAtom } from '@/helpers/atoms/Thread.atom' const ChatScreen = () => { const activeThread = useAtomValue(activeThreadAtom) diff --git a/web/utils/conversation.ts b/web/utils/thread.ts similarity index 100% rename from web/utils/conversation.ts rename to web/utils/thread.ts