diff --git a/timesketch/api/v1/resources/nl2q.py b/timesketch/api/v1/resources/nl2q.py index b6928ff55a..d6f07362a8 100644 --- a/timesketch/api/v1/resources/nl2q.py +++ b/timesketch/api/v1/resources/nl2q.py @@ -170,45 +170,34 @@ def concatenate_values(self, group): @login_required def post(self, sketch_id): - """Handles POST request to the resource. - - Args: - sketch_id: Sketch ID. - - Returns: - JSON representing the LLM prediction. - """ - llm_provider = current_app.config.get("LLM_PROVIDER", "") - if not llm_provider: - logger.error("No LLM provider was defined in the main configuration file") - abort( - HTTP_STATUS_CODE_INTERNAL_SERVER_ERROR, - "No LLM provider was defined in the main configuration file", - ) form = request.json if not form: - abort( - HTTP_STATUS_CODE_BAD_REQUEST, - "No JSON data provided", - ) + abort(HTTP_STATUS_CODE_BAD_REQUEST, "No JSON data provided") if "question" not in form: + abort(HTTP_STATUS_CODE_BAD_REQUEST, "The 'question' parameter is required!") + + llm_configs = current_app.config.get("LLM_PROVIDER_CONFIGS") + if not llm_configs: + logger.error("No LLM provider configuration defined.") abort( - HTTP_STATUS_CODE_BAD_REQUEST, - "The 'question' parameter is required!", + HTTP_STATUS_CODE_INTERNAL_SERVER_ERROR, + "No LLM provider was defined in the main configuration file", ) question = form.get("question") prompt = self.build_prompt(question, sketch_id) + result_schema = { "name": "AI generated search query", "query_string": None, "error": None, } + feature_name = "nl2q" try: llm = manager.LLMManager.create_provider(feature_name=feature_name) - except Exception as e: # pylint: disable=broad-except + except Exception as e: logger.error("Error LLM Provider: {}".format(e)) result_schema["error"] = ( "Error loading LLM Provider. Please try again later!" @@ -217,14 +206,13 @@ def post(self, sketch_id): try: prediction = llm.generate(prompt) - except Exception as e: # pylint: disable=broad-except + except Exception as e: logger.error("Error NL2Q prompt: {}".format(e)) result_schema["error"] = ( "An error occurred generating the query via the defined LLM. " "Please try again later!" ) return jsonify(result_schema) - # The model sometimes output triple backticks that needs to be removed. - result_schema["query_string"] = prediction.strip("```") + result_schema["query_string"] = prediction.strip("```") return jsonify(result_schema) diff --git a/timesketch/api/v1/resources_test.py b/timesketch/api/v1/resources_test.py index 3b8ea1712f..52d3eb82b0 100644 --- a/timesketch/api/v1/resources_test.py +++ b/timesketch/api/v1/resources_test.py @@ -1196,6 +1196,7 @@ class TestNl2qResource(BaseTest): @mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore) def test_nl2q_prompt(self, mock_aggregator, mock_create_provider): """Test the prompt is created correctly.""" + self.login() data = dict(question="Question for LLM?") mock_AggregationResult = mock.MagicMock() @@ -1204,13 +1205,9 @@ def test_nl2q_prompt(self, mock_aggregator, mock_create_provider): {"data_type": "test:data_type:2"}, ] mock_aggregator.return_value = (mock_AggregationResult, {}) - - # Create a mock provider that returns the expected query. mock_llm = mock.Mock() mock_llm.generate.return_value = "LLM generated query" - # Patch create_provider to return our mock provider. mock_create_provider.return_value = mock_llm - response = self.client.post( self.resource_url, data=json.dumps(data), @@ -1318,6 +1315,7 @@ def test_nl2q_wrong_llm_provider(self, mock_aggregator): self.app.config["LLM_PROVIDER_CONFIGS"] = {"default": {"DoesNotExists": {}}} self.login() + self.login() data = dict(question="Question for LLM?") mock_AggregationResult = mock.MagicMock() mock_AggregationResult.values = [ @@ -1337,10 +1335,9 @@ def test_nl2q_wrong_llm_provider(self, mock_aggregator): @mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore) def test_nl2q_no_llm_provider(self): """Test nl2q with no LLM provider configured.""" + if "LLM_PROVIDER_CONFIGS" in self.app.config: del self.app.config["LLM_PROVIDER_CONFIGS"] - self.app.config["DFIQ_ENABLED"] = False - self.login() data = dict(question="Question for LLM?") response = self.client.post( @@ -1376,10 +1373,10 @@ def test_nl2q_no_permission(self): ) self.assertEqual(response.status_code, HTTP_STATUS_CODE_FORBIDDEN) - @mock.patch("timesketch.lib.llms.manager.LLMManager") + @mock.patch("timesketch.lib.llms.manager.LLMManager.create_provider") @mock.patch("timesketch.api.v1.utils.run_aggregator") @mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore) - def test_nl2q_llm_error(self, mock_aggregator, mock_llm_manager): + def test_nl2q_llm_error(self, mock_aggregator, mock_create_provider): """Test nl2q with llm error.""" self.login() @@ -1392,13 +1389,15 @@ def test_nl2q_llm_error(self, mock_aggregator, mock_llm_manager): mock_aggregator.return_value = (mock_AggregationResult, {}) mock_llm = mock.Mock() mock_llm.generate.side_effect = Exception("Test exception") - mock_llm_manager.return_value.get_provider.return_value = lambda: mock_llm + mock_create_provider.return_value = mock_llm response = self.client.post( self.resource_url, data=json.dumps(data), content_type="application/json", ) - self.assertEqual(response.status_code, HTTP_STATUS_CODE_OK) + self.assertEqual( + response.status_code, HTTP_STATUS_CODE_OK + ) # Still expect 200 OK with error in JSON data = json.loads(response.get_data(as_text=True)) self.assertIsNotNone(data.get("error"))