From 648a38d41c9bb16ac8ce4a7a0e1c4f0aa65fd58e Mon Sep 17 00:00:00 2001 From: Tilman Griesel <griesel@me.com> Date: Wed, 29 Jan 2025 09:16:53 +0100 Subject: [PATCH] WIP: Mirror Ollama chat API --- services/api/src/main.py | 32 ++++++++++---------------------- 1 file changed, 10 insertions(+), 22 deletions(-) diff --git a/services/api/src/main.py b/services/api/src/main.py index cb4590c..3bad9b0 100755 --- a/services/api/src/main.py +++ b/services/api/src/main.py @@ -9,9 +9,9 @@ from pathlib import Path import elasticsearch +from core.ollama_proxy import OllamaProxy from core.pipeline_config import ModelProvider, QueryPipelineConfig from core.rag_pipeline import RAGQueryPipeline -from core.ollama_proxy import OllamaProxy from dotenv import load_dotenv from flask import Flask, Response, abort, jsonify, request, stream_with_context from flask_limiter import Limiter @@ -110,6 +110,7 @@ def load_systemprompt(base_path: str) -> str: system_prompt_value = load_systemprompt(os.getenv("SYSTEM_PROMPT_PATH", os.getcwd())) + def get_env_param(param_name, converter=None, default=None): value = os.getenv(param_name) if value is None: @@ -262,9 +263,11 @@ def tags(): def pull(): return proxy.pull() + ollama_proxy = OllamaProxy(os.getenv("OLLAMA_URL", "http://localhost:11434")) register_ollama_routes(app, ollama_proxy) + @app.before_request def before_request(): logger.info(f"Request {request.method} {request.path} from {request.remote_addr}") @@ -441,9 +444,7 @@ def generate(): def handle_standard_response( - config: QueryPipelineConfig, - query: str, - conversation: list + config: QueryPipelineConfig, query: str, conversation: list ) -> Response: rag = RAGQueryPipeline(config=config) @@ -451,9 +452,7 @@ def handle_standard_response( result = None try: result = rag.run_query( - query=query, - conversation=conversation, - print_response=False + query=query, conversation=conversation, print_response=False ) if result: @@ -474,30 +473,18 @@ def handle_standard_response( "prompt_eval_count": 26, "prompt_eval_duration": 342546000, "eval_count": 282, - "eval_duration": 4535599000 + "eval_duration": 4535599000, } except Exception as e: success = False logger.error(f"Error in RAG pipeline: {e}", exc_info=True) - if success and result: - latest_message = { - "role": "assistant", - "content": result["llm"]["replies"][0], - "timestamp": datetime.now().isoformat(), - } - conversation.append(latest_message) - return jsonify( - { - "success": success, - "timestamp": datetime.now().isoformat(), - "result": result, - "messages": conversation, - } + {"success": success, "timestamp": datetime.now().isoformat(), "result": result} ) + @app.route("/", methods=["GET"]) @app.route("/health", methods=["GET"]) def health_check(): @@ -511,6 +498,7 @@ def health_check(): } ) + @app.errorhandler(404) def not_found_error(error): return "", 404