Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor LLM manager so that users can configure an LLM provider per feature #3278

Merged
merged 13 commits into from
Feb 4, 2025
56 changes: 31 additions & 25 deletions data/timesketch.conf
Original file line number Diff line number Diff line change
Expand Up @@ -353,36 +353,42 @@ CONTEXT_LINKS_CONFIG_PATH = '/etc/timesketch/context_links.yaml'

# LLM provider configs
LLM_PROVIDER_CONFIGS = {
# To use the Ollama provider you need to download and run an Ollama server.
# See instructions at: https://ollama.ai/
'ollama': {
'server_url': 'http://localhost:11434',
'model': 'gemma:7b',
},
# To use the Vertex AI provider you need to:
# 1. Create and export a Service Account Key from the Google Cloud Console.
# 2. Set the GOOGLE_APPLICATION_CREDENTIALS environment variable to the full path
# to your service account private key file by adding it to the docker-compose.yml
# under environment:
# GOOGLE_APPLICATION_CREDENTIALS=/usr/local/src/timesketch/<key_file>.json
# 3. Install the python libraries: $ pip3 install google-cloud-aiplatform
# Configure a LLM provider for a specific LLM enabled feature, or the
# default provider will be used.
# Supported LLM Providers:
# - ollama: Self-hosted, open-source.
# To use the Ollama provider you need to download and run an Ollama server.
# See instructions at: https://ollama.ai/
# - vertexai: Google Cloud Vertex AI. Requires Google Cloud Project.
# To use the Vertex AI provider you need to:
# 1. Create and export a Service Account Key from the Google Cloud Console.
# 2. Set the GOOGLE_APPLICATION_CREDENTIALS environment variable to the full path
# to your service account private key file by adding it to the docker-compose.yml
# under environment:
# GOOGLE_APPLICATION_CREDENTIALS=/usr/local/src/timesketch/<key_file>.json
# 3. Install the python libraries: $ pip3 install google-cloud-aiplatform
#
# IMPORTANT: Private keys must be kept secret. If you expose your private key it is
# recommended to revoke it immediately from the Google Cloud Console.
'vertexai': {
'model': 'gemini-1.5-flash-001',
'project_id': '',
},
# To use Google's AI Studio simply obtain an API key from https://aistudio.google.com/
# pip3 install google-generativeai
'aistudio': {
'api_key': '',
'model': 'gemini-2.0-flash-exp',
# IMPORTANT: Private keys must be kept secret. If you expose your private key it is
# recommended to revoke it immediately from the Google Cloud Console.
# - aistudio: Google AI Studio (API key). Get API key from Google AI Studio website.
# To use Google's AI Studio simply obtain an API key from https://aistudio.google.com/
# $ pip3 install google-generativeai
'nl2q': {
'vertexai': {
'model': 'gemini-1.5-flash-001',
'project_id': '',
},
},
'default': {
'aistudio': {
'api_key': '',
'model': 'gemini-2.0-flash-exp',
},
}
}


# LLM nl2q configuration
DATA_TYPES_PATH = '/etc/timesketch/nl2q/data_types.csv'
PROMPT_NL2Q = '/etc/timesketch/nl2q/prompt_nl2q'
EXAMPLES_NL2Q = '/etc/timesketch/nl2q/examples_nl2q'
LLM_PROVIDER = ''
29 changes: 13 additions & 16 deletions timesketch/api/v1/resources/nl2q.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,35 +178,33 @@ def post(self, sketch_id):
Returns:
JSON representing the LLM prediction.
"""
llm_provider = current_app.config.get("LLM_PROVIDER", "")
if not llm_provider:
logger.error("No LLM provider was defined in the main configuration file")
abort(
HTTP_STATUS_CODE_INTERNAL_SERVER_ERROR,
"No LLM provider was defined in the main configuration file",
)
form = request.json
if not form:
abort(
HTTP_STATUS_CODE_BAD_REQUEST,
"No JSON data provided",
)
abort(HTTP_STATUS_CODE_BAD_REQUEST, "No JSON data provided")

if "question" not in form:
abort(HTTP_STATUS_CODE_BAD_REQUEST, "The 'question' parameter is required!")

llm_configs = current_app.config.get("LLM_PROVIDER_CONFIGS")
if not llm_configs:
logger.error("No LLM provider configuration defined.")
abort(
HTTP_STATUS_CODE_BAD_REQUEST,
"The 'question' parameter is required!",
HTTP_STATUS_CODE_INTERNAL_SERVER_ERROR,
"No LLM provider was defined in the main configuration file",
)

question = form.get("question")
prompt = self.build_prompt(question, sketch_id)

result_schema = {
"name": "AI generated search query",
"query_string": None,
"error": None,
}

feature_name = "nl2q"
try:
llm = manager.LLMManager().get_provider(llm_provider)()
llm = manager.LLMManager.create_provider(feature_name=feature_name)
except Exception as e: # pylint: disable=broad-except
logger.error("Error LLM Provider: {}".format(e))
result_schema["error"] = (
Expand All @@ -223,7 +221,6 @@ def post(self, sketch_id):
"Please try again later!"
)
return jsonify(result_schema)
# The model sometimes output triple backticks that needs to be removed.
result_schema["query_string"] = prediction.strip("```")

result_schema["query_string"] = prediction.strip("```")
return jsonify(result_schema)
32 changes: 29 additions & 3 deletions timesketch/api/v1/resources/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@
# limitations under the License.
"""System settings."""

from flask import current_app
from flask import jsonify
import logging
from flask import current_app, jsonify
from flask_restful import Resource
from flask_login import login_required

logger = logging.getLogger("timesketch.system_settings")


class SystemSettingsResource(Resource):
"""Resource to get system settings."""
Expand All @@ -30,10 +32,34 @@ def get(self):
JSON object with system settings.
"""
# Settings from timesketch.conf to expose to the frontend clients.
settings_to_return = ["LLM_PROVIDER", "DFIQ_ENABLED"]
settings_to_return = ["DFIQ_ENABLED"]
result = {}

for setting in settings_to_return:
result[setting] = current_app.config.get(setting)

# Derive the default LLM provider from the new configuration.
# Expecting the "default" config to be a dict with exactly one key:
# the provider name.
llm_configs = current_app.config.get("LLM_PROVIDER_CONFIGS", {})
default_provider = None
default_conf = llm_configs.get("default")
if default_conf and isinstance(default_conf, dict) and len(default_conf) == 1:
default_provider = next(iter(default_conf))
result["LLM_PROVIDER"] = default_provider

# TODO(mvd): Remove by 2025/06/01 once all users have updated their config.
old_llm_provider = current_app.config.get("LLM_PROVIDER")
if (
old_llm_provider and "default" not in llm_configs
): # Basic check for old config
warning_message = (
"Your LLM configuration in timesketch.conf is outdated and may cause "
"issues with LLM features. "
"Please update your LLM_PROVIDER_CONFIGS section to the new format. "
"Refer to the documentation for the updated configuration structure."
)
result["llm_config_warning"] = warning_message
logger.warning(warning_message)

return jsonify(result)
27 changes: 17 additions & 10 deletions timesketch/api/v1/resources_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,10 +1191,10 @@ class TestNl2qResource(BaseTest):

resource_url = "/api/v1/sketches/1/nl2q/"

@mock.patch("timesketch.lib.llms.manager.LLMManager")
@mock.patch("timesketch.lib.llms.manager.LLMManager.create_provider")
@mock.patch("timesketch.api.v1.utils.run_aggregator")
@mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore)
def test_nl2q_prompt(self, mock_aggregator, mock_llm_manager):
def test_nl2q_prompt(self, mock_aggregator, mock_create_provider):
"""Test the prompt is created correctly."""

self.login()
Expand All @@ -1207,7 +1207,7 @@ def test_nl2q_prompt(self, mock_aggregator, mock_llm_manager):
mock_aggregator.return_value = (mock_AggregationResult, {})
mock_llm = mock.Mock()
mock_llm.generate.return_value = "LLM generated query"
mock_llm_manager.return_value.get_provider.return_value = lambda: mock_llm
mock_create_provider.return_value = mock_llm
response = self.client.post(
self.resource_url,
data=json.dumps(data),
Expand Down Expand Up @@ -1313,7 +1313,8 @@ def test_nl2q_no_question(self):
def test_nl2q_wrong_llm_provider(self, mock_aggregator):
"""Test nl2q with llm provider that does not exist."""

self.app.config["LLM_PROVIDER"] = "DoesNotExists"
self.app.config["LLM_PROVIDER_CONFIGS"] = {"default": {"DoesNotExists": {}}}
self.login()
self.login()
data = dict(question="Question for LLM?")
mock_AggregationResult = mock.MagicMock()
Expand All @@ -1333,9 +1334,10 @@ def test_nl2q_wrong_llm_provider(self, mock_aggregator):

@mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore)
def test_nl2q_no_llm_provider(self):
"""Test nl2q with no llm provider configured."""
"""Test nl2q with no LLM provider configured."""

del self.app.config["LLM_PROVIDER"]
if "LLM_PROVIDER_CONFIGS" in self.app.config:
del self.app.config["LLM_PROVIDER_CONFIGS"]
self.login()
data = dict(question="Question for LLM?")
response = self.client.post(
Expand Down Expand Up @@ -1371,10 +1373,10 @@ def test_nl2q_no_permission(self):
)
self.assertEqual(response.status_code, HTTP_STATUS_CODE_FORBIDDEN)

@mock.patch("timesketch.lib.llms.manager.LLMManager")
@mock.patch("timesketch.lib.llms.manager.LLMManager.create_provider")
@mock.patch("timesketch.api.v1.utils.run_aggregator")
@mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore)
def test_nl2q_llm_error(self, mock_aggregator, mock_llm_manager):
def test_nl2q_llm_error(self, mock_aggregator, mock_create_provider):
"""Test nl2q with llm error."""

self.login()
Expand All @@ -1387,13 +1389,15 @@ def test_nl2q_llm_error(self, mock_aggregator, mock_llm_manager):
mock_aggregator.return_value = (mock_AggregationResult, {})
mock_llm = mock.Mock()
mock_llm.generate.side_effect = Exception("Test exception")
mock_llm_manager.return_value.get_provider.return_value = lambda: mock_llm
mock_create_provider.return_value = mock_llm
response = self.client.post(
self.resource_url,
data=json.dumps(data),
content_type="application/json",
)
self.assertEqual(response.status_code, HTTP_STATUS_CODE_OK)
self.assertEqual(
response.status_code, HTTP_STATUS_CODE_OK
) # Still expect 200 OK with error in JSON
data = json.loads(response.get_data(as_text=True))
self.assertIsNotNone(data.get("error"))

Expand All @@ -1405,6 +1409,9 @@ class SystemSettingsResourceTest(BaseTest):

def test_system_settings_resource(self):
"""Authenticated request to get system settings."""
self.app.config["LLM_PROVIDER_CONFIGS"] = {"default": {"test": {}}}
self.app.config["DFIQ_ENABLED"] = False

self.login()
response = self.client.get(self.resource_url)
expected_response = {"DFIQ_ENABLED": False, "LLM_PROVIDER": "test"}
Expand Down
62 changes: 27 additions & 35 deletions timesketch/lib/llms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
import string
from typing import Optional

from flask import current_app

DEFAULT_TEMPERATURE = 0.1
DEFAULT_TOP_P = 0.1
DEFAULT_TOP_K = 1
Expand All @@ -27,12 +25,19 @@


class LLMProvider:
"""Base class for LLM providers."""
"""
Base class for LLM providers.

The provider is instantiated with a configuration dictionary that
was extracted (by the manager) from timesketch.conf.
Subclasses should override the NAME attribute.
"""

NAME = "name"

def __init__(
self,
config: dict,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
top_k: int = DEFAULT_TOP_K,
Expand All @@ -43,47 +48,32 @@ def __init__(
"""Initialize the LLM provider.

Args:
temperature: The temperature to use for the response.
top_p: The top_p to use for the response.
top_k: The top_k to use for the response.
max_output_tokens: The maximum number of output tokens to generate.
stream: Whether to stream the response.
location: The cloud location/region to use for the provider.
config: A dictionary of provider-specific configuration options.
temperature: Temperature setting for text generation.
top_p: Top probability (p) value used for generation.
top_k: Top-k value used for generation.
max_output_tokens: Maximum number of tokens to generate in the output.
stream: Whether to enable streaming of the generated content.
location: An optional location parameter for the provider.

Attributes:
config: The configuration for the LLM provider.

Raises:
Exception: If the LLM provider is not configured.
"""
config = {}
config["temperature"] = temperature
config["top_p"] = top_p
config["top_k"] = top_k
config["max_output_tokens"] = max_output_tokens
config["stream"] = stream
config["location"] = location

# Load the LLM provider config from the Flask app config
config_from_flask = current_app.config.get("LLM_PROVIDER_CONFIGS").get(
self.NAME
)
if not config_from_flask:
raise Exception(f"{self.NAME} config not found")

config.update(config_from_flask)
self.config = config
self.config = {
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"max_output_tokens": max_output_tokens,
"stream": stream,
"location": location,
}
self.config.update(config)

def prompt_from_template(self, template: str, kwargs: dict) -> str:
"""Format a prompt from a template.

Args:
template: The template to format.
kwargs: The keyword arguments to format the template with.

Returns:
The formatted prompt.
"""
"""Format a prompt from a template."""
formatter = string.Formatter()
return formatter.format(template, **kwargs)

Expand All @@ -97,5 +87,7 @@ def generate(self, prompt: str, response_schema: Optional[dict] = None) -> str:

Returns:
The generated response.

Subclasses must override this method.
"""
raise NotImplementedError()
Loading