Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
itsmvd committed Feb 4, 2025
1 parent 364ec10 commit f653145
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 35 deletions.
38 changes: 13 additions & 25 deletions timesketch/api/v1/resources/nl2q.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,45 +170,34 @@ def concatenate_values(self, group):

@login_required
def post(self, sketch_id):
"""Handles POST request to the resource.
Args:
sketch_id: Sketch ID.
Returns:
JSON representing the LLM prediction.
"""
llm_provider = current_app.config.get("LLM_PROVIDER", "")
if not llm_provider:
logger.error("No LLM provider was defined in the main configuration file")
abort(
HTTP_STATUS_CODE_INTERNAL_SERVER_ERROR,
"No LLM provider was defined in the main configuration file",
)
form = request.json
if not form:
abort(
HTTP_STATUS_CODE_BAD_REQUEST,
"No JSON data provided",
)
abort(HTTP_STATUS_CODE_BAD_REQUEST, "No JSON data provided")

if "question" not in form:
abort(HTTP_STATUS_CODE_BAD_REQUEST, "The 'question' parameter is required!")

llm_configs = current_app.config.get("LLM_PROVIDER_CONFIGS")
if not llm_configs:
logger.error("No LLM provider configuration defined.")
abort(
HTTP_STATUS_CODE_BAD_REQUEST,
"The 'question' parameter is required!",
HTTP_STATUS_CODE_INTERNAL_SERVER_ERROR,
"No LLM provider was defined in the main configuration file",
)

question = form.get("question")
prompt = self.build_prompt(question, sketch_id)

result_schema = {
"name": "AI generated search query",
"query_string": None,
"error": None,
}

feature_name = "nl2q"
try:
llm = manager.LLMManager.create_provider(feature_name=feature_name)
except Exception as e: # pylint: disable=broad-except
except Exception as e:
logger.error("Error LLM Provider: {}".format(e))
result_schema["error"] = (
"Error loading LLM Provider. Please try again later!"
Expand All @@ -217,14 +206,13 @@ def post(self, sketch_id):

try:
prediction = llm.generate(prompt)
except Exception as e: # pylint: disable=broad-except
except Exception as e:
logger.error("Error NL2Q prompt: {}".format(e))
result_schema["error"] = (
"An error occurred generating the query via the defined LLM. "
"Please try again later!"
)
return jsonify(result_schema)
# The model sometimes output triple backticks that needs to be removed.
result_schema["query_string"] = prediction.strip("```")

result_schema["query_string"] = prediction.strip("```")
return jsonify(result_schema)
19 changes: 9 additions & 10 deletions timesketch/api/v1/resources_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,6 +1196,7 @@ class TestNl2qResource(BaseTest):
@mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore)
def test_nl2q_prompt(self, mock_aggregator, mock_create_provider):
"""Test the prompt is created correctly."""

self.login()
data = dict(question="Question for LLM?")
mock_AggregationResult = mock.MagicMock()
Expand All @@ -1204,13 +1205,9 @@ def test_nl2q_prompt(self, mock_aggregator, mock_create_provider):
{"data_type": "test:data_type:2"},
]
mock_aggregator.return_value = (mock_AggregationResult, {})

# Create a mock provider that returns the expected query.
mock_llm = mock.Mock()
mock_llm.generate.return_value = "LLM generated query"
# Patch create_provider to return our mock provider.
mock_create_provider.return_value = mock_llm

response = self.client.post(
self.resource_url,
data=json.dumps(data),
Expand Down Expand Up @@ -1318,6 +1315,7 @@ def test_nl2q_wrong_llm_provider(self, mock_aggregator):

self.app.config["LLM_PROVIDER_CONFIGS"] = {"default": {"DoesNotExists": {}}}
self.login()
self.login()
data = dict(question="Question for LLM?")
mock_AggregationResult = mock.MagicMock()
mock_AggregationResult.values = [
Expand All @@ -1337,10 +1335,9 @@ def test_nl2q_wrong_llm_provider(self, mock_aggregator):
@mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore)
def test_nl2q_no_llm_provider(self):
"""Test nl2q with no LLM provider configured."""

if "LLM_PROVIDER_CONFIGS" in self.app.config:
del self.app.config["LLM_PROVIDER_CONFIGS"]
self.app.config["DFIQ_ENABLED"] = False

self.login()
data = dict(question="Question for LLM?")
response = self.client.post(
Expand Down Expand Up @@ -1376,10 +1373,10 @@ def test_nl2q_no_permission(self):
)
self.assertEqual(response.status_code, HTTP_STATUS_CODE_FORBIDDEN)

@mock.patch("timesketch.lib.llms.manager.LLMManager")
@mock.patch("timesketch.lib.llms.manager.LLMManager.create_provider")
@mock.patch("timesketch.api.v1.utils.run_aggregator")
@mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore)
def test_nl2q_llm_error(self, mock_aggregator, mock_llm_manager):
def test_nl2q_llm_error(self, mock_aggregator, mock_create_provider):
"""Test nl2q with llm error."""

self.login()
Expand All @@ -1392,13 +1389,15 @@ def test_nl2q_llm_error(self, mock_aggregator, mock_llm_manager):
mock_aggregator.return_value = (mock_AggregationResult, {})
mock_llm = mock.Mock()
mock_llm.generate.side_effect = Exception("Test exception")
mock_llm_manager.return_value.get_provider.return_value = lambda: mock_llm
mock_create_provider.return_value = mock_llm
response = self.client.post(
self.resource_url,
data=json.dumps(data),
content_type="application/json",
)
self.assertEqual(response.status_code, HTTP_STATUS_CODE_OK)
self.assertEqual(
response.status_code, HTTP_STATUS_CODE_OK
) # Still expect 200 OK with error in JSON
data = json.loads(response.get_data(as_text=True))
self.assertIsNotNone(data.get("error"))

Expand Down

0 comments on commit f653145

Please sign in to comment.