Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow users to add remote models #4534

Merged
merged 3 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions core/src/browser/extensions/enginesManagement.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
EngineReleased,
EngineConfig,
DefaultEngineVariant,
Model,
} from '../../types'
import { BaseExtension, ExtensionTypeEnum } from '../extension'

Expand All @@ -15,7 +16,7 @@
*/
export abstract class EngineManagementExtension extends BaseExtension {
type(): ExtensionTypeEnum | undefined {
return ExtensionTypeEnum.Engine

Check warning on line 19 in core/src/browser/extensions/enginesManagement.ts

View workflow job for this annotation

GitHub Actions / coverage-check

19 line is not covered with tests
}

/**
Expand Down Expand Up @@ -103,6 +104,11 @@
engineConfig?: EngineConfig
): Promise<{ messages: string }>

/**
* Add a new remote model for a specific engine
*/
abstract addRemoteModel(model: Model): Promise<void>

/**
* @returns A Promise that resolves to an object of remote models list .
*/
Expand Down
7 changes: 5 additions & 2 deletions core/src/types/model/modelEntity.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import { FileMetadata } from '../file'

/**
* Represents the information about a model.
* @stored
Expand Down Expand Up @@ -70,6 +68,11 @@ export type Model = {
*/
id: string

/**
* The model identifier, modern version of id.
*/
mode?: string

/**
* Human-readable name that is used for UI.
*/
Expand Down
30 changes: 30 additions & 0 deletions web/hooks/useEngineManagement.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
EngineConfig,
events,
EngineEvent,
Model,
ModelEvent,
} from '@janhq/core'
import { useAtom } from 'jotai'
import { atomWithStorage } from 'jotai/utils'
Expand All @@ -30,10 +32,10 @@
extension: EngineManagementExtension | null,
method: (extension: EngineManagementExtension) => Promise<T>
): Promise<T> {
if (!extension) {
throw new Error('Extension not found')

Check warning on line 36 in web/hooks/useEngineManagement.ts

View workflow job for this annotation

GitHub Actions / coverage-check

35-36 lines are not covered with tests
}
return method(extension)

Check warning on line 38 in web/hooks/useEngineManagement.ts

View workflow job for this annotation

GitHub Actions / coverage-check

38 line is not covered with tests
}

/**
Expand All @@ -54,7 +56,7 @@
mutate,
} = useSWR(
extension ? 'engines' : null,
() => fetchExtensionData(extension, (ext) => ext.getEngines()),

Check warning on line 59 in web/hooks/useEngineManagement.ts

View workflow job for this annotation

GitHub Actions / coverage-check

59 line is not covered with tests
{
revalidateOnFocus: false,
revalidateOnReconnect: true,
Expand All @@ -67,10 +69,10 @@
/**
* @returns A Promise that resolves to an object of remote models.
*/
export function useGetRemoteModels(name: string) {
const extension = useMemo(

Check warning on line 73 in web/hooks/useEngineManagement.ts

View workflow job for this annotation

GitHub Actions / coverage-check

72-73 lines are not covered with tests
() =>
extensionManager.get<EngineManagementExtension>(

Check warning on line 75 in web/hooks/useEngineManagement.ts

View workflow job for this annotation

GitHub Actions / coverage-check

75 line is not covered with tests
ExtensionTypeEnum.Engine
) ?? null,
[]
Expand All @@ -80,24 +82,24 @@
data: remoteModels,
error,
mutate,
} = useSWR(

Check warning on line 85 in web/hooks/useEngineManagement.ts

View workflow job for this annotation

GitHub Actions / coverage-check

85 line is not covered with tests
extension ? 'remoteModels' : null,
() => fetchExtensionData(extension, (ext) => ext.getRemoteModels(name)),

Check warning on line 87 in web/hooks/useEngineManagement.ts

View workflow job for this annotation

GitHub Actions / coverage-check

87 line is not covered with tests
{
revalidateOnFocus: false,
revalidateOnReconnect: true,
}
)

return { remoteModels, error, mutate }

Check warning on line 94 in web/hooks/useEngineManagement.ts

View workflow job for this annotation

GitHub Actions / coverage-check

94 line is not covered with tests
}

/**
* @param name - Inference engine name.
* @returns A Promise that resolves to an array of installed engine.
*/
export function useGetInstalledEngines(name: InferenceEngine) {
const extension = useMemo(

Check warning on line 102 in web/hooks/useEngineManagement.ts

View workflow job for this annotation

GitHub Actions / coverage-check

101-102 lines are not covered with tests
() =>
extensionManager.get<EngineManagementExtension>(
ExtensionTypeEnum.Engine
Expand Down Expand Up @@ -385,3 +387,31 @@
throw error
}
}

/**
* Add a new remote engine model
* @param name
* @param engine
* @returns
*/
export const addRemoteEngineModel = async (name: string, engine: string) => {
const extension = getExtension()

if (!extension) {
throw new Error('Extension is not available')
}

try {
// Call the extension's method
const response = await extension.addRemoteModel({
id: name,
model: name,
engine: engine as InferenceEngine,
} as unknown as Model)
events.emit(ModelEvent.OnModelsUpdate, { fetch: true })
return response
} catch (error) {
console.error('Failed to install engine variant:', error)
throw error
}
}
155 changes: 155 additions & 0 deletions web/screens/Settings/Engines/ModalAddModel.tsx
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import { memo, ReactNode, useState } from 'react'

import { useForm } from 'react-hook-form'

import Image from 'next/image'

import { zodResolver } from '@hookform/resolvers/zod'

import { InferenceEngine, Model } from '@janhq/core'

import { Button, Input, Modal } from '@janhq/joi'
import { useAtomValue } from 'jotai'
import { PlusIcon } from 'lucide-react'

import { z } from 'zod'

import {
addRemoteEngineModel,
useGetEngines,
useGetRemoteModels,
} from '@/hooks/useEngineManagement'

import { getLogoEngine, getTitleByEngine } from '@/utils/modelEngine'

import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'

const modelSchema = z.object({
modelName: z.string().min(1, 'Model name is required'),
})

const ModelAddModel = ({ engine }: { engine: string }) => {
const [open, setOpen] = useState(false)
const { mutate: mutateListEngines } = useGetRemoteModels(engine)
const { engines } = useGetEngines()
const models = useAtomValue(downloadedModelsAtom)
const {
register,
handleSubmit,
formState: { errors },
setError,
} = useForm({
resolver: zodResolver(modelSchema),
defaultValues: {
modelName: '',
},
})

const onSubmit = async (data: z.infer<typeof modelSchema>) => {
if (models.some((e: Model) => e.id === data.modelName)) {
setError('modelName', {
type: 'manual',
message: 'Model already exists',
})
return
}
await addRemoteEngineModel(data.modelName, engine)
mutateListEngines()

setOpen(false)
}

// Helper to render labels with asterisks for required fields
const renderLabel = (
prefix: ReactNode,
label: string,
isRequired: boolean,
desc?: string
) => (
<>
<span className="flex flex-row items-center gap-1">
{prefix}
{label}
</span>
<p className="mt-1 font-normal text-[hsla(var(--text-secondary))]">
{desc}
{isRequired && <span className="text-red-500">*</span>}
</p>
</>
)

return (
<Modal
title={
<div>
<p>Add Model</p>
</div>
}
fullPage
open={open}
onOpenChange={() => setOpen(!open)}
trigger={
<Button>
<PlusIcon className="mr-2" size={14} />
Add Model
</Button>
}
className="w-[500px]"
content={
<div>
<form className="mt-8 space-y-6" onSubmit={handleSubmit(onSubmit)}>
<div className="space-y-2">
<label htmlFor="modelName" className="font-semibold">
{renderLabel(
getLogoEngine(engine as InferenceEngine) ? (
<Image
src={getLogoEngine(engine as InferenceEngine) ?? ''}
width={40}
height={40}
alt="Engine logo"
className="h-5 w-5 flex-shrink-0"
/>
) : (
<></>
),
getTitleByEngine(engine as InferenceEngine) ?? engine,
false,
'Model ID'
)}
</label>
<Input placeholder="Enter model ID" {...register('modelName')} />
{errors.modelName && (
<p className="text-sm text-red-500">
{errors.modelName.message}
</p>
)}
<div className="pt-4">
<a
target="_blank"
href={engines?.[engine as InferenceEngine]?.[0]?.url}
className="text-[hsla(var(--app-link))]"
>
See model list from{' '}
{getTitleByEngine(engine as InferenceEngine)}
</a>
</div>
</div>

<div className="mt-8 flex justify-end gap-x-2">
<Button
theme="ghost"
variant="outline"
onClick={() => setOpen(false)}
>
Cancel
</Button>
<Button type="submit">Add</Button>
</div>
</form>
</div>
}
/>
)
}

export default memo(ModelAddModel)
5 changes: 4 additions & 1 deletion web/screens/Settings/Engines/RemoteEngineSettings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import { updateEngine, useGetEngines } from '@/hooks/useEngineManagement'

import { getTitleByEngine } from '@/utils/modelEngine'

import ModalAddModel from './ModalAddModel'

import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'

const RemoteEngineSettings = ({
Expand Down Expand Up @@ -194,10 +196,11 @@ const RemoteEngineSettings = ({
<div className="mb-3 mt-4 pb-4">
<div className="flex w-full flex-col items-start justify-between sm:flex-row">
<div className="w-full flex-shrink-0 ">
<div className="flex items-center justify-between gap-x-2">
<div className="mb-4 flex items-center justify-between gap-x-2">
<div>
<h6 className="mb-2 line-clamp-1 font-semibold">Model</h6>
</div>
<ModalAddModel engine={name} />
</div>

<div>
Expand Down
Loading