From 648a38d41c9bb16ac8ce4a7a0e1c4f0aa65fd58e Mon Sep 17 00:00:00 2001
From: Tilman Griesel <griesel@me.com>
Date: Wed, 29 Jan 2025 09:16:53 +0100
Subject: [PATCH] WIP: Mirror Ollama chat API

---
 services/api/src/main.py | 32 ++++++++++----------------------
 1 file changed, 10 insertions(+), 22 deletions(-)

diff --git a/services/api/src/main.py b/services/api/src/main.py
index cb4590c..3bad9b0 100755
--- a/services/api/src/main.py
+++ b/services/api/src/main.py
@@ -9,9 +9,9 @@
 from pathlib import Path
 
 import elasticsearch
+from core.ollama_proxy import OllamaProxy
 from core.pipeline_config import ModelProvider, QueryPipelineConfig
 from core.rag_pipeline import RAGQueryPipeline
-from core.ollama_proxy import OllamaProxy
 from dotenv import load_dotenv
 from flask import Flask, Response, abort, jsonify, request, stream_with_context
 from flask_limiter import Limiter
@@ -110,6 +110,7 @@ def load_systemprompt(base_path: str) -> str:
 
 system_prompt_value = load_systemprompt(os.getenv("SYSTEM_PROMPT_PATH", os.getcwd()))
 
+
 def get_env_param(param_name, converter=None, default=None):
     value = os.getenv(param_name)
     if value is None:
@@ -262,9 +263,11 @@ def tags():
     def pull():
         return proxy.pull()
 
+
 ollama_proxy = OllamaProxy(os.getenv("OLLAMA_URL", "http://localhost:11434"))
 register_ollama_routes(app, ollama_proxy)
 
+
 @app.before_request
 def before_request():
     logger.info(f"Request {request.method} {request.path} from {request.remote_addr}")
@@ -441,9 +444,7 @@ def generate():
 
 
 def handle_standard_response(
-    config: QueryPipelineConfig,
-    query: str,
-    conversation: list
+    config: QueryPipelineConfig, query: str, conversation: list
 ) -> Response:
     rag = RAGQueryPipeline(config=config)
 
@@ -451,9 +452,7 @@ def handle_standard_response(
     result = None
     try:
         result = rag.run_query(
-            query=query,
-            conversation=conversation,
-            print_response=False
+            query=query, conversation=conversation, print_response=False
         )
 
         if result:
@@ -474,30 +473,18 @@ def handle_standard_response(
                 "prompt_eval_count": 26,
                 "prompt_eval_duration": 342546000,
                 "eval_count": 282,
-                "eval_duration": 4535599000
+                "eval_duration": 4535599000,
             }
 
     except Exception as e:
         success = False
         logger.error(f"Error in RAG pipeline: {e}", exc_info=True)
 
-    if success and result:
-        latest_message = {
-            "role": "assistant",
-            "content": result["llm"]["replies"][0],
-            "timestamp": datetime.now().isoformat(),
-        }
-        conversation.append(latest_message)
-
     return jsonify(
-        {
-            "success": success,
-            "timestamp": datetime.now().isoformat(),
-            "result": result,
-            "messages": conversation,
-        }
+        {"success": success, "timestamp": datetime.now().isoformat(), "result": result}
     )
 
+
 @app.route("/", methods=["GET"])
 @app.route("/health", methods=["GET"])
 def health_check():
@@ -511,6 +498,7 @@ def health_check():
         }
     )
 
+
 @app.errorhandler(404)
 def not_found_error(error):
     return "", 404