Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor LLM manager so that users can configure an LLM provider per feature #3278

Merged
merged 13 commits into from
Feb 4, 2025
56 changes: 31 additions & 25 deletions data/timesketch.conf
Original file line number Diff line number Diff line change
Expand Up @@ -353,36 +353,42 @@ CONTEXT_LINKS_CONFIG_PATH = '/etc/timesketch/context_links.yaml'

# LLM provider configs
LLM_PROVIDER_CONFIGS = {
# To use the Ollama provider you need to download and run an Ollama server.
# See instructions at: https://ollama.ai/
'ollama': {
'server_url': 'http://localhost:11434',
'model': 'gemma:7b',
},
# To use the Vertex AI provider you need to:
# 1. Create and export a Service Account Key from the Google Cloud Console.
# 2. Set the GOOGLE_APPLICATION_CREDENTIALS environment variable to the full path
# to your service account private key file by adding it to the docker-compose.yml
# under environment:
# GOOGLE_APPLICATION_CREDENTIALS=/usr/local/src/timesketch/<key_file>.json
# 3. Install the python libraries: $ pip3 install google-cloud-aiplatform
# Configure a LLM provider for a specific LLM enabled feature, or the
# default provider will be used.
# Supported LLM Providers:
# - ollama: Self-hosted, open-source.
# To use the Ollama provider you need to download and run an Ollama server.
# See instructions at: https://ollama.ai/
# - vertexai: Google Cloud Vertex AI. Requires Google Cloud Project.
# To use the Vertex AI provider you need to:
# 1. Create and export a Service Account Key from the Google Cloud Console.
# 2. Set the GOOGLE_APPLICATION_CREDENTIALS environment variable to the full path
# to your service account private key file by adding it to the docker-compose.yml
# under environment:
# GOOGLE_APPLICATION_CREDENTIALS=/usr/local/src/timesketch/<key_file>.json
# 3. Install the python libraries: $ pip3 install google-cloud-aiplatform
#
# IMPORTANT: Private keys must be kept secret. If you expose your private key it is
# recommended to revoke it immediately from the Google Cloud Console.
'vertexai': {
'model': 'gemini-1.5-flash-001',
'project_id': '',
},
# To use Google's AI Studio simply obtain an API key from https://aistudio.google.com/
# pip3 install google-generativeai
'aistudio': {
'api_key': '',
'model': 'gemini-2.0-flash-exp',
# IMPORTANT: Private keys must be kept secret. If you expose your private key it is
# recommended to revoke it immediately from the Google Cloud Console.
# - aistudio: Google AI Studio (API key). Get API key from Google AI Studio website.
# To use Google's AI Studio simply obtain an API key from https://aistudio.google.com/
# $ pip3 install google-generativeai
'nl2q': {
'vertexai': {
'model': 'gemini-1.5-flash-001',
'project_id': '',
},
},
'default': {
'aistudio': {
'api_key': '',
'model': 'gemini-2.0-flash-exp',
},
}
}


# LLM nl2q configuration
DATA_TYPES_PATH = '/etc/timesketch/nl2q/data_types.csv'
PROMPT_NL2Q = '/etc/timesketch/nl2q/prompt_nl2q'
EXAMPLES_NL2Q = '/etc/timesketch/nl2q/examples_nl2q'
LLM_PROVIDER = ''
3 changes: 2 additions & 1 deletion timesketch/api/v1/resources/nl2q.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,9 @@ def post(self, sketch_id):
"query_string": None,
"error": None,
}
feature_name = "nl2q"
try:
llm = manager.LLMManager().get_provider(llm_provider)()
llm = manager.LLMManager.create_provider(feature_name=feature_name)
except Exception as e: # pylint: disable=broad-except
logger.error("Error LLM Provider: {}".format(e))
result_schema["error"] = (
Expand Down
15 changes: 12 additions & 3 deletions timesketch/api/v1/resources/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
# limitations under the License.
"""System settings."""

from flask import current_app
from flask import jsonify
from flask import current_app, jsonify
from flask_restful import Resource
from flask_login import login_required

Expand All @@ -30,10 +29,20 @@ def get(self):
JSON object with system settings.
"""
# Settings from timesketch.conf to expose to the frontend clients.
settings_to_return = ["LLM_PROVIDER", "DFIQ_ENABLED"]
settings_to_return = ["DFIQ_ENABLED"]
result = {}

for setting in settings_to_return:
result[setting] = current_app.config.get(setting)

# Derive the default LLM provider from the new configuration.
# Expecting the "default" config to be a dict with exactly one key:
# the provider name.
llm_configs = current_app.config.get("LLM_PROVIDER_CONFIGS", {})
default_provider = None
default_conf = llm_configs.get("default")
if default_conf and isinstance(default_conf, dict) and len(default_conf) == 1:
default_provider = next(iter(default_conf))
result["LLM_PROVIDER"] = default_provider

return jsonify(result)
22 changes: 15 additions & 7 deletions timesketch/api/v1/resources_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1191,12 +1191,11 @@ class TestNl2qResource(BaseTest):

resource_url = "/api/v1/sketches/1/nl2q/"

@mock.patch("timesketch.lib.llms.manager.LLMManager")
@mock.patch("timesketch.lib.llms.manager.LLMManager.create_provider")
@mock.patch("timesketch.api.v1.utils.run_aggregator")
@mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore)
def test_nl2q_prompt(self, mock_aggregator, mock_llm_manager):
def test_nl2q_prompt(self, mock_aggregator, mock_create_provider):
"""Test the prompt is created correctly."""

self.login()
data = dict(question="Question for LLM?")
mock_AggregationResult = mock.MagicMock()
Expand All @@ -1205,9 +1204,13 @@ def test_nl2q_prompt(self, mock_aggregator, mock_llm_manager):
{"data_type": "test:data_type:2"},
]
mock_aggregator.return_value = (mock_AggregationResult, {})

# Create a mock provider that returns the expected query.
mock_llm = mock.Mock()
mock_llm.generate.return_value = "LLM generated query"
mock_llm_manager.return_value.get_provider.return_value = lambda: mock_llm
# Patch create_provider to return our mock provider.
mock_create_provider.return_value = mock_llm

response = self.client.post(
self.resource_url,
data=json.dumps(data),
Expand Down Expand Up @@ -1313,7 +1316,7 @@ def test_nl2q_no_question(self):
def test_nl2q_wrong_llm_provider(self, mock_aggregator):
"""Test nl2q with llm provider that does not exist."""

self.app.config["LLM_PROVIDER"] = "DoesNotExists"
self.app.config["LLM_PROVIDER_CONFIGS"] = {"default": {"DoesNotExists": {}}}
self.login()
data = dict(question="Question for LLM?")
mock_AggregationResult = mock.MagicMock()
Expand All @@ -1333,9 +1336,11 @@ def test_nl2q_wrong_llm_provider(self, mock_aggregator):

@mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore)
def test_nl2q_no_llm_provider(self):
"""Test nl2q with no llm provider configured."""
"""Test nl2q with no LLM provider configured."""
if "LLM_PROVIDER_CONFIGS" in self.app.config:
del self.app.config["LLM_PROVIDER_CONFIGS"]
self.app.config["DFIQ_ENABLED"] = False

del self.app.config["LLM_PROVIDER"]
self.login()
data = dict(question="Question for LLM?")
response = self.client.post(
Expand Down Expand Up @@ -1405,6 +1410,9 @@ class SystemSettingsResourceTest(BaseTest):

def test_system_settings_resource(self):
"""Authenticated request to get system settings."""
self.app.config["LLM_PROVIDER_CONFIGS"] = {"default": {"test": {}}}
self.app.config["DFIQ_ENABLED"] = False

self.login()
response = self.client.get(self.resource_url)
expected_response = {"DFIQ_ENABLED": False, "LLM_PROVIDER": "test"}
Expand Down
62 changes: 27 additions & 35 deletions timesketch/lib/llms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@
import string
from typing import Optional

from flask import current_app

DEFAULT_TEMPERATURE = 0.1
DEFAULT_TOP_P = 0.1
DEFAULT_TOP_K = 1
Expand All @@ -27,12 +25,19 @@


class LLMProvider:
"""Base class for LLM providers."""
"""
Base class for LLM providers.

The provider is instantiated with a configuration dictionary that
was extracted (by the manager) from timesketch.conf.
Subclasses should override the NAME attribute.
"""

NAME = "name"

def __init__(
self,
config: dict,
temperature: float = DEFAULT_TEMPERATURE,
top_p: float = DEFAULT_TOP_P,
top_k: int = DEFAULT_TOP_K,
Expand All @@ -43,47 +48,32 @@ def __init__(
"""Initialize the LLM provider.

Args:
temperature: The temperature to use for the response.
top_p: The top_p to use for the response.
top_k: The top_k to use for the response.
max_output_tokens: The maximum number of output tokens to generate.
stream: Whether to stream the response.
location: The cloud location/region to use for the provider.
config: A dictionary of provider-specific configuration options.
temperature: Temperature setting for text generation.
top_p: Top probability (p) value used for generation.
top_k: Top-k value used for generation.
max_output_tokens: Maximum number of tokens to generate in the output.
stream: Whether to enable streaming of the generated content.
location: An optional location parameter for the provider.

Attributes:
config: The configuration for the LLM provider.

Raises:
Exception: If the LLM provider is not configured.
"""
config = {}
config["temperature"] = temperature
config["top_p"] = top_p
config["top_k"] = top_k
config["max_output_tokens"] = max_output_tokens
config["stream"] = stream
config["location"] = location

# Load the LLM provider config from the Flask app config
config_from_flask = current_app.config.get("LLM_PROVIDER_CONFIGS").get(
self.NAME
)
if not config_from_flask:
raise Exception(f"{self.NAME} config not found")

config.update(config_from_flask)
self.config = config
self.config = {
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"max_output_tokens": max_output_tokens,
"stream": stream,
"location": location,
}
self.config.update(config)

def prompt_from_template(self, template: str, kwargs: dict) -> str:
"""Format a prompt from a template.

Args:
template: The template to format.
kwargs: The keyword arguments to format the template with.

Returns:
The formatted prompt.
"""
"""Format a prompt from a template."""
formatter = string.Formatter()
return formatter.format(template, **kwargs)

Expand All @@ -97,5 +87,7 @@ def generate(self, prompt: str, response_schema: Optional[dict] = None) -> str:

Returns:
The generated response.

Subclasses must override this method.
"""
raise NotImplementedError()
45 changes: 35 additions & 10 deletions timesketch/lib/llms/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,23 @@
# limitations under the License.
"""This file contains a class for managing Large Language Model (LLM) providers."""

from flask import current_app
from timesketch.lib.llms.interface import LLMProvider


class LLMManager:
"""The manager for LLM providers."""

_class_registry = {}

@classmethod
def register_provider(cls, provider_class: type) -> None:
"""Register a provider class."""
provider_name = provider_class.NAME.lower()
if provider_name in cls._class_registry:
raise ValueError(f"Provider {provider_class.NAME} already registered")
cls._class_registry[provider_name] = provider_class

@classmethod
def get_providers(cls):
"""Get all registered providers.
Expand Down Expand Up @@ -48,19 +59,33 @@ def get_provider(cls, provider_name: str) -> type:
return provider_class

@classmethod
def register_provider(cls, provider_class: type) -> None:
"""Register a provider.
def create_provider(cls, feature_name: str = None, **kwargs) -> LLMProvider:
"""
Create an instance of the provider for the given feature.

Args:
provider_class: The provider class to register.
If a configuration exists for the feature in
current_app.config["LLM_PROVIDER_CONFIGS"], use it; otherwise,
fall back to the configuration under the "default" key.

Raises:
ValueError: If the provider is already registered.
The configuration is expected to be a dict with exactly one key:
the provider name.
"""
provider_name = provider_class.NAME.lower()
if provider_name in cls._class_registry:
raise ValueError(f"Provider {provider_class.NAME} already registered")
cls._class_registry[provider_name] = provider_class
llm_configs = current_app.config.get("LLM_PROVIDER_CONFIGS", {})
if feature_name and feature_name in llm_configs:
config_mapping = llm_configs[feature_name]
else:
config_mapping = llm_configs.get("default")

if not config_mapping or len(config_mapping) != 1:
raise ValueError(
"Configuration for the feature must specify exactly one provider."
)

provider_name = next(iter(config_mapping))
provider_config = config_mapping[provider_name]

provider_class = cls.get_provider(provider_name)
return provider_class(config=provider_config, **kwargs)

@classmethod
def clear_registration(cls):
Expand Down
13 changes: 7 additions & 6 deletions timesketch/lib/llms/manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,19 @@


class MockProvider:
"""A mock LLM provider."""

NAME = "mock"

def generate_text(self) -> str:
"""Generate text."""
return "This is a mock LLM provider."
def __init__(self, config: dict, **kwargs):
self.config = config

def generate(self, prompt: str, response_schema: dict = None) -> str:
return "mock response"


class TestLLMManager(BaseTest):
"""Tests for the functionality of the manager module."""

# Clear the registry and register the mock provider.
manager.LLMManager.clear_registration()
manager.LLMManager.register_provider(MockProvider)

Expand All @@ -51,7 +52,7 @@ def test_get_provider(self):
self.assertRaises(KeyError, manager.LLMManager.get_provider, "no_such_provider")

def test_register_provider(self):
"""Test so we raise KeyError when provider is already registered."""
"""Test that registering a provider that is already registered raises ValueError."""
self.assertRaises(
ValueError, manager.LLMManager.register_provider, MockProvider
)
1 change: 1 addition & 0 deletions timesketch/lib/testlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ class TestConfig(object):
INTELLIGENCE_TAG_METADATA = "./data/intelligence_tag_metadata.yaml"
CONTEXT_LINKS_CONFIG_PATH = "./tests/test_events/mock_context_links.yaml"
LLM_PROVIDER = "test"
LLM_PROVIDER_CONFIGS = {"default": {"test": "test"}}
DFIQ_ENABLED = False
DATA_TYPES_PATH = "./test_data/nl2q/test_data_types.csv"
PROMPT_NL2Q = "./test_data/nl2q/test_prompt_nl2q"
Expand Down
Loading