From cbac991bd9abf399780ecc0d9c4594567d61042b Mon Sep 17 00:00:00 2001 From: Nicholas Albion Date: Wed, 4 Oct 2023 15:23:36 +1100 Subject: [PATCH] avoid getting stuck in a loop if LLM can't conform to schema. --- pilot/helpers/agents/test_Developer.py | 18 +++++++++--------- pilot/utils/llm_connection.py | 8 +++++++- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/pilot/helpers/agents/test_Developer.py b/pilot/helpers/agents/test_Developer.py index 36d70ede5..5aa8069ad 100644 --- a/pilot/helpers/agents/test_Developer.py +++ b/pilot/helpers/agents/test_Developer.py @@ -144,7 +144,7 @@ def test_test_code_changes_invalid_json(self, mock_get_saved_user_input, self.project.developer = self.developer # we send a GET_TEST_TYPE spec, but the 1st response is invalid - types_in_response = ['command', 'command_test'] + types_in_response = ['command', 'wrong_again', 'command_test'] json_received = [] def generate_response(*args, **kwargs): @@ -169,12 +169,12 @@ def generate_response(*args, **kwargs): mock_questionary = MockQuestionary(['']) - # with patch('utils.questionary.questionary', mock_questionary): - # When - result = self.developer.test_code_changes(monkey, convo) + with patch('utils.questionary.questionary', mock_questionary): + # When + result = self.developer.test_code_changes(monkey, convo) - # Then - assert result == {'success': True, 'cli_response': 'stdout:\n```\n\n```'} - assert mock_requests_post.call_count == 2 - assert "The JSON is invalid at $.type - 'command' is not one of ['automated_test', 'command_test', 'manual_test', 'no_test']" in json_received[1]['messages'][3]['content'] - assert mock_execute.call_count == 1 + # Then + assert result == {'success': True, 'cli_response': 'stdout:\n```\n\n```'} + assert mock_requests_post.call_count == 3 + assert "The JSON is invalid at $.type - 'command' is not one of ['automated_test', 'command_test', 'manual_test', 'no_test']" in json_received[1]['messages'][3]['content'] + assert mock_execute.call_count == 1 diff --git a/pilot/utils/llm_connection.py b/pilot/utils/llm_connection.py index e07362d21..260f13677 100644 --- a/pilot/utils/llm_connection.py +++ b/pilot/utils/llm_connection.py @@ -163,13 +163,19 @@ def wrapper(*args, **kwargs): args[0]['function_buffer'] = e.doc continue elif isinstance(e, ValidationError): + function_error_count = 1 if 'function_error' not in args[0] else args[0]['function_error_count'] + 1 + args[0]['function_error_count'] = function_error_count + logger.warning('Received invalid JSON response from LLM. Asking to retry...') logger.info(f' at {e.json_path} {e.message}') # eg: # json_path: '$.type' # message: "'command' is not one of ['automated_test', 'command_test', 'manual_test', 'no_test']" args[0]['function_error'] = f'at {e.json_path} - {e.message}' - continue + + # Attempt retry if the JSON schema is invalid, but avoid getting stuck in a loop + if function_error_count < 3: + continue if "context_length_exceeded" in err_str: # spinner_stop(spinner) raise TokenLimitError(get_tokens_in_messages_from_openai_error(err_str), MAX_GPT_MODEL_TOKENS)