diff --git a/examples/cohere_client.py b/examples/cohere_client.py new file mode 100644 index 000000000..a7b75b109 --- /dev/null +++ b/examples/cohere_client.py @@ -0,0 +1,40 @@ +import os + +import cohere +from cohere.responses.chat import StreamTextGeneration + +co = cohere.Client(api_key='na', api_url=os.getenv('OPENLLM_ENDPOINT', 'http://localhost:3000') + '/cohere') + +generation = co.generate(prompt='Write me a tag line for an ice cream shop.') +print(generation.generations[0].text) + +for it in co.generate(prompt='Write me a tag line for an ice cream shop.', stream=True): + print(it.text, flush=True, end='') + +response = co.chat( + message="What is Epicurus's philosophy of life?", + temperature=0.6, + chat_history=[ + {'role': 'User', 'message': 'What is the meaning of life?'}, + { + 'role': 'Chatbot', + 'message': "Many thinkers have proposed theories about the meaning of life. \n\nFor instance, Jean-Paul Sartre believed that existence precedes essence, meaning that the essence, or meaning, of one's life arises after birth. Søren Kierkegaard argued that life is full of absurdity and that one must make one's own values in an indifferent world. Arthur Schopenhauer stated that one's life reflects one's will, and that the will (or life) is without aim, irrational, and full of pain. \n\nEarly thinkers such as John Locke, Jean-Jacques Rousseau and Adam Smith believed that humankind should find meaning through labour, property and social contracts. \n\nAnother way of thinking about the meaning of life is to focus on the pursuit of happiness or pleasure. Aristippus of Cyrene, a student of Socrates, founded an early Socratic school that emphasised one aspect of Socrates's teachings: that happiness is the end goal of moral action and that pleasure is the supreme good. Epicurus taught that the pursuit of modest pleasures was the greatest good, as it leads to tranquility, freedom from fear and absence of bodily pain. \n\nUltimately, the meaning of life is a subjective concept and what provides life with meaning differs for each individual.", + }, + ], +) +print(response) + +for it in co.chat( + message="What is Epicurus's philosophy of life?", + temperature=0.6, + chat_history=[ + {'role': 'User', 'message': 'What is the meaning of life?'}, + { + 'role': 'Chatbot', + 'message': "Many thinkers have proposed theories about the meaning of life. \n\nFor instance, Jean-Paul Sartre believed that existence precedes essence, meaning that the essence, or meaning, of one's life arises after birth. Søren Kierkegaard argued that life is full of absurdity and that one must make one's own values in an indifferent world. Arthur Schopenhauer stated that one's life reflects one's will, and that the will (or life) is without aim, irrational, and full of pain. \n\nEarly thinkers such as John Locke, Jean-Jacques Rousseau and Adam Smith believed that humankind should find meaning through labour, property and social contracts. \n\nAnother way of thinking about the meaning of life is to focus on the pursuit of happiness or pleasure. Aristippus of Cyrene, a student of Socrates, founded an early Socratic school that emphasised one aspect of Socrates's teachings: that happiness is the end goal of moral action and that pleasure is the supreme good. Epicurus taught that the pursuit of modest pleasures was the greatest good, as it leads to tranquility, freedom from fear and absence of bodily pain. \n\nUltimately, the meaning of life is a subjective concept and what provides life with meaning differs for each individual.", + }, + ], + stream=True, +): + if isinstance(it, StreamTextGeneration): + print(it.text, flush=True, end='') diff --git a/openllm-core/src/openllm_core/_configuration.py b/openllm-core/src/openllm_core/_configuration.py index f9c7e2992..f039e4a2c 100644 --- a/openllm-core/src/openllm_core/_configuration.py +++ b/openllm-core/src/openllm_core/_configuration.py @@ -1,6 +1,7 @@ # mypy: disable-error-code="attr-defined,no-untyped-call,type-var,operator,arg-type,no-redef,misc" from __future__ import annotations import copy +import importlib.util import logging import os import sys @@ -31,7 +32,7 @@ Self, overload, ) -from .exceptions import ForbiddenAttributeError +from .exceptions import ForbiddenAttributeError, MissingDependencyError from .utils import LazyLoader, ReprMixin, codegen, converter, dantic, field_env_key, first_not_none, lenient_issubclass from .utils.peft import PEFT_TASK_TYPE_TARGET_MAPPING, FineTuneConfig @@ -39,7 +40,9 @@ import click import transformers import vllm + from attrs import AttrsInstance + from openllm.protocol.cohere import CohereChatRequest, CohereGenerateRequest from openllm.protocol.openai import ChatCompletionRequest, CompletionRequest else: vllm = LazyLoader('vllm', globals(), 'vllm') @@ -1460,7 +1463,28 @@ def to_generation_config(self, return_as_dict: bool = False) -> transformers.Gen def to_sampling_config(self) -> vllm.SamplingParams: return self.sampling_config.build() - def with_openai_request(self, request: ChatCompletionRequest | CompletionRequest) -> dict[str, t.Any]: + @overload + def with_request(self, request: ChatCompletionRequest | CompletionRequest) -> dict[str, t.Any]: ... + + @overload + def with_request(self, request: CohereChatRequest | CohereGenerateRequest) -> dict[str, t.Any]: ... + + def with_request(self, request: AttrsInstance) -> dict[str, t.Any]: + if importlib.util.find_spec('openllm') is None: + raise MissingDependencyError( + "'openllm' is required to use 'with_request'. Make sure to install with 'pip install openllm'." + ) + from openllm.protocol.cohere import CohereChatRequest, CohereGenerateRequest + from openllm.protocol.openai import ChatCompletionRequest, CompletionRequest + + if isinstance(request, (ChatCompletionRequest, CompletionRequest)): + return self._with_openai_request(request) + elif isinstance(request, (CohereChatRequest, CohereGenerateRequest)): + return self._with_cohere_request(request) + else: + raise TypeError(f'Unknown request type {type(request)}') + + def _with_openai_request(self, request: ChatCompletionRequest | CompletionRequest) -> dict[str, t.Any]: d = dict( temperature=first_not_none(request.temperature, self['temperature']), top_p=first_not_none(request.top_p, self['top_p']), @@ -1476,6 +1500,21 @@ def with_openai_request(self, request: ChatCompletionRequest | CompletionRequest d['logprobs'] = first_not_none(request.logprobs, default=self['logprobs']) return d + def _with_cohere_request(self, request: CohereGenerateRequest | CohereChatRequest) -> dict[str, t.Any]: + d = dict( + max_new_tokens=first_not_none(request.max_tokens, default=self['max_new_tokens']), + temperature=first_not_none(request.temperature, default=self['temperature']), + top_k=first_not_none(request.k, default=self['top_k']), + top_p=first_not_none(request.p, default=self['top_p']), + ) + if hasattr(request, 'num_generations'): + d['n'] = first_not_none(request.num_generations, default=self['n']) + if hasattr(request, 'frequency_penalty'): + d['frequency_penalty'] = first_not_none(request.frequency_penalty, default=self['frequency_penalty']) + if hasattr(request, 'presence_penalty'): + d['presence_penalty'] = first_not_none(request.presence_penalty, default=self['presence_penalty']) + return d + @classmethod def to_click_options(cls, f: AnyCallable) -> click.Command: """Convert current configuration to click options. diff --git a/openllm-python/src/openllm/entrypoints/__init__.py b/openllm-python/src/openllm/entrypoints/__init__.py index 6b75ea269..5e31d9d3b 100644 --- a/openllm-python/src/openllm/entrypoints/__init__.py +++ b/openllm-python/src/openllm/entrypoints/__init__.py @@ -8,21 +8,28 @@ """ from __future__ import annotations +import importlib import typing as t from openllm_core.utils import LazyModule -from . import hf as hf, openai as openai - if t.TYPE_CHECKING: import bentoml import openllm -_import_structure: dict[str, list[str]] = {'openai': [], 'hf': []} + +class IntegrationModule(t.Protocol): + def mount_to_svc(self, svc: bentoml.Service, llm: openllm.LLM[t.Any, t.Any]) -> bentoml.Service: ... + + +_import_structure: dict[str, list[str]] = {'openai': [], 'hf': [], 'cohere': []} def mount_entrypoints(svc: bentoml.Service, llm: openllm.LLM[t.Any, t.Any]) -> bentoml.Service: - return openai.mount_to_svc(hf.mount_to_svc(svc, llm), llm) + for module_name in _import_structure: + module = t.cast(IntegrationModule, importlib.import_module(f'.{module_name}', __name__)) + svc = module.mount_to_svc(svc, llm) + return svc __lazy = LazyModule( diff --git a/openllm-python/src/openllm/entrypoints/_openapi.py b/openllm-python/src/openllm/entrypoints/_openapi.py index 198956fb3..427b55ce6 100644 --- a/openllm-python/src/openllm/entrypoints/_openapi.py +++ b/openllm-python/src/openllm/entrypoints/_openapi.py @@ -396,7 +396,7 @@ summary: Describes a model offering that can be used with the API. tags: - HF -x-bentoml-name: adapters_map +x-bentoml-name: hf_adapters responses: 200: description: Return list of LoRA adapters. @@ -416,6 +416,65 @@ $ref: '#/components/schemas/HFErrorResponse' description: Not Found """ +COHERE_GENERATE_SCHEMA = """\ +--- +consumes: + - application/json +description: >- + Given a prompt, the model will return one or more predicted completions, and + can also return the probabilities of alternative tokens at each position. +operationId: cohere__generate +produces: + - application/json +tags: + - Cohere +x-bentoml-name: cohere_generate +summary: Creates a completion for the provided prompt and parameters. +requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/CohereGenerateRequest' + examples: + one-shot: + summary: One-shot input example + value: + prompt: This is a test + max_tokens: 256 + temperature: 0.7 + p: 0.43 + k: 12 + num_generations: 2 + stream: false + streaming: + summary: Streaming input example + value: + prompt: This is a test + max_tokens: 256 + temperature: 0.7 + p: 0.43 + k: 12 + num_generations: 2 + stream: true + stop_sequences: + - "\\n" + - "<|endoftext|>" +""" +COHERE_CHAT_SCHEMA = """\ +--- +consumes: +- application/json +description: >- + Given a list of messages comprising a conversation, the model will return a response. +operationId: cohere__chat +produces: + - application/json +tags: + - Cohere +x-bentoml-name: cohere_chat +summary: Creates a model response for the given chat conversation. +""" _SCHEMAS = {k[:-7].lower(): v for k, v in locals().items() if k.endswith('_SCHEMA')} @@ -485,12 +544,15 @@ def get_schema(self, routes: list[BaseRoute], mount_path: str | None = None) -> def get_generator( - title: str, components: list[type[AttrsInstance]] | None = None, tags: list[dict[str, t.Any]] | None = None + title: str, + components: list[type[AttrsInstance]] | None = None, + tags: list[dict[str, t.Any]] | None = None, + inject: bool = True, ) -> OpenLLMSchemaGenerator: base_schema: dict[str, t.Any] = dict(info={'title': title, 'version': API_VERSION}, version=OPENAPI_VERSION) - if components: + if components and inject: base_schema['components'] = {'schemas': {c.__name__: component_schema_generator(c) for c in components}} - if tags is not None and tags: + if tags is not None and tags and inject: base_schema['tags'] = tags return OpenLLMSchemaGenerator(base_schema) diff --git a/openllm-python/src/openllm/entrypoints/cohere.py b/openllm-python/src/openllm/entrypoints/cohere.py new file mode 100644 index 000000000..645dbc360 --- /dev/null +++ b/openllm-python/src/openllm/entrypoints/cohere.py @@ -0,0 +1,317 @@ +from __future__ import annotations +import functools +import json +import logging +import traceback +import typing as t +from http import HTTPStatus + +import orjson +from starlette.applications import Starlette +from starlette.responses import JSONResponse, StreamingResponse +from starlette.routing import Route + +from openllm_core.utils import converter, gen_random_uuid + +from ._openapi import append_schemas, get_generator + +# from ._openapi import add_schema_definitions +from ..protocol.cohere import ( + Chat, + ChatStreamEnd, + ChatStreamStart, + ChatStreamTextGeneration, + CohereChatRequest, + CohereErrorResponse, + CohereGenerateRequest, + Generation, + Generations, + StreamingGenerations, + StreamingText, +) + +schemas = get_generator( + 'cohere', + components=[ + CohereChatRequest, + CohereErrorResponse, + CohereGenerateRequest, + Generation, + Generations, + StreamingGenerations, + StreamingText, + Chat, + ChatStreamStart, + ChatStreamEnd, + ChatStreamTextGeneration, + ], + tags=[ + { + 'name': 'Cohere', + 'description': 'Cohere compatible API. Currently support /generate, /chat', + 'externalDocs': 'https://docs.cohere.com/docs/the-cohere-platform', + } + ], + inject=False, +) +logger = logging.getLogger(__name__) + +if t.TYPE_CHECKING: + from attr import AttrsInstance + from starlette.requests import Request + from starlette.responses import Response + + import bentoml + import openllm + from openllm_core._schemas import GenerationOutput + from openllm_core._typing_compat import M, T + + +def jsonify_attr(obj: AttrsInstance) -> str: + return json.dumps(converter.unstructure(obj)) + + +def error_response(status_code: HTTPStatus, message: str) -> JSONResponse: + return JSONResponse(converter.unstructure(CohereErrorResponse(text=message)), status_code=status_code.value) + + +async def check_model(request: CohereGenerateRequest | CohereChatRequest, model: str) -> JSONResponse | None: + if request.model is None or request.model == model: + return None + return error_response( + HTTPStatus.NOT_FOUND, + f"Model '{request.model}' does not exists. Try 'GET /v1/models' to see current running models.", + ) + + +def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Service: + app = Starlette( + debug=True, + routes=[ + Route( + '/v1/generate', + endpoint=functools.partial(cohere_generate, llm=llm), + name='cohere_generate', + methods=['POST'], + include_in_schema=False, + ), + Route( + '/v1/chat', + endpoint=functools.partial(cohere_chat, llm=llm), + name='cohere_chat', + methods=['POST'], + include_in_schema=False, + ), + Route('/schema', endpoint=openapi_schema, include_in_schema=False), + ], + ) + mount_path = '/cohere' + + svc.mount_asgi_app(app, path=mount_path) + return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path=mount_path), tags_order='append') + + +# @add_schema_definitions +async def cohere_generate(req: Request, llm: openllm.LLM[M, T]) -> Response: + json_str = await req.body() + try: + request = converter.structure(orjson.loads(json_str), CohereGenerateRequest) + except orjson.JSONDecodeError as err: + logger.debug('Sent body: %s', json_str) + logger.error('Invalid JSON input received: %s', err) + return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).') + logger.debug('Received generate request: %s', request) + + err_check = await check_model(request, llm.llm_type) + if err_check is not None: + return err_check + + request_id = gen_random_uuid('cohere-generate') + config = llm.config.with_request(request) + + if request.prompt_vars is not None: + prompt = request.prompt.format(**request.prompt_vars) + else: + prompt = request.prompt + + # TODO: support end_sequences, stop_sequences, logit_bias, return_likelihoods, truncate + + try: + result_generator = llm.generate_iterator(prompt, request_id=request_id, stop=request.stop_sequences, **config) + except Exception as err: + traceback.print_exc() + logger.error('Error generating completion: %s', err) + return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)') + + def create_stream_response_json(index: int, text: str, is_finished: bool) -> str: + return f'{jsonify_attr(StreamingText(index=index, text=text, is_finished=is_finished))}\n' + + async def generate_stream_generator() -> t.AsyncGenerator[str, None]: + async for res in result_generator: + for output in res.outputs: + yield create_stream_response_json(index=output.index, text=output.text, is_finished=output.finish_reason) + + try: + # streaming case + if request.stream: + return StreamingResponse(generate_stream_generator(), media_type='text/event-stream') + # None-streaming case + final_result: GenerationOutput | None = None + texts: list[list[str]] = [[]] * config['num_generations'] + token_ids: list[list[int]] = [[]] * config['num_generations'] + async for res in result_generator: + if await req.is_disconnected(): + return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.') + for output in res.outputs: + texts[output.index].append(output.text) + token_ids[output.index].extend(output.token_ids) + final_result = res + if final_result is None: + return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.') + final_result = final_result.with_options( + outputs=[ + output.with_options(text=''.join(texts[output.index]), token_ids=token_ids[output.index]) + for output in final_result.outputs + ] + ) + response = Generations( + id=request_id, + generations=[ + Generation(id=request_id, text=output.text, prompt=prompt, finish_reason=output.finish_reason) + for output in final_result.outputs + ], + ) + return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value) + except Exception as err: + traceback.print_exc() + logger.error('Error generating completion: %s', err) + return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)') + + +def _transpile_cohere_chat_messages(request: CohereChatRequest) -> list[dict[str, str]]: + def convert_role(role): + return {'User': 'user', 'Chatbot': 'assistant'}[role] + + chat_history = request.chat_history + if chat_history: + messages = [{'role': convert_role(msg['role']), 'content': msg['message']} for msg in chat_history] + else: + messages = [] + messages.append({'role': 'user', 'content': request.message}) + return messages + + +# @add_schema_definitions +async def cohere_chat(req: Request, llm: openllm.LLM[M, T]) -> Response: + json_str = await req.body() + try: + request = converter.structure(orjson.loads(json_str), CohereChatRequest) + except orjson.JSONDecodeError as err: + logger.debug('Sent body: %s', json_str) + logger.error('Invalid JSON input received: %s', err) + return error_response(HTTPStatus.BAD_REQUEST, 'Invalid JSON input received (Check server log).') + logger.debug('Received chat completion request: %s', request) + + err_check = await check_model(request, llm.llm_type) + if err_check is not None: + return err_check + + request_id = gen_random_uuid('cohere-chat') + prompt: str = llm.tokenizer.apply_chat_template( + _transpile_cohere_chat_messages(request), tokenize=False, add_generation_prompt=llm.config['add_generation_prompt'] + ) + logger.debug('Prompt: %r', prompt) + config = llm.config.with_request(request) + + try: + result_generator = llm.generate_iterator(prompt, request_id=request_id, **config) + except Exception as err: + traceback.print_exc() + logger.error('Error generating completion: %s', err) + return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)') + + def create_stream_generation_json(index: int, text: str, is_finished: bool) -> str: + return f'{jsonify_attr(ChatStreamTextGeneration(index=index, text=text, is_finished=is_finished))}\n' + + async def completion_stream_generator() -> t.AsyncGenerator[str, None]: + texts: list[str] = [] + token_ids: list[int] = [] + yield f'{jsonify_attr(ChatStreamStart(is_finished=False, index=0, generation_id=request_id))}\n' + + it = None + async for res in result_generator: + yield create_stream_generation_json(index=res.outputs[0].index, text=res.outputs[0].text, is_finished=False) + texts.append(res.outputs[0].text) + token_ids.extend(res.outputs[0].token_ids) + it = res + + if it is None: + raise ValueError('No response from model.') + num_prompt_tokens = len(t.cast(t.List[int], it.prompt_token_ids)) + num_response_tokens = len(token_ids) + total_tokens = num_prompt_tokens + num_response_tokens + + json_str = jsonify_attr( + ChatStreamEnd( + is_finished=True, + finish_reason='COMPLETE', + index=0, + response=Chat( + response_id=request_id, + message=request.message, + text=''.join(texts), + prompt=prompt, + chat_history=request.chat_history, + token_count={ + 'prompt_tokens': num_prompt_tokens, + 'response_tokens': num_response_tokens, + 'total_tokens': total_tokens, + }, + ), + ) + ) + yield f'{json_str}\n' + + try: + if request.stream: + return StreamingResponse(completion_stream_generator(), media_type='text/event-stream') + # Non-streaming case + final_result: GenerationOutput | None = None + texts: list[str] = [] + token_ids: list[int] = [] + async for res in result_generator: + if await req.is_disconnected(): + return error_response(HTTPStatus.BAD_REQUEST, 'Client disconnected.') + texts.append(res.outputs[0].text) + token_ids.extend(res.outputs[0].token_ids) + final_result = res + if final_result is None: + return error_response(HTTPStatus.BAD_REQUEST, 'No response from model.') + final_result = final_result.with_options( + outputs=[final_result.outputs[0].with_options(text=''.join(texts), token_ids=token_ids)] + ) + num_prompt_tokens = len(t.cast(t.List[int], final_result.prompt_token_ids)) + num_response_tokens = len(token_ids) + total_tokens = num_prompt_tokens + num_response_tokens + + response = Chat( + response_id=request_id, + message=request.message, + text=''.join(texts), + prompt=prompt, + chat_history=request.chat_history, + token_count={ + 'prompt_tokens': num_prompt_tokens, + 'response_tokens': num_response_tokens, + 'total_tokens': total_tokens, + }, + ) + return JSONResponse(converter.unstructure(response), status_code=HTTPStatus.OK.value) + except Exception as err: + traceback.print_exc() + logger.error('Error generating completion: %s', err) + return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)') + + +def openapi_schema(req: Request) -> Response: + return schemas.OpenAPIResponse(req) diff --git a/openllm-python/src/openllm/entrypoints/hf.py b/openllm-python/src/openllm/entrypoints/hf.py index 50629dea6..faa11acb8 100644 --- a/openllm-python/src/openllm/entrypoints/hf.py +++ b/openllm-python/src/openllm/entrypoints/hf.py @@ -48,9 +48,8 @@ def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Servic ], ) mount_path = '/hf' - generated_schema = schemas.get_schema(routes=app.routes, mount_path=mount_path) svc.mount_asgi_app(app, path=mount_path) - return append_schemas(svc, generated_schema, tags_order='append') + return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path=mount_path), tags_order='append') def error_response(status_code: HTTPStatus, message: str) -> JSONResponse: diff --git a/openllm-python/src/openllm/entrypoints/openai.py b/openllm-python/src/openllm/entrypoints/openai.py index 5f3d20436..0f76f1c55 100644 --- a/openllm-python/src/openllm/entrypoints/openai.py +++ b/openllm-python/src/openllm/entrypoints/openai.py @@ -119,12 +119,12 @@ def mount_to_svc(svc: bentoml.Service, llm: openllm.LLM[M, T]) -> bentoml.Servic Route('/models', functools.partial(list_models, llm=llm), methods=['GET']), Route('/completions', functools.partial(completions, llm=llm), methods=['POST']), Route('/chat/completions', functools.partial(chat_completions, llm=llm), methods=['POST']), + Route('/schema', endpoint=openapi_schema, include_in_schema=False), ], ) mount_path = '/v1' - generated_schema = schemas.get_schema(routes=app.routes, mount_path=mount_path) svc.mount_asgi_app(app, path=mount_path) - return append_schemas(svc, generated_schema) + return append_schemas(svc, schemas.get_schema(routes=app.routes, mount_path=mount_path)) # GET /v1/models @@ -157,7 +157,7 @@ async def chat_completions(req: Request, llm: openllm.LLM[M, T]) -> Response: request.messages, tokenize=False, add_generation_prompt=llm.config['add_generation_prompt'] ) logger.debug('Prompt: %r', prompt) - config = llm.config.with_openai_request(request) + config = llm.config.with_request(request) try: result_generator = llm.generate_iterator(prompt, request_id=request_id, **config) @@ -287,7 +287,7 @@ async def completions(req: Request, llm: openllm.LLM[M, T]) -> Response: model_name, request_id = request.model, gen_random_uuid('cmpl') created_time = int(time.monotonic()) - config = llm.config.with_openai_request(request) + config = llm.config.with_request(request) try: result_generator = llm.generate_iterator(prompt, request_id=request_id, **config) @@ -398,3 +398,7 @@ async def fake_stream_generator() -> t.AsyncGenerator[str, None]: traceback.print_exc() logger.error('Error generating completion: %s', err) return error_response(HTTPStatus.INTERNAL_SERVER_ERROR, f'Exception: {err!s} (check server log)') + + +def openapi_schema(req: Request) -> Response: + return schemas.OpenAPIResponse(req) diff --git a/openllm-python/src/openllm/protocol/cohere.py b/openllm-python/src/openllm/protocol/cohere.py new file mode 100644 index 000000000..a080302fd --- /dev/null +++ b/openllm-python/src/openllm/protocol/cohere.py @@ -0,0 +1,154 @@ +from __future__ import annotations +import typing as t +from enum import Enum + +import attr + +from openllm_core.utils import converter + + +@attr.define +class CohereErrorResponse: + text: str + + +converter.register_unstructure_hook(CohereErrorResponse, lambda obj: obj.text) + + +@attr.define +class CohereGenerateRequest: + prompt: str + prompt_vars: t.Optional[t.Dict[str, t.Any]] = None + model: t.Optional[str] = None + preset: t.Optional[str] = None + num_generations: t.Optional[int] = None + max_tokens: t.Optional[int] = None + temperature: t.Optional[float] = None + k: t.Optional[int] = None + p: t.Optional[float] = None + frequency_penalty: t.Optional[float] = None + presence_penalty: t.Optional[float] = None + end_sequences: t.Optional[t.List[str]] = None + stop_sequences: t.Optional[t.List[str]] = None + return_likelihoods: t.Optional[t.Literal['GENERATION', 'ALL', 'NONE']] = None + truncate: t.Optional[str] = None + logit_bias: t.Optional[t.Dict[int, float]] = None + stream: bool = False + + +@attr.define +class TokenLikelihood: # pretty sure this is similar to token_logprobs + token: str + likelihood: float + + +@attr.define +class Generation: + id: str + text: str + prompt: str + likelihood: t.Optional[float] = None + token_likelihoods: t.List[TokenLikelihood] = attr.field(factory=list) + finish_reason: t.Optional[str] = None + + +@attr.define +class Generations: + id: str + generations: t.List[Generation] + meta: t.Optional[t.Dict[str, t.Any]] = None + + +@attr.define +class StreamingText: + index: int + text: str + is_finished: bool + + +@attr.define +class StreamingGenerations: + id: str + generations: Generations + texts: t.List[str] + meta: t.Optional[t.Dict[str, t.Any]] = None + + +@attr.define +class CohereChatRequest: + message: str + conversation_id: t.Optional[str] = '' + model: t.Optional[str] = None + return_chat_history: t.Optional[bool] = False + return_prompt: t.Optional[bool] = False + return_preamble: t.Optional[bool] = False + chat_history: t.Optional[t.List[t.Dict[str, str]]] = None + preamble_override: t.Optional[str] = None + user_name: t.Optional[str] = None + temperature: t.Optional[float] = 0.8 + max_tokens: t.Optional[int] = None + stream: t.Optional[bool] = False + p: t.Optional[float] = None + k: t.Optional[float] = None + logit_bias: t.Optional[t.Dict[int, float]] = None + search_queries_only: t.Optional[bool] = None + documents: t.Optional[t.List[t.Dict[str, t.Any]]] = None + citation_quality: t.Optional[str] = None + prompt_truncation: t.Optional[str] = None + connectors: t.Optional[t.List[t.Dict[str, t.Any]]] = None + + +class StreamEvent(str, Enum): + STREAM_START = 'stream-start' + TEXT_GENERATION = 'text-generation' + STREAM_END = 'stream-end' + # TODO: The following are yet to be implemented + SEARCH_QUERIES_GENERATION = 'search-queries-generation' + SEARCH_RESULTS = 'search-results' + CITATION_GENERATION = 'citation-generation' + + +@attr.define +class Chat: + response_id: str + message: str + text: str + generation_id: t.Optional[str] = None + conversation_id: t.Optional[str] = None + meta: t.Optional[t.Dict[str, t.Any]] = None + prompt: t.Optional[str] = None + chat_history: t.Optional[t.List[t.Dict[str, t.Any]]] = None + preamble: t.Optional[str] = None + token_count: t.Optional[t.Dict[str, int]] = None + is_search_required: t.Optional[bool] = None + citations: t.Optional[t.List[t.Dict[str, t.Any]]] = None + documents: t.Optional[t.List[t.Dict[str, t.Any]]] = None + search_results: t.Optional[t.List[t.Dict[str, t.Any]]] = None + search_queries: t.Optional[t.List[t.Dict[str, t.Any]]] = None + + +@attr.define +class ChatStreamResponse: + is_finished: bool + event_type: StreamEvent + index: int + + +@attr.define +class ChatStreamStart(ChatStreamResponse): + generation_id: str + conversation_id: t.Optional[str] = None + event_type: StreamEvent = StreamEvent.STREAM_START + + +@attr.define +class ChatStreamTextGeneration(ChatStreamResponse): + text: str + event_type: StreamEvent = StreamEvent.TEXT_GENERATION + + +@attr.define +class ChatStreamEnd(ChatStreamResponse): + finish_reason: str + response: Chat + event_type: StreamEvent = StreamEvent.STREAM_END