From cb1ddbd8ce7ad769e8c76050323e8f5744eefada Mon Sep 17 00:00:00 2001 From: Antonino Lorenzo <94693967+antoninoLorenzo@users.noreply.github.com> Date: Mon, 24 Feb 2025 19:50:35 +0100 Subject: [PATCH] test: updated testing for Ollama --- src/core/llm/ollama.py | 44 +++++---- src/core/memory/schema.py | 6 +- test/api/llm/__init__.py | 0 test/api/llm/test_ollama.py | 154 ++++++++++++++++++++++++++++++++ test/mock/mock_ollama_client.py | 30 ++++++- 5 files changed, 212 insertions(+), 22 deletions(-) create mode 100644 test/api/llm/__init__.py create mode 100644 test/api/llm/test_ollama.py diff --git a/src/core/llm/ollama.py b/src/core/llm/ollama.py index 8bd7dc6..d00bdde 100644 --- a/src/core/llm/ollama.py +++ b/src/core/llm/ollama.py @@ -5,8 +5,8 @@ from ollama import Client, ResponseError from pydantic import validate_call +from src.core.memory import Conversation, Role from src.core.llm.schema import Provider, ProviderError -from src.core.memory import Conversation from src.utils import get_logger @@ -60,35 +60,49 @@ def __post_init__(self): try: self.client = Client(host=self.inference_endpoint) except Exception as err: - raise RuntimeError('Initialization Failed') from err + raise RuntimeError('Ollama: invalid endpoint') from err @staticmethod def user_message_token_length( conversation: Conversation, full_input_tokens: int, - system_prompt_tokens: int ) -> int: - # assumes conversation contains at least system prompt and usr message - if len(conversation) < 2: - return 0 - + print(f'conversation: {conversation}, ({len(conversation)})') + print(f'counting tokens, full input length: {full_input_tokens}') + # if conversation contains a system prompt its length must be subtracted + subtract = 0 + if conversation.messages[0].role == Role.SYS: + print('system prompt found') + subtract = int(len(conversation.messages[0].content) / 4) + + # considering User + Assistant if len(conversation) == 2: - return full_input_tokens - system_prompt_tokens + return full_input_tokens - subtract + else: - user_message_tokens = full_input_tokens - system_prompt_tokens - # currently conversation.messages contains all messages - # except the assistant response, so last message must be excluded - for message in conversation.messages[:-1]: + user_message_tokens = full_input_tokens - subtract + print(f'{full_input_tokens} - {subtract} = {user_message_tokens}') + # exclude system prompt + for message in conversation.messages[1:]: + prev = user_message_tokens user_message_tokens -= message.get_tokens() + print(f'{prev} - {message.get_tokens()} = {user_message_tokens}') return user_message_tokens @validate_call def query( self, conversation: Conversation - ) -> Tuple[str, int, int]: + ): """Generator that returns a tuple containing: (response_chunk, user message tokens, assistant message tokens)""" + # validation: conversation should contain messages + if not conversation.messages: + raise ProviderError('Error: empty conversation') + last_message = conversation.messages[-1] + if not last_message.role == Role.USER or not last_message.content: + raise ProviderError('Error: last message is not user message') + try: options = AVAILABLE_MODELS[self.__match_model()]['options'] stream = self.client.chat( @@ -106,11 +120,9 @@ def query( # - `prompt_eval_count` -> input prompt tokens # - `eval_count` -> output tokens # The input prompt contains system prompt + entire conversation. - system_prompt_length_estimate = int(len(conversation.messages[0].content) / 4) user_msg_tokens = Ollama.user_message_token_length( conversation=conversation, - full_input_tokens=c['prompt_eval_count'], - system_prompt_tokens=system_prompt_length_estimate + full_input_tokens=c['prompt_eval_count'] ) assistant_msg_tokens = c['eval_count'] diff --git a/src/core/memory/schema.py b/src/core/memory/schema.py index 29f2878..00548ea 100644 --- a/src/core/memory/schema.py +++ b/src/core/memory/schema.py @@ -26,17 +26,17 @@ class Message(BaseModel): """Message object""" role: Role content: str - __token_length: int = 0 + token_length: int = 0 def model_dump(self, **kwargs): return {'role': str(self.role), 'content': self.content} # using @property causes issues with BaseModel def get_tokens(self) -> int: - return self.__token_length + return self.token_length def set_tokens(self, val: int): - self.__token_length = val + self.token_length = val class Conversation(BaseModel): diff --git a/test/api/llm/__init__.py b/test/api/llm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/api/llm/test_ollama.py b/test/api/llm/test_ollama.py new file mode 100644 index 0000000..913c100 --- /dev/null +++ b/test/api/llm/test_ollama.py @@ -0,0 +1,154 @@ +# TODO: +# **query** +# - test streaming +# - test token counting (**user_message_token_length**) +# +# **tool_query** +# +import functools +from pytest import mark, raises, fixture + +from src.core.llm import Ollama, ProviderError +from src.core.memory import Conversation, Message, Role +from test.mock.mock_ollama_client import MockOllamaClient + + +@mark.parametrize('parameters', [ + {'host_is_valid': False, 'model': 'mistral', 'exception': RuntimeError}, + {'host_is_valid': True, 'model': '', 'exception': ValueError}, + {'host_is_valid': True, 'model': 'chatgipiti', 'exception': ValueError}, +]) +def test_ollama_init( + monkeypatch, + parameters +): + monkeypatch.setattr( + 'src.core.llm.ollama.Client', + functools.partial(MockOllamaClient.__init__, valid_host=parameters['host_is_valid']) + ) + + with raises(parameters['exception']): + Ollama(model=parameters['model'], inference_endpoint='') + + +# since Ollama.query uses pydantic.validate_call here the tests +# are oriented to mess with Conversation.messages order and content. +@mark.parametrize('parameters', [ + { + 'conversation': Conversation( + conversation_id=1, + name='untitled', + messages=[] + ), + 'exception': ProviderError + }, + { + 'conversation': Conversation( + conversation_id=1, + name='untitled', + messages=[ + Message(role=Role.SYS, content='you are an helpful 41 ass') + ] + ), + 'exception': ProviderError + }, + { + 'conversation': Conversation( + conversation_id=1, + name='untitled', + messages=[ + Message(role=Role.USER, content='') + ] + ), + 'exception': ProviderError + }, + { + 'conversation': Conversation( + conversation_id=1, + name='untitled', + messages=[ + Message(role=Role.ASSISTANT, content='Shouldn\'t be an assistant message') + ] + ), + 'exception': ProviderError + } +]) +def test_query_message_validation( + monkeypatch, + parameters +): + monkeypatch.setattr( + 'src.core.llm.ollama.Client', + functools.partial(MockOllamaClient, valid_host=True) + ) + + ollama = Ollama(model='mistral', inference_endpoint='') + with raises(parameters['exception']): + for _ in ollama.query(parameters['conversation']): + pass + + +@mark.parametrize('parameters', [ + { + 'conversation': Conversation( + conversation_id=1, + name='untitled', + messages=[ + Message(role=Role.USER, content='Can I run LLM in a lemon?') + ] + ), + 'expected_length_user': 6, + 'expected_length_response': 10 + }, + { + 'conversation': Conversation( + conversation_id=1, + name='untitled', + messages=[ + Message(role=Role.SYS, content='You are a large lemon engineer'), + Message(role=Role.USER, content='Can I run LLM in a lemon?') + ] + ), + 'expected_length_user': 6, + 'expected_length_response': 10 + }, + { + 'conversation': Conversation( + conversation_id=1, + name='untitled', + messages=[ + Message(role=Role.SYS, content='You are a large lemon engineer'), + Message(role=Role.USER, content='Can I run LLM in a lemon?', token_length=6), + Message(role=Role.ASSISTANT, content='hold on you might be onto something\n', token_length=9), + Message(role=Role.USER, content='cool') + ] + ), + 'expected_length_user': 1, + 'expected_length_response': 4 + }, + +]) +def test_query_response( + monkeypatch, + parameters +): + monkeypatch.setattr( + 'src.core.llm.ollama.Client', + functools.partial(MockOllamaClient, valid_host=True) + ) + + ollama = Ollama(model='mistral', inference_endpoint='') + + response = '' + last_user, last_assistant = 0, 0 + for chunk, tk_user, tk_assistant in ollama.query(parameters['conversation']): + response += chunk + last_user = tk_user + last_assistant = tk_assistant + + print(f'user: {last_user}; assistant: {last_assistant}') + + # the expected_length is approximated considering message token length is len(content) / 4 + # => we expect len(user_message.content) / 4, len(response.content) / 4 + assert last_user == parameters['expected_length_user'] + assert last_assistant == int(len(response) / 4) diff --git a/test/mock/mock_ollama_client.py b/test/mock/mock_ollama_client.py index 4f09fd4..e23d8ec 100644 --- a/test/mock/mock_ollama_client.py +++ b/test/mock/mock_ollama_client.py @@ -1,3 +1,5 @@ +from ollama import ResponseError + from src.core.llm import ProviderError @@ -9,10 +11,32 @@ def __init__(self, host: str, valid_host = True): raise ProviderError('Ollama: invalid endpoint') def chat( + self, model: str, messages: list, - stream: bool, - options: dict, + stream: bool = True, + options: dict | None = None, tools: list | None = None ): - pass + if not model or not messages: + raise ResponseError("Model and messages are required") + + last_message = messages[-1]['content'] + response_message = f'response for: {last_message}\n' + + full_input_tokens, eval_count = 0, 0 + for i, char in enumerate(response_message): + # ollama sends token count only for last chunk in a stream + if len(response_message) == i + 1: + # count "tokens" (considering an average of 4 characters per token) + for msg in messages: + full_input_tokens += int(len(msg['content']) / 4) + + # count output "tokens" + eval_count += int(len(response_message) / 4) + + yield { + 'message': {'content': char}, + 'prompt_eval_count': full_input_tokens, + 'eval_count': eval_count + }