From 45530f7d4bbff9c0408b5e556619baade590b71e Mon Sep 17 00:00:00 2001 From: Maarten van Dantzig Date: Mon, 3 Feb 2025 18:14:09 +0100 Subject: [PATCH 01/13] Refactor LLM manager so that users can configure an LLM provider per feature --- data/timesketch.conf | 56 ++++++++++++---------- timesketch/api/v1/resources/nl2q.py | 3 +- timesketch/api/v1/resources/settings.py | 14 ++++-- timesketch/api/v1/resources_test.py | 11 +++-- timesketch/lib/llms/interface.py | 62 +++++++++++-------------- timesketch/lib/llms/manager.py | 43 ++++++++++++----- timesketch/lib/testlib.py | 1 + 7 files changed, 111 insertions(+), 79 deletions(-) diff --git a/data/timesketch.conf b/data/timesketch.conf index 6cca7592e1..0952ec2f1d 100644 --- a/data/timesketch.conf +++ b/data/timesketch.conf @@ -353,36 +353,42 @@ CONTEXT_LINKS_CONFIG_PATH = '/etc/timesketch/context_links.yaml' # LLM provider configs LLM_PROVIDER_CONFIGS = { - # To use the Ollama provider you need to download and run an Ollama server. - # See instructions at: https://ollama.ai/ - 'ollama': { - 'server_url': 'http://localhost:11434', - 'model': 'gemma:7b', - }, - # To use the Vertex AI provider you need to: - # 1. Create and export a Service Account Key from the Google Cloud Console. - # 2. Set the GOOGLE_APPLICATION_CREDENTIALS environment variable to the full path - # to your service account private key file by adding it to the docker-compose.yml - # under environment: - # GOOGLE_APPLICATION_CREDENTIALS=/usr/local/src/timesketch/.json - # 3. Install the python libraries: $ pip3 install google-cloud-aiplatform + # Configure a LLM provider for a specific LLM enabled feature, or the + # default provider will be used. + # Supported LLM Providers: + # - ollama: Self-hosted, open-source. + # To use the Ollama provider you need to download and run an Ollama server. + # See instructions at: https://ollama.ai/ + # - vertexai: Google Cloud Vertex AI. Requires Google Cloud Project. + # To use the Vertex AI provider you need to: + # 1. Create and export a Service Account Key from the Google Cloud Console. + # 2. Set the GOOGLE_APPLICATION_CREDENTIALS environment variable to the full path + # to your service account private key file by adding it to the docker-compose.yml + # under environment: + # GOOGLE_APPLICATION_CREDENTIALS=/usr/local/src/timesketch/.json + # 3. Install the python libraries: $ pip3 install google-cloud-aiplatform # - # IMPORTANT: Private keys must be kept secret. If you expose your private key it is - # recommended to revoke it immediately from the Google Cloud Console. - 'vertexai': { - 'model': 'gemini-1.5-flash-001', - 'project_id': '', - }, - # To use Google's AI Studio simply obtain an API key from https://aistudio.google.com/ - # pip3 install google-generativeai - 'aistudio': { - 'api_key': '', - 'model': 'gemini-2.0-flash-exp', + # IMPORTANT: Private keys must be kept secret. If you expose your private key it is + # recommended to revoke it immediately from the Google Cloud Console. + # - aistudio: Google AI Studio (API key). Get API key from Google AI Studio website. + # To use Google's AI Studio simply obtain an API key from https://aistudio.google.com/ + # $ pip3 install google-generativeai + 'nl2q': { + 'vertexai': { + 'model': 'gemini-1.5-flash-001', + 'project_id': '', + }, }, + 'default': { + 'aistudio': { + 'api_key': '', + 'model': 'gemini-2.0-flash-exp', + }, + } } + # LLM nl2q configuration DATA_TYPES_PATH = '/etc/timesketch/nl2q/data_types.csv' PROMPT_NL2Q = '/etc/timesketch/nl2q/prompt_nl2q' EXAMPLES_NL2Q = '/etc/timesketch/nl2q/examples_nl2q' -LLM_PROVIDER = '' diff --git a/timesketch/api/v1/resources/nl2q.py b/timesketch/api/v1/resources/nl2q.py index 23a122a5d3..b6928ff55a 100644 --- a/timesketch/api/v1/resources/nl2q.py +++ b/timesketch/api/v1/resources/nl2q.py @@ -205,8 +205,9 @@ def post(self, sketch_id): "query_string": None, "error": None, } + feature_name = "nl2q" try: - llm = manager.LLMManager().get_provider(llm_provider)() + llm = manager.LLMManager.create_provider(feature_name=feature_name) except Exception as e: # pylint: disable=broad-except logger.error("Error LLM Provider: {}".format(e)) result_schema["error"] = ( diff --git a/timesketch/api/v1/resources/settings.py b/timesketch/api/v1/resources/settings.py index 2f8dd1c311..3f36ea9bfd 100644 --- a/timesketch/api/v1/resources/settings.py +++ b/timesketch/api/v1/resources/settings.py @@ -13,8 +13,7 @@ # limitations under the License. """System settings.""" -from flask import current_app -from flask import jsonify +from flask import current_app, jsonify from flask_restful import Resource from flask_login import login_required @@ -30,10 +29,19 @@ def get(self): JSON object with system settings. """ # Settings from timesketch.conf to expose to the frontend clients. - settings_to_return = ["LLM_PROVIDER", "DFIQ_ENABLED"] + settings_to_return = ["DFIQ_ENABLED"] result = {} for setting in settings_to_return: result[setting] = current_app.config.get(setting) + # Derive the default LLM provider from the new configuration. + # Expecting the "default" config to be a dict with exactly one key: the provider name. + llm_configs = current_app.config.get("LLM_PROVIDER_CONFIGS", {}) + default_provider = None + default_conf = llm_configs.get("default") + if default_conf and isinstance(default_conf, dict) and len(default_conf) == 1: + default_provider = next(iter(default_conf)) + result["LLM_PROVIDER"] = default_provider + return jsonify(result) diff --git a/timesketch/api/v1/resources_test.py b/timesketch/api/v1/resources_test.py index 236b8e9bb1..f844cd787a 100644 --- a/timesketch/api/v1/resources_test.py +++ b/timesketch/api/v1/resources_test.py @@ -1313,7 +1313,7 @@ def test_nl2q_no_question(self): def test_nl2q_wrong_llm_provider(self, mock_aggregator): """Test nl2q with llm provider that does not exist.""" - self.app.config["LLM_PROVIDER"] = "DoesNotExists" + self.app.config["LLM_PROVIDER_CONFIGS"] = {"default": {"DoesNotExists": {}}} self.login() data = dict(question="Question for LLM?") mock_AggregationResult = mock.MagicMock() @@ -1333,9 +1333,11 @@ def test_nl2q_wrong_llm_provider(self, mock_aggregator): @mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore) def test_nl2q_no_llm_provider(self): - """Test nl2q with no llm provider configured.""" + """Test nl2q with no LLM provider configured.""" + if "LLM_PROVIDER_CONFIGS" in self.app.config: + del self.app.config["LLM_PROVIDER_CONFIGS"] + self.app.config["DFIQ_ENABLED"] = False - del self.app.config["LLM_PROVIDER"] self.login() data = dict(question="Question for LLM?") response = self.client.post( @@ -1405,6 +1407,9 @@ class SystemSettingsResourceTest(BaseTest): def test_system_settings_resource(self): """Authenticated request to get system settings.""" + self.app.config["LLM_PROVIDER_CONFIGS"] = {"default": {"test": {}}} + self.app.config["DFIQ_ENABLED"] = False + self.login() response = self.client.get(self.resource_url) expected_response = {"DFIQ_ENABLED": False, "LLM_PROVIDER": "test"} diff --git a/timesketch/lib/llms/interface.py b/timesketch/lib/llms/interface.py index a54699fac8..cef1d73d0f 100644 --- a/timesketch/lib/llms/interface.py +++ b/timesketch/lib/llms/interface.py @@ -16,8 +16,6 @@ import string from typing import Optional -from flask import current_app - DEFAULT_TEMPERATURE = 0.1 DEFAULT_TOP_P = 0.1 DEFAULT_TOP_K = 1 @@ -27,12 +25,19 @@ class LLMProvider: - """Base class for LLM providers.""" + """ + Base class for LLM providers. + + The provider is instantiated with a configuration dictionary that + was extracted (by the manager) from timesketch.conf. + Subclasses should override the NAME attribute. + """ NAME = "name" def __init__( self, + config: dict, temperature: float = DEFAULT_TEMPERATURE, top_p: float = DEFAULT_TOP_P, top_k: int = DEFAULT_TOP_K, @@ -43,12 +48,13 @@ def __init__( """Initialize the LLM provider. Args: - temperature: The temperature to use for the response. - top_p: The top_p to use for the response. - top_k: The top_k to use for the response. - max_output_tokens: The maximum number of output tokens to generate. - stream: Whether to stream the response. - location: The cloud location/region to use for the provider. + config: A dictionary of provider-specific configuration options. + temperature: Temperature setting for text generation. + top_p: Top probability (p) value used for generation. + top_k: Top-k value used for generation. + max_output_tokens: Maximum number of tokens to generate in the output. + stream: Whether to enable streaming of the generated content. + location: An optional location parameter for the provider. Attributes: config: The configuration for the LLM provider. @@ -56,34 +62,18 @@ def __init__( Raises: Exception: If the LLM provider is not configured. """ - config = {} - config["temperature"] = temperature - config["top_p"] = top_p - config["top_k"] = top_k - config["max_output_tokens"] = max_output_tokens - config["stream"] = stream - config["location"] = location - - # Load the LLM provider config from the Flask app config - config_from_flask = current_app.config.get("LLM_PROVIDER_CONFIGS").get( - self.NAME - ) - if not config_from_flask: - raise Exception(f"{self.NAME} config not found") - - config.update(config_from_flask) - self.config = config + self.config = { + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "max_output_tokens": max_output_tokens, + "stream": stream, + "location": location, + } + self.config.update(config) def prompt_from_template(self, template: str, kwargs: dict) -> str: - """Format a prompt from a template. - - Args: - template: The template to format. - kwargs: The keyword arguments to format the template with. - - Returns: - The formatted prompt. - """ + """Format a prompt from a template.""" formatter = string.Formatter() return formatter.format(template, **kwargs) @@ -97,5 +87,7 @@ def generate(self, prompt: str, response_schema: Optional[dict] = None) -> str: Returns: The generated response. + + Subclasses must override this method. """ raise NotImplementedError() diff --git a/timesketch/lib/llms/manager.py b/timesketch/lib/llms/manager.py index a3d6b00614..4d0ce28744 100644 --- a/timesketch/lib/llms/manager.py +++ b/timesketch/lib/llms/manager.py @@ -13,12 +13,19 @@ # limitations under the License. """This file contains a class for managing Large Language Model (LLM) providers.""" +from flask import current_app +from timesketch.lib.llms.interface import LLMProvider -class LLMManager: - """The manager for LLM providers.""" +class LLMManager: _class_registry = {} + @classmethod + def register_provider(cls, provider_class: type) -> None: + """Register a provider class.""" + provider_name = provider_class.NAME.lower() + cls._class_registry[provider_name] = provider_class + @classmethod def get_providers(cls): """Get all registered providers. @@ -48,19 +55,31 @@ def get_provider(cls, provider_name: str) -> type: return provider_class @classmethod - def register_provider(cls, provider_class: type) -> None: - """Register a provider. + def create_provider(cls, feature_name: str = None, **kwargs) -> LLMProvider: + """ + Create an instance of the provider for the given feature. - Args: - provider_class: The provider class to register. + If a configuration exists for the feature in current_app.config["LLM_PROVIDER_CONFIGS"], + use it; otherwise, fall back to the configuration under the "default" key. - Raises: - ValueError: If the provider is already registered. + The configuration is expected to be a dict with exactly one key: the provider name. """ - provider_name = provider_class.NAME.lower() - if provider_name in cls._class_registry: - raise ValueError(f"Provider {provider_class.NAME} already registered") - cls._class_registry[provider_name] = provider_class + llm_configs = current_app.config.get("LLM_PROVIDER_CONFIGS", {}) + if feature_name and feature_name in llm_configs: + config_mapping = llm_configs[feature_name] + else: + config_mapping = llm_configs.get("default") + + if not config_mapping or len(config_mapping) != 1: + raise Exception( + "Configuration for the feature must specify exactly one provider." + ) + + provider_name = next(iter(config_mapping)) + provider_config = config_mapping[provider_name] + + provider_class = cls.get_provider(provider_name) + return provider_class(config=provider_config, **kwargs) @classmethod def clear_registration(cls): diff --git a/timesketch/lib/testlib.py b/timesketch/lib/testlib.py index 628f6b4e0d..7fca3289e3 100644 --- a/timesketch/lib/testlib.py +++ b/timesketch/lib/testlib.py @@ -81,6 +81,7 @@ class TestConfig(object): INTELLIGENCE_TAG_METADATA = "./data/intelligence_tag_metadata.yaml" CONTEXT_LINKS_CONFIG_PATH = "./tests/test_events/mock_context_links.yaml" LLM_PROVIDER = "test" + LLM_PROVIDER_CONFIGS = {"default": {"test": "test"}} DFIQ_ENABLED = False DATA_TYPES_PATH = "./test_data/nl2q/test_data_types.csv" PROMPT_NL2Q = "./test_data/nl2q/test_prompt_nl2q" From bfa1d5f4318fff76f96288b6f6c66ce3ba0a5991 Mon Sep 17 00:00:00 2001 From: Maarten van Dantzig Date: Mon, 3 Feb 2025 18:20:13 +0100 Subject: [PATCH 02/13] Linter fix --- timesketch/api/v1/resources/settings.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/timesketch/api/v1/resources/settings.py b/timesketch/api/v1/resources/settings.py index 3f36ea9bfd..4ca98e2fdc 100644 --- a/timesketch/api/v1/resources/settings.py +++ b/timesketch/api/v1/resources/settings.py @@ -36,7 +36,8 @@ def get(self): result[setting] = current_app.config.get(setting) # Derive the default LLM provider from the new configuration. - # Expecting the "default" config to be a dict with exactly one key: the provider name. + # Expecting the "default" config to be a dict with exactly one key: + # the provider name. llm_configs = current_app.config.get("LLM_PROVIDER_CONFIGS", {}) default_provider = None default_conf = llm_configs.get("default") From e15e16cf32287dd12701027ad21d8b06dede7715 Mon Sep 17 00:00:00 2001 From: Maarten van Dantzig Date: Mon, 3 Feb 2025 18:33:59 +0100 Subject: [PATCH 03/13] fix manager test --- timesketch/api/v1/resources/settings.py | 2 +- timesketch/lib/llms/manager_test.py | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/timesketch/api/v1/resources/settings.py b/timesketch/api/v1/resources/settings.py index 4ca98e2fdc..eac0c3b8af 100644 --- a/timesketch/api/v1/resources/settings.py +++ b/timesketch/api/v1/resources/settings.py @@ -36,7 +36,7 @@ def get(self): result[setting] = current_app.config.get(setting) # Derive the default LLM provider from the new configuration. - # Expecting the "default" config to be a dict with exactly one key: + # Expecting the "default" config to be a dict with exactly one key: # the provider name. llm_configs = current_app.config.get("LLM_PROVIDER_CONFIGS", {}) default_provider = None diff --git a/timesketch/lib/llms/manager_test.py b/timesketch/lib/llms/manager_test.py index 16c0c561a0..72f6746de9 100644 --- a/timesketch/lib/llms/manager_test.py +++ b/timesketch/lib/llms/manager_test.py @@ -18,18 +18,19 @@ class MockProvider: - """A mock LLM provider.""" - NAME = "mock" - def generate_text(self) -> str: - """Generate text.""" - return "This is a mock LLM provider." + def __init__(self, config: dict, **kwargs): + self.config = config + + def generate(self, prompt: str, response_schema: dict = None) -> str: + return "mock response" class TestLLMManager(BaseTest): """Tests for the functionality of the manager module.""" + # Clear the registry and register the mock provider. manager.LLMManager.clear_registration() manager.LLMManager.register_provider(MockProvider) @@ -51,7 +52,7 @@ def test_get_provider(self): self.assertRaises(KeyError, manager.LLMManager.get_provider, "no_such_provider") def test_register_provider(self): - """Test so we raise KeyError when provider is already registered.""" + """Test that registering a provider that is already registered raises ValueError.""" self.assertRaises( ValueError, manager.LLMManager.register_provider, MockProvider ) From 74f4461f96ad4b16ab2c1bdd1300fabd0c4007af Mon Sep 17 00:00:00 2001 From: Maarten van Dantzig Date: Mon, 3 Feb 2025 18:36:46 +0100 Subject: [PATCH 04/13] fix manager.py error throwing --- timesketch/lib/llms/manager.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/timesketch/lib/llms/manager.py b/timesketch/lib/llms/manager.py index 4d0ce28744..c9f9bda6d3 100644 --- a/timesketch/lib/llms/manager.py +++ b/timesketch/lib/llms/manager.py @@ -24,6 +24,8 @@ class LLMManager: def register_provider(cls, provider_class: type) -> None: """Register a provider class.""" provider_name = provider_class.NAME.lower() + if provider_name in cls._class_registry: + raise ValueError(f"Provider {provider_class.NAME} already registered") cls._class_registry[provider_name] = provider_class @classmethod From eeacfd6b6b2fa1152b87888ece1df6ca5e8c16c8 Mon Sep 17 00:00:00 2001 From: Maarten van Dantzig Date: Mon, 3 Feb 2025 18:38:09 +0100 Subject: [PATCH 05/13] fix manager.py to throw specific error --- timesketch/lib/llms/manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timesketch/lib/llms/manager.py b/timesketch/lib/llms/manager.py index c9f9bda6d3..0ceb17878f 100644 --- a/timesketch/lib/llms/manager.py +++ b/timesketch/lib/llms/manager.py @@ -73,7 +73,7 @@ def create_provider(cls, feature_name: str = None, **kwargs) -> LLMProvider: config_mapping = llm_configs.get("default") if not config_mapping or len(config_mapping) != 1: - raise Exception( + raise ValueError( "Configuration for the feature must specify exactly one provider." ) From e2abfdbeb11e0b3cc593fb34888f1b78bddd7a49 Mon Sep 17 00:00:00 2001 From: Maarten van Dantzig Date: Mon, 3 Feb 2025 20:38:36 +0100 Subject: [PATCH 06/13] fix unit test --- timesketch/api/v1/resources_test.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/timesketch/api/v1/resources_test.py b/timesketch/api/v1/resources_test.py index f844cd787a..3b8ea1712f 100644 --- a/timesketch/api/v1/resources_test.py +++ b/timesketch/api/v1/resources_test.py @@ -1191,12 +1191,11 @@ class TestNl2qResource(BaseTest): resource_url = "/api/v1/sketches/1/nl2q/" - @mock.patch("timesketch.lib.llms.manager.LLMManager") + @mock.patch("timesketch.lib.llms.manager.LLMManager.create_provider") @mock.patch("timesketch.api.v1.utils.run_aggregator") @mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore) - def test_nl2q_prompt(self, mock_aggregator, mock_llm_manager): + def test_nl2q_prompt(self, mock_aggregator, mock_create_provider): """Test the prompt is created correctly.""" - self.login() data = dict(question="Question for LLM?") mock_AggregationResult = mock.MagicMock() @@ -1205,9 +1204,13 @@ def test_nl2q_prompt(self, mock_aggregator, mock_llm_manager): {"data_type": "test:data_type:2"}, ] mock_aggregator.return_value = (mock_AggregationResult, {}) + + # Create a mock provider that returns the expected query. mock_llm = mock.Mock() mock_llm.generate.return_value = "LLM generated query" - mock_llm_manager.return_value.get_provider.return_value = lambda: mock_llm + # Patch create_provider to return our mock provider. + mock_create_provider.return_value = mock_llm + response = self.client.post( self.resource_url, data=json.dumps(data), From 364ec10fe4647d53c5e1a38c5b45cdae298e77dc Mon Sep 17 00:00:00 2001 From: Maarten van Dantzig Date: Mon, 3 Feb 2025 20:41:36 +0100 Subject: [PATCH 07/13] Fix LLMManager doc-string --- timesketch/lib/llms/manager.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/timesketch/lib/llms/manager.py b/timesketch/lib/llms/manager.py index 0ceb17878f..5412abcec6 100644 --- a/timesketch/lib/llms/manager.py +++ b/timesketch/lib/llms/manager.py @@ -18,6 +18,8 @@ class LLMManager: + """The manager for LLM providers.""" + _class_registry = {} @classmethod @@ -61,10 +63,12 @@ def create_provider(cls, feature_name: str = None, **kwargs) -> LLMProvider: """ Create an instance of the provider for the given feature. - If a configuration exists for the feature in current_app.config["LLM_PROVIDER_CONFIGS"], - use it; otherwise, fall back to the configuration under the "default" key. + If a configuration exists for the feature in + current_app.config["LLM_PROVIDER_CONFIGS"], use it; otherwise, + fall back to the configuration under the "default" key. - The configuration is expected to be a dict with exactly one key: the provider name. + The configuration is expected to be a dict with exactly one key: + the provider name. """ llm_configs = current_app.config.get("LLM_PROVIDER_CONFIGS", {}) if feature_name and feature_name in llm_configs: From f653145d3bbad888660cf9cf44ca965a017c6c22 Mon Sep 17 00:00:00 2001 From: Maarten van Dantzig Date: Tue, 4 Feb 2025 11:59:18 +0100 Subject: [PATCH 08/13] Fix tests --- timesketch/api/v1/resources/nl2q.py | 38 ++++++++++------------------- timesketch/api/v1/resources_test.py | 19 +++++++-------- 2 files changed, 22 insertions(+), 35 deletions(-) diff --git a/timesketch/api/v1/resources/nl2q.py b/timesketch/api/v1/resources/nl2q.py index b6928ff55a..d6f07362a8 100644 --- a/timesketch/api/v1/resources/nl2q.py +++ b/timesketch/api/v1/resources/nl2q.py @@ -170,45 +170,34 @@ def concatenate_values(self, group): @login_required def post(self, sketch_id): - """Handles POST request to the resource. - - Args: - sketch_id: Sketch ID. - - Returns: - JSON representing the LLM prediction. - """ - llm_provider = current_app.config.get("LLM_PROVIDER", "") - if not llm_provider: - logger.error("No LLM provider was defined in the main configuration file") - abort( - HTTP_STATUS_CODE_INTERNAL_SERVER_ERROR, - "No LLM provider was defined in the main configuration file", - ) form = request.json if not form: - abort( - HTTP_STATUS_CODE_BAD_REQUEST, - "No JSON data provided", - ) + abort(HTTP_STATUS_CODE_BAD_REQUEST, "No JSON data provided") if "question" not in form: + abort(HTTP_STATUS_CODE_BAD_REQUEST, "The 'question' parameter is required!") + + llm_configs = current_app.config.get("LLM_PROVIDER_CONFIGS") + if not llm_configs: + logger.error("No LLM provider configuration defined.") abort( - HTTP_STATUS_CODE_BAD_REQUEST, - "The 'question' parameter is required!", + HTTP_STATUS_CODE_INTERNAL_SERVER_ERROR, + "No LLM provider was defined in the main configuration file", ) question = form.get("question") prompt = self.build_prompt(question, sketch_id) + result_schema = { "name": "AI generated search query", "query_string": None, "error": None, } + feature_name = "nl2q" try: llm = manager.LLMManager.create_provider(feature_name=feature_name) - except Exception as e: # pylint: disable=broad-except + except Exception as e: logger.error("Error LLM Provider: {}".format(e)) result_schema["error"] = ( "Error loading LLM Provider. Please try again later!" @@ -217,14 +206,13 @@ def post(self, sketch_id): try: prediction = llm.generate(prompt) - except Exception as e: # pylint: disable=broad-except + except Exception as e: logger.error("Error NL2Q prompt: {}".format(e)) result_schema["error"] = ( "An error occurred generating the query via the defined LLM. " "Please try again later!" ) return jsonify(result_schema) - # The model sometimes output triple backticks that needs to be removed. - result_schema["query_string"] = prediction.strip("```") + result_schema["query_string"] = prediction.strip("```") return jsonify(result_schema) diff --git a/timesketch/api/v1/resources_test.py b/timesketch/api/v1/resources_test.py index 3b8ea1712f..52d3eb82b0 100644 --- a/timesketch/api/v1/resources_test.py +++ b/timesketch/api/v1/resources_test.py @@ -1196,6 +1196,7 @@ class TestNl2qResource(BaseTest): @mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore) def test_nl2q_prompt(self, mock_aggregator, mock_create_provider): """Test the prompt is created correctly.""" + self.login() data = dict(question="Question for LLM?") mock_AggregationResult = mock.MagicMock() @@ -1204,13 +1205,9 @@ def test_nl2q_prompt(self, mock_aggregator, mock_create_provider): {"data_type": "test:data_type:2"}, ] mock_aggregator.return_value = (mock_AggregationResult, {}) - - # Create a mock provider that returns the expected query. mock_llm = mock.Mock() mock_llm.generate.return_value = "LLM generated query" - # Patch create_provider to return our mock provider. mock_create_provider.return_value = mock_llm - response = self.client.post( self.resource_url, data=json.dumps(data), @@ -1318,6 +1315,7 @@ def test_nl2q_wrong_llm_provider(self, mock_aggregator): self.app.config["LLM_PROVIDER_CONFIGS"] = {"default": {"DoesNotExists": {}}} self.login() + self.login() data = dict(question="Question for LLM?") mock_AggregationResult = mock.MagicMock() mock_AggregationResult.values = [ @@ -1337,10 +1335,9 @@ def test_nl2q_wrong_llm_provider(self, mock_aggregator): @mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore) def test_nl2q_no_llm_provider(self): """Test nl2q with no LLM provider configured.""" + if "LLM_PROVIDER_CONFIGS" in self.app.config: del self.app.config["LLM_PROVIDER_CONFIGS"] - self.app.config["DFIQ_ENABLED"] = False - self.login() data = dict(question="Question for LLM?") response = self.client.post( @@ -1376,10 +1373,10 @@ def test_nl2q_no_permission(self): ) self.assertEqual(response.status_code, HTTP_STATUS_CODE_FORBIDDEN) - @mock.patch("timesketch.lib.llms.manager.LLMManager") + @mock.patch("timesketch.lib.llms.manager.LLMManager.create_provider") @mock.patch("timesketch.api.v1.utils.run_aggregator") @mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore) - def test_nl2q_llm_error(self, mock_aggregator, mock_llm_manager): + def test_nl2q_llm_error(self, mock_aggregator, mock_create_provider): """Test nl2q with llm error.""" self.login() @@ -1392,13 +1389,15 @@ def test_nl2q_llm_error(self, mock_aggregator, mock_llm_manager): mock_aggregator.return_value = (mock_AggregationResult, {}) mock_llm = mock.Mock() mock_llm.generate.side_effect = Exception("Test exception") - mock_llm_manager.return_value.get_provider.return_value = lambda: mock_llm + mock_create_provider.return_value = mock_llm response = self.client.post( self.resource_url, data=json.dumps(data), content_type="application/json", ) - self.assertEqual(response.status_code, HTTP_STATUS_CODE_OK) + self.assertEqual( + response.status_code, HTTP_STATUS_CODE_OK + ) # Still expect 200 OK with error in JSON data = json.loads(response.get_data(as_text=True)) self.assertIsNotNone(data.get("error")) From da6aacd137a9614d325bb3cda44454959f228e0b Mon Sep 17 00:00:00 2001 From: Maarten van Dantzig Date: Tue, 4 Feb 2025 12:42:40 +0100 Subject: [PATCH 09/13] linter fix --- timesketch/api/v1/resources/nl2q.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/timesketch/api/v1/resources/nl2q.py b/timesketch/api/v1/resources/nl2q.py index d6f07362a8..a1f1e37e97 100644 --- a/timesketch/api/v1/resources/nl2q.py +++ b/timesketch/api/v1/resources/nl2q.py @@ -170,6 +170,14 @@ def concatenate_values(self, group): @login_required def post(self, sketch_id): + """Handles POST request to the resource. + + Args: + sketch_id: Sketch ID. + + Returns: + JSON representing the LLM prediction. + """ form = request.json if not form: abort(HTTP_STATUS_CODE_BAD_REQUEST, "No JSON data provided") @@ -197,7 +205,7 @@ def post(self, sketch_id): feature_name = "nl2q" try: llm = manager.LLMManager.create_provider(feature_name=feature_name) - except Exception as e: + except Exception as e: # pylint: disable=broad-except logger.error("Error LLM Provider: {}".format(e)) result_schema["error"] = ( "Error loading LLM Provider. Please try again later!" From fd127082474587558a421384694938f36b159815 Mon Sep 17 00:00:00 2001 From: Maarten van Dantzig Date: Tue, 4 Feb 2025 12:58:35 +0100 Subject: [PATCH 10/13] linter fix --- timesketch/api/v1/resources/nl2q.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/timesketch/api/v1/resources/nl2q.py b/timesketch/api/v1/resources/nl2q.py index a1f1e37e97..d016a768f7 100644 --- a/timesketch/api/v1/resources/nl2q.py +++ b/timesketch/api/v1/resources/nl2q.py @@ -214,7 +214,7 @@ def post(self, sketch_id): try: prediction = llm.generate(prompt) - except Exception as e: + except Exception as e: # pylint: disable=broad-except logger.error("Error NL2Q prompt: {}".format(e)) result_schema["error"] = ( "An error occurred generating the query via the defined LLM. " From 84f593b561c83565b44823fd19185f4421d48228 Mon Sep 17 00:00:00 2001 From: Maarten van Dantzig Date: Tue, 4 Feb 2025 13:33:15 +0100 Subject: [PATCH 11/13] update manager_test.py tests --- timesketch/lib/llms/manager_test.py | 128 +++++++++++++++++++++++----- 1 file changed, 108 insertions(+), 20 deletions(-) diff --git a/timesketch/lib/llms/manager_test.py b/timesketch/lib/llms/manager_test.py index 72f6746de9..ae600d7719 100644 --- a/timesketch/lib/llms/manager_test.py +++ b/timesketch/lib/llms/manager_test.py @@ -17,42 +17,130 @@ from timesketch.lib.llms import manager -class MockProvider: - NAME = "mock" +class MockAistudioProvider: + """A mock provider for Google AI Studio (using API key).""" - def __init__(self, config: dict, **kwargs): + NAME = "aistudio" + + def __init__(self, config, **kwargs): self.config = config + self.kwargs = kwargs + + def generate(self, prompt=None) -> str: + return "Generated response from AI Studio." + - def generate(self, prompt: str, response_schema: dict = None) -> str: - return "mock response" +class MockVertexAIProvider: + """A mock provider for Google Cloud Vertex AI.""" + + NAME = "vertexai" + + def __init__(self, config, **kwargs): + self.config = config + self.kwargs = kwargs + + def generate(self, prompt=None) -> str: + return "Generated response from Vertex AI." class TestLLMManager(BaseTest): - """Tests for the functionality of the manager module.""" + """Tests for the functionality of the LLMManager module.""" - # Clear the registry and register the mock provider. - manager.LLMManager.clear_registration() - manager.LLMManager.register_provider(MockProvider) + def setUp(self) -> None: + super().setUp() + manager.LLMManager.clear_registration() + manager.LLMManager.register_provider(MockAistudioProvider) + manager.LLMManager.register_provider(MockVertexAIProvider) + self.app.config["LLM_PROVIDER_CONFIGS"] = { + "default": { + "aistudio": { + "api_key": "AIzaSyTestDefaultKey", + "model": "gemini-2.0-flash-exp", + } + }, + } + + def tearDown(self) -> None: + manager.LLMManager.clear_registration() + super().tearDown() def test_get_providers(self): - """Test to get provider class objects.""" + """Test that get_providers returns the registered providers.""" providers = manager.LLMManager.get_providers() provider_list = list(providers) - first_provider_tuple = provider_list[0] - provider_name, provider_class = first_provider_tuple self.assertIsInstance(provider_list, list) - self.assertIsInstance(first_provider_tuple, tuple) - self.assertEqual(provider_class, MockProvider) - self.assertEqual(provider_name, "mock") + # Verify that both providers are registered. + found_aistudio = any( + provider_name == "aistudio" and provider_class == MockAistudioProvider + for provider_name, provider_class in provider_list + ) + found_vertexai = any( + provider_name == "vertexai" and provider_class == MockVertexAIProvider + for provider_name, provider_class in provider_list + ) + self.assertTrue(found_aistudio, "AI Studio provider not found.") + self.assertTrue(found_vertexai, "Vertex AI provider not found.") def test_get_provider(self): - """Test to get provider class from registry.""" - provider_class = manager.LLMManager.get_provider("mock") - self.assertEqual(provider_class, MockProvider) + """Test retrieval of a provider class from the registry.""" + provider_class = manager.LLMManager.get_provider("aistudio") + self.assertEqual(provider_class, MockAistudioProvider) self.assertRaises(KeyError, manager.LLMManager.get_provider, "no_such_provider") def test_register_provider(self): - """Test that registering a provider that is already registered raises ValueError.""" + """Test that re-registering an already registered provider raises ValueError.""" self.assertRaises( - ValueError, manager.LLMManager.register_provider, MockProvider + ValueError, manager.LLMManager.register_provider, MockAistudioProvider + ) + + def test_create_provider_default(self): + """Test create_provider using the default configuration.""" + provider_instance = manager.LLMManager.create_provider() + self.assertIsInstance(provider_instance, MockAistudioProvider) + self.assertEqual( + provider_instance.config, + { + "api_key": "AIzaSyTestDefaultKey", + "model": "gemini-2.0-flash-exp", + }, + ) + + def test_create_provider_feature(self): + """Test create_provider using a feature-specific configuration.""" + self.app.config["LLM_PROVIDER_CONFIGS"] = { + "nl2q": { + "vertexai": { + "model": "gemini-1.5-flash-001", + "project_id": "test_project_id", + }, + }, + } + provider_instance = manager.LLMManager.create_provider(feature_name="nl2q") + self.assertIsInstance(provider_instance, MockVertexAIProvider) + self.assertEqual( + provider_instance.config, + { + "model": "gemini-1.5-flash-001", + "project_id": "test_project_id", + }, ) + + def test_create_provider_invalid_config(self): + """Test that create_provider raises ValueError when configuration is invalid. + + Here, more than one provider is specified in the configuration. + """ + self.app.config["LLM_PROVIDER_CONFIGS"] = { + "default": { + "aistudio": {"api_key": "value"}, + "vertexai": {"model": "value"}, + } + } + with self.assertRaises(ValueError): + manager.LLMManager.create_provider() + + def test_create_provider_missing_config(self): + """Test that create_provider raises ValueError when configuration is missing.""" + self.app.config["LLM_PROVIDER_CONFIGS"] = {} + with self.assertRaises(ValueError): + manager.LLMManager.create_provider() From ad49e937cf9a69076c866ea9f47adb54443c0e76 Mon Sep 17 00:00:00 2001 From: Maarten van Dantzig Date: Tue, 4 Feb 2025 13:47:01 +0100 Subject: [PATCH 12/13] update manager_test.py tests --- timesketch/lib/llms/manager_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/timesketch/lib/llms/manager_test.py b/timesketch/lib/llms/manager_test.py index ae600d7719..c850b6a75c 100644 --- a/timesketch/lib/llms/manager_test.py +++ b/timesketch/lib/llms/manager_test.py @@ -26,7 +26,7 @@ def __init__(self, config, **kwargs): self.config = config self.kwargs = kwargs - def generate(self, prompt=None) -> str: + def generate(self) -> str: return "Generated response from AI Studio." @@ -39,7 +39,7 @@ def __init__(self, config, **kwargs): self.config = config self.kwargs = kwargs - def generate(self, prompt=None) -> str: + def generate(self) -> str: return "Generated response from Vertex AI." From 77e770c9c96a03bc0ddab97be4dc355735111f67 Mon Sep 17 00:00:00 2001 From: Maarten van Dantzig Date: Tue, 4 Feb 2025 16:45:09 +0100 Subject: [PATCH 13/13] Add warning to settings.py for old configs --- timesketch/api/v1/resources/settings.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/timesketch/api/v1/resources/settings.py b/timesketch/api/v1/resources/settings.py index eac0c3b8af..39ec97e916 100644 --- a/timesketch/api/v1/resources/settings.py +++ b/timesketch/api/v1/resources/settings.py @@ -13,10 +13,13 @@ # limitations under the License. """System settings.""" +import logging from flask import current_app, jsonify from flask_restful import Resource from flask_login import login_required +logger = logging.getLogger("timesketch.system_settings") + class SystemSettingsResource(Resource): """Resource to get system settings.""" @@ -45,4 +48,18 @@ def get(self): default_provider = next(iter(default_conf)) result["LLM_PROVIDER"] = default_provider + # TODO(mvd): Remove by 2025/06/01 once all users have updated their config. + old_llm_provider = current_app.config.get("LLM_PROVIDER") + if ( + old_llm_provider and "default" not in llm_configs + ): # Basic check for old config + warning_message = ( + "Your LLM configuration in timesketch.conf is outdated and may cause " + "issues with LLM features. " + "Please update your LLM_PROVIDER_CONFIGS section to the new format. " + "Refer to the documentation for the updated configuration structure." + ) + result["llm_config_warning"] = warning_message + logger.warning(warning_message) + return jsonify(result)