Skip to content

Commit

Permalink
avoid getting stuck in a loop if LLM can't conform to schema.
Browse files Browse the repository at this point in the history
  • Loading branch information
nalbion committed Oct 4, 2023
1 parent 0d8a4c7 commit cbac991
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
18 changes: 9 additions & 9 deletions pilot/helpers/agents/test_Developer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def test_test_code_changes_invalid_json(self, mock_get_saved_user_input,
self.project.developer = self.developer

# we send a GET_TEST_TYPE spec, but the 1st response is invalid
types_in_response = ['command', 'command_test']
types_in_response = ['command', 'wrong_again', 'command_test']
json_received = []

def generate_response(*args, **kwargs):
Expand All @@ -169,12 +169,12 @@ def generate_response(*args, **kwargs):

mock_questionary = MockQuestionary([''])

# with patch('utils.questionary.questionary', mock_questionary):
# When
result = self.developer.test_code_changes(monkey, convo)
with patch('utils.questionary.questionary', mock_questionary):
# When
result = self.developer.test_code_changes(monkey, convo)

# Then
assert result == {'success': True, 'cli_response': 'stdout:\n```\n\n```'}
assert mock_requests_post.call_count == 2
assert "The JSON is invalid at $.type - 'command' is not one of ['automated_test', 'command_test', 'manual_test', 'no_test']" in json_received[1]['messages'][3]['content']
assert mock_execute.call_count == 1
# Then
assert result == {'success': True, 'cli_response': 'stdout:\n```\n\n```'}
assert mock_requests_post.call_count == 3
assert "The JSON is invalid at $.type - 'command' is not one of ['automated_test', 'command_test', 'manual_test', 'no_test']" in json_received[1]['messages'][3]['content']
assert mock_execute.call_count == 1
8 changes: 7 additions & 1 deletion pilot/utils/llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,19 @@ def wrapper(*args, **kwargs):
args[0]['function_buffer'] = e.doc
continue
elif isinstance(e, ValidationError):
function_error_count = 1 if 'function_error' not in args[0] else args[0]['function_error_count'] + 1
args[0]['function_error_count'] = function_error_count

logger.warning('Received invalid JSON response from LLM. Asking to retry...')
logger.info(f' at {e.json_path} {e.message}')
# eg:
# json_path: '$.type'
# message: "'command' is not one of ['automated_test', 'command_test', 'manual_test', 'no_test']"
args[0]['function_error'] = f'at {e.json_path} - {e.message}'
continue

# Attempt retry if the JSON schema is invalid, but avoid getting stuck in a loop
if function_error_count < 3:
continue
if "context_length_exceeded" in err_str:
# spinner_stop(spinner)
raise TokenLimitError(get_tokens_in_messages_from_openai_error(err_str), MAX_GPT_MODEL_TOKENS)
Expand Down

0 comments on commit cbac991

Please sign in to comment.