Skip to content

Commit

Permalink
test: updated testing for Ollama
Browse files Browse the repository at this point in the history
  • Loading branch information
antoninoLorenzo committed Feb 24, 2025
1 parent 37ee61c commit cb1ddbd
Show file tree
Hide file tree
Showing 5 changed files with 212 additions and 22 deletions.
44 changes: 28 additions & 16 deletions src/core/llm/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from ollama import Client, ResponseError
from pydantic import validate_call

from src.core.memory import Conversation, Role
from src.core.llm.schema import Provider, ProviderError
from src.core.memory import Conversation
from src.utils import get_logger


Expand Down Expand Up @@ -60,35 +60,49 @@ def __post_init__(self):
try:
self.client = Client(host=self.inference_endpoint)
except Exception as err:
raise RuntimeError('Initialization Failed') from err
raise RuntimeError('Ollama: invalid endpoint') from err

@staticmethod
def user_message_token_length(
conversation: Conversation,
full_input_tokens: int,
system_prompt_tokens: int
) -> int:
# assumes conversation contains at least system prompt and usr message
if len(conversation) < 2:
return 0

print(f'conversation: {conversation}, ({len(conversation)})')
print(f'counting tokens, full input length: {full_input_tokens}')
# if conversation contains a system prompt its length must be subtracted
subtract = 0
if conversation.messages[0].role == Role.SYS:
print('system prompt found')
subtract = int(len(conversation.messages[0].content) / 4)

# considering User + Assistant
if len(conversation) == 2:
return full_input_tokens - system_prompt_tokens
return full_input_tokens - subtract

else:
user_message_tokens = full_input_tokens - system_prompt_tokens
# currently conversation.messages contains all messages
# except the assistant response, so last message must be excluded
for message in conversation.messages[:-1]:
user_message_tokens = full_input_tokens - subtract
print(f'{full_input_tokens} - {subtract} = {user_message_tokens}')
# exclude system prompt
for message in conversation.messages[1:]:
prev = user_message_tokens
user_message_tokens -= message.get_tokens()
print(f'{prev} - {message.get_tokens()} = {user_message_tokens}')
return user_message_tokens

@validate_call
def query(
self,
conversation: Conversation
) -> Tuple[str, int, int]:
):
"""Generator that returns a tuple containing:
(response_chunk, user message tokens, assistant message tokens)"""
# validation: conversation should contain messages
if not conversation.messages:
raise ProviderError('Error: empty conversation')
last_message = conversation.messages[-1]
if not last_message.role == Role.USER or not last_message.content:
raise ProviderError('Error: last message is not user message')

try:
options = AVAILABLE_MODELS[self.__match_model()]['options']
stream = self.client.chat(
Expand All @@ -106,11 +120,9 @@ def query(
# - `prompt_eval_count` -> input prompt tokens
# - `eval_count` -> output tokens
# The input prompt contains system prompt + entire conversation.
system_prompt_length_estimate = int(len(conversation.messages[0].content) / 4)
user_msg_tokens = Ollama.user_message_token_length(
conversation=conversation,
full_input_tokens=c['prompt_eval_count'],
system_prompt_tokens=system_prompt_length_estimate
full_input_tokens=c['prompt_eval_count']
)
assistant_msg_tokens = c['eval_count']

Expand Down
6 changes: 3 additions & 3 deletions src/core/memory/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@ class Message(BaseModel):
"""Message object"""
role: Role
content: str
__token_length: int = 0
token_length: int = 0

def model_dump(self, **kwargs):
return {'role': str(self.role), 'content': self.content}

# using @property causes issues with BaseModel
def get_tokens(self) -> int:
return self.__token_length
return self.token_length

def set_tokens(self, val: int):
self.__token_length = val
self.token_length = val


class Conversation(BaseModel):
Expand Down
Empty file added test/api/llm/__init__.py
Empty file.
154 changes: 154 additions & 0 deletions test/api/llm/test_ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
# TODO:
# **query**
# - test streaming
# - test token counting (**user_message_token_length**)
#
# **tool_query**
#
import functools
from pytest import mark, raises, fixture

from src.core.llm import Ollama, ProviderError
from src.core.memory import Conversation, Message, Role
from test.mock.mock_ollama_client import MockOllamaClient


@mark.parametrize('parameters', [
{'host_is_valid': False, 'model': 'mistral', 'exception': RuntimeError},
{'host_is_valid': True, 'model': '', 'exception': ValueError},
{'host_is_valid': True, 'model': 'chatgipiti', 'exception': ValueError},
])
def test_ollama_init(
monkeypatch,
parameters
):
monkeypatch.setattr(
'src.core.llm.ollama.Client',
functools.partial(MockOllamaClient.__init__, valid_host=parameters['host_is_valid'])
)

with raises(parameters['exception']):
Ollama(model=parameters['model'], inference_endpoint='')


# since Ollama.query uses pydantic.validate_call here the tests
# are oriented to mess with Conversation.messages order and content.
@mark.parametrize('parameters', [
{
'conversation': Conversation(
conversation_id=1,
name='untitled',
messages=[]
),
'exception': ProviderError
},
{
'conversation': Conversation(
conversation_id=1,
name='untitled',
messages=[
Message(role=Role.SYS, content='you are an helpful 41 ass')
]
),
'exception': ProviderError
},
{
'conversation': Conversation(
conversation_id=1,
name='untitled',
messages=[
Message(role=Role.USER, content='')
]
),
'exception': ProviderError
},
{
'conversation': Conversation(
conversation_id=1,
name='untitled',
messages=[
Message(role=Role.ASSISTANT, content='Shouldn\'t be an assistant message')
]
),
'exception': ProviderError
}
])
def test_query_message_validation(
monkeypatch,
parameters
):
monkeypatch.setattr(
'src.core.llm.ollama.Client',
functools.partial(MockOllamaClient, valid_host=True)
)

ollama = Ollama(model='mistral', inference_endpoint='')
with raises(parameters['exception']):
for _ in ollama.query(parameters['conversation']):
pass


@mark.parametrize('parameters', [
{
'conversation': Conversation(
conversation_id=1,
name='untitled',
messages=[
Message(role=Role.USER, content='Can I run LLM in a lemon?')
]
),
'expected_length_user': 6,
'expected_length_response': 10
},
{
'conversation': Conversation(
conversation_id=1,
name='untitled',
messages=[
Message(role=Role.SYS, content='You are a large lemon engineer'),
Message(role=Role.USER, content='Can I run LLM in a lemon?')
]
),
'expected_length_user': 6,
'expected_length_response': 10
},
{
'conversation': Conversation(
conversation_id=1,
name='untitled',
messages=[
Message(role=Role.SYS, content='You are a large lemon engineer'),
Message(role=Role.USER, content='Can I run LLM in a lemon?', token_length=6),
Message(role=Role.ASSISTANT, content='hold on you might be onto something\n', token_length=9),
Message(role=Role.USER, content='cool')
]
),
'expected_length_user': 1,
'expected_length_response': 4
},
])
def test_query_response(
monkeypatch,
parameters
):
monkeypatch.setattr(
'src.core.llm.ollama.Client',
functools.partial(MockOllamaClient, valid_host=True)
)

ollama = Ollama(model='mistral', inference_endpoint='')

response = ''
last_user, last_assistant = 0, 0
for chunk, tk_user, tk_assistant in ollama.query(parameters['conversation']):
response += chunk
last_user = tk_user
last_assistant = tk_assistant

print(f'user: {last_user}; assistant: {last_assistant}')

# the expected_length is approximated considering message token length is len(content) / 4
# => we expect len(user_message.content) / 4, len(response.content) / 4
assert last_user == parameters['expected_length_user']
assert last_assistant == int(len(response) / 4)
30 changes: 27 additions & 3 deletions test/mock/mock_ollama_client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from ollama import ResponseError

from src.core.llm import ProviderError


Expand All @@ -9,10 +11,32 @@ def __init__(self, host: str, valid_host = True):
raise ProviderError('Ollama: invalid endpoint')

def chat(
self,
model: str,
messages: list,
stream: bool,
options: dict,
stream: bool = True,
options: dict | None = None,
tools: list | None = None
):
pass
if not model or not messages:
raise ResponseError("Model and messages are required")

last_message = messages[-1]['content']
response_message = f'response for: {last_message}\n'

full_input_tokens, eval_count = 0, 0
for i, char in enumerate(response_message):
# ollama sends token count only for last chunk in a stream
if len(response_message) == i + 1:
# count "tokens" (considering an average of 4 characters per token)
for msg in messages:
full_input_tokens += int(len(msg['content']) / 4)

# count output "tokens"
eval_count += int(len(response_message) / 4)

yield {
'message': {'content': char},
'prompt_eval_count': full_input_tokens,
'eval_count': eval_count
}

0 comments on commit cb1ddbd

Please sign in to comment.