diff --git a/timesketch/api/v1/resources_test.py b/timesketch/api/v1/resources_test.py index f844cd787a..3b8ea1712f 100644 --- a/timesketch/api/v1/resources_test.py +++ b/timesketch/api/v1/resources_test.py @@ -1191,12 +1191,11 @@ class TestNl2qResource(BaseTest): resource_url = "/api/v1/sketches/1/nl2q/" - @mock.patch("timesketch.lib.llms.manager.LLMManager") + @mock.patch("timesketch.lib.llms.manager.LLMManager.create_provider") @mock.patch("timesketch.api.v1.utils.run_aggregator") @mock.patch("timesketch.api.v1.resources.OpenSearchDataStore", MockDataStore) - def test_nl2q_prompt(self, mock_aggregator, mock_llm_manager): + def test_nl2q_prompt(self, mock_aggregator, mock_create_provider): """Test the prompt is created correctly.""" - self.login() data = dict(question="Question for LLM?") mock_AggregationResult = mock.MagicMock() @@ -1205,9 +1204,13 @@ def test_nl2q_prompt(self, mock_aggregator, mock_llm_manager): {"data_type": "test:data_type:2"}, ] mock_aggregator.return_value = (mock_AggregationResult, {}) + + # Create a mock provider that returns the expected query. mock_llm = mock.Mock() mock_llm.generate.return_value = "LLM generated query" - mock_llm_manager.return_value.get_provider.return_value = lambda: mock_llm + # Patch create_provider to return our mock provider. + mock_create_provider.return_value = mock_llm + response = self.client.post( self.resource_url, data=json.dumps(data),