diff --git a/core/src/types/model/modelEntity.ts b/core/src/types/model/modelEntity.ts index 482dfa1ac9..d7db7b9d6a 100644 --- a/core/src/types/model/modelEntity.ts +++ b/core/src/types/model/modelEntity.ts @@ -71,7 +71,7 @@ export type Model = { /** * The model identifier, modern version of id. */ - mode?: string + model?: string /** * Human-readable name that is used for UI. diff --git a/extensions/engine-management-extension/rolldown.config.mjs b/extensions/engine-management-extension/rolldown.config.mjs index 1290338db5..7e8cdcd2b3 100644 --- a/extensions/engine-management-extension/rolldown.config.mjs +++ b/extensions/engine-management-extension/rolldown.config.mjs @@ -16,6 +16,13 @@ export default defineConfig([ CORTEX_ENGINE_VERSION: JSON.stringify('v0.1.49'), DEFAULT_REMOTE_ENGINES: JSON.stringify(engines), DEFAULT_REMOTE_MODELS: JSON.stringify(models), + DEFAULT_REQUEST_PAYLOAD_TRANSFORM: JSON.stringify('{{ tojson(value) }}'), + DEFAULT_RESPONSE_BODY_TRANSFORM: JSON.stringify( + '{ {% set first = true %} {% for key, value in input_request %} {% if key == "choices" or key == "created" or key == "model" or key == "service_tier" or key == "stream" or key == "object" or key == "usage" %} {% if not first %},{% endif %} "{{ key }}": {{ tojson(value) }} {% set first = false %} {% endif %} {% endfor %} }' + ), + DEFAULT_REQUEST_HEADERS_TRANSFORM: JSON.stringify( + 'Authorization: Bearer {{api_key}}' + ), }, }, { diff --git a/extensions/engine-management-extension/src/@types/global.d.ts b/extensions/engine-management-extension/src/@types/global.d.ts index 2d520d5f9e..1785095780 100644 --- a/extensions/engine-management-extension/src/@types/global.d.ts +++ b/extensions/engine-management-extension/src/@types/global.d.ts @@ -2,6 +2,9 @@ declare const API_URL: string declare const CORTEX_ENGINE_VERSION: string declare const SOCKET_URL: string declare const NODE: string +declare const DEFAULT_REQUEST_PAYLOAD_TRANSFORM: string +declare const DEFAULT_RESPONSE_BODY_TRANSFORM: string +declare const DEFAULT_REQUEST_HEADERS_TRANSFORM: string declare const DEFAULT_REMOTE_ENGINES: ({ id: string diff --git a/extensions/engine-management-extension/src/index.ts b/extensions/engine-management-extension/src/index.ts index 0d30bf4eac..4e3e30f751 100644 --- a/extensions/engine-management-extension/src/index.ts +++ b/extensions/engine-management-extension/src/index.ts @@ -20,6 +20,9 @@ import PQueue from 'p-queue' import { EngineError } from './error' import { getJanDataFolderPath } from '@janhq/core' +interface ModelList { + data: Model[] +} /** * JSONEngineManagementExtension is a EngineManagementExtension implementation that provides * functionality for managing engines. @@ -63,13 +66,12 @@ export default class JSONEngineManagementExtension extends EngineManagementExten * @returns A Promise that resolves to an object of list engines. */ async getRemoteModels(name: string): Promise { - return this.queue.add(() => - ky - .get(`${API_URL}/v1/models/remote/${name}`) - .json() - .then((e) => e) - .catch(() => []) - ) as Promise + return ky + .get(`${API_URL}/v1/models/remote/${name}`) + .json() + .catch(() => ({ + data: [], + })) as Promise } /** @@ -138,9 +140,36 @@ export default class JSONEngineManagementExtension extends EngineManagementExten * Add a new remote engine * @returns A Promise that resolves to intall of engine. */ - async addRemoteEngine(engineConfig: EngineConfig) { + async addRemoteEngine( + engineConfig: EngineConfig, + persistModels: boolean = true + ) { + // Populate default settings + if ( + engineConfig.metadata?.transform_req?.chat_completions && + !engineConfig.metadata.transform_req.chat_completions.template + ) + engineConfig.metadata.transform_req.chat_completions.template = + DEFAULT_REQUEST_PAYLOAD_TRANSFORM + + if ( + engineConfig.metadata?.transform_resp?.chat_completions && + !engineConfig.metadata.transform_resp.chat_completions?.template + ) + engineConfig.metadata.transform_resp.chat_completions.template = + DEFAULT_RESPONSE_BODY_TRANSFORM + + if (engineConfig.metadata && !engineConfig.metadata?.header_template) + engineConfig.metadata.header_template = DEFAULT_REQUEST_HEADERS_TRANSFORM + return this.queue.add(() => - ky.post(`${API_URL}/v1/engines`, { json: engineConfig }).then((e) => e) + ky.post(`${API_URL}/v1/engines`, { json: engineConfig }).then((e) => { + if (persistModels && engineConfig.metadata?.get_models_url) { + // Pull /models from remote models endpoint + return this.populateRemoteModels(engineConfig).then(() => e) + } + return e + }) ) as Promise<{ messages: string }> } @@ -161,9 +190,11 @@ export default class JSONEngineManagementExtension extends EngineManagementExten * @param model - Remote model object. */ async addRemoteModel(model: Model) { - return this.queue.add(() => - ky.post(`${API_URL}/v1/models/add`, { json: model }).then((e) => e) - ) + return this.queue + .add(() => + ky.post(`${API_URL}/v1/models/add`, { json: model }).then((e) => e) + ) + .then(() => {}) } /** @@ -293,7 +324,7 @@ export default class JSONEngineManagementExtension extends EngineManagementExten data.api_key = api_key /// END - Migrate legacy api key settings - await this.addRemoteEngine(data).catch(console.error) + await this.addRemoteEngine(data, false).catch(console.error) }) ) events.emit(EngineEvent.OnEngineUpdate, {}) @@ -303,4 +334,27 @@ export default class JSONEngineManagementExtension extends EngineManagementExten events.emit(ModelEvent.OnModelsUpdate, { fetch: true }) } } + + /** + * Pulls models list from the remote provider and persist + * @param engineConfig + * @returns + */ + private populateRemoteModels = async (engineConfig: EngineConfig) => { + return this.getRemoteModels(engineConfig.engine) + .then((models: ModelList) => { + Promise.all( + models.data?.map((model) => + this.addRemoteModel({ + ...model, + engine: engineConfig.engine as InferenceEngine, + model: model.model ?? model.id, + }).catch(console.info) + ) + ).then(() => { + events.emit(ModelEvent.OnModelsUpdate, { fetch: true }) + }) + }) + .catch(console.info) + } }