Skip to content

Commit

Permalink
refactoring: updated llm and memory module
Browse files Browse the repository at this point in the history
- split classes in various files
- integrated pydantic model validation
  • Loading branch information
antoninoLorenzo committed Jan 22, 2025
1 parent 48a179a commit 7b20b8e
Show file tree
Hide file tree
Showing 9 changed files with 258 additions and 300 deletions.
2 changes: 1 addition & 1 deletion src/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from src.core.memory import (
Role,
Message,
Session,
Conversation,
Memory
)
from src.core.tools import (
Expand Down
2 changes: 1 addition & 1 deletion src/core/knowledge/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from qdrant_client import QdrantClient, models
from qdrant_client.http.exceptions import UnexpectedResponse

from src.core.llm.llm import ProviderError
from src.core.llm import ProviderError
from src.core.knowledge.collections import Collection, Document, Topic


Expand Down
9 changes: 3 additions & 6 deletions src/core/llm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
"""Exposes implemented LLM functionalities"""

from src.core.llm.llm import (
LLM,
Provider,
ProviderError,
Ollama,
)
from src.core.llm.llm import LLM
from src.core.llm.schema import Provider, ProviderError
from src.core.llm.ollama import Ollama

AVAILABLE_PROVIDERS = {
'ollama': {'class': Ollama, 'key_required': False},
Expand Down
194 changes: 6 additions & 188 deletions src/core/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,185 +1,9 @@
"""
Interfaces the AI Agent to the LLM Provider, model availability depends on
implemented prompts, to use a new model the relative prompts should be written.
"""
from abc import ABC, abstractmethod
from typing import Tuple
from dataclasses import dataclass, field
from dataclasses import dataclass

import httpx
from ollama import Client, ResponseError

from src.core.memory import Role
from src.utils import get_logger

AVAILABLE_MODELS = {
'mistral': {
'options': {
'temperature': 0.5,
'num_ctx': 8192
},
'tools': True
},
'llama3.1': {
'options': {
'temperature': 0.5,
'num_ctx': 8192
},
'tools': False
},
'gemma2:9b': {
'options': {
'temperature': 0.5,
'num_ctx': 8192
},
'tools': False
}
}
logger = get_logger(__name__)


@dataclass
class Provider(ABC):
"""Represents a LLM Provider"""
model: str
inference_endpoint: str = 'http://localhost:11434'
api_key: str | None = None

@abstractmethod
def query(self, messages: list) -> Tuple[str, Tuple]:
"""Implement to makes query to the LLM provider"""

@abstractmethod
def tool_query(self, messages: list, tools: list | None = None):
"""Implement for LLM tool calling"""

@staticmethod
def verify_messages_format(messages: list[dict]):
"""Format validation for messages."""
# check types
message_types_dict = [isinstance(msg, dict) for msg in messages]
if not isinstance(messages, list) or \
len(messages) == 0 or \
False in message_types_dict:
raise TypeError(f'messages must be a list[dict]: \n {messages}')

# check format
roles = [Role.SYS, Role.USER, Role.ASSISTANT, Role.TOOL]
valid_roles = [str(role) for role in roles]
err_message = f'expected {{"role": "{valid_roles}", "content": "..."}}'

# check format - keys
message_keys = [list(msg.keys()) for msg in messages]
valid_keys = ['role' in keys and 'content' in keys and len(keys) == 2
for keys in message_keys]
if False in valid_keys:
raise ValueError(err_message + f'\nMessage Keys: {message_keys}')

# check format = values
message_roles = [msg['role'] in valid_roles for msg in messages]
message_content = [
len(msg['content']) != 0 and isinstance(msg['content'], str)
for msg in messages
]
if False in message_roles or False in message_content:
if False in message_roles:
invalid = messages[message_roles.index(False)]
else:
invalid = messages[message_content.index(False)]
logger.error(f'\t{err_message}. Found {invalid}')
raise ValueError(err_message)


class ProviderError(Exception):
"""Just a wrapper to Exception for error handling
when an error is caused by the LLM provider"""


@dataclass
class Ollama(Provider):
"""Ollama Interface"""
client: Client | None = field(init=False, default=None)

def __post_init__(self):
if self.__match_model() is None:
raise ValueError(f'Model {self.model} is not supported.')
try:
self.client = Client(host=self.inference_endpoint)
except Exception as err:
raise RuntimeError('Initialization Failed') from err

def query(
self,
messages: list
) -> Tuple[str, int]:
"""Generator that returns a tuple containing:
(response_chunk, token_usage)"""
try:
self.verify_messages_format(messages)
except (TypeError, ValueError) as input_err:
raise input_err from input_err

try:
options = AVAILABLE_MODELS[self.__match_model()]['options']
stream = self.client.chat(
model=self.model,
messages=messages,
stream=True,
options=options
)
for chunk in stream:
if 'eval_count' and 'prompt_eval_count' in chunk:
yield "", chunk['prompt_eval_count']

yield chunk['message']['content'], None
except (ResponseError, httpx.ConnectError) as err:
raise ProviderError(err) from err

def tool_query(
self,
messages: list,
tools: list | None = None
):
"""Implements LLM tool calling.
:param messages:
The current conversation provided as a list of messages in the
format [{"role": "assistant/user/system", "content": "..."}, ...]
:param tools:
A list of tools in the format specified by `ollama-python`, the
conversion is managed by `ToolRegistry` from `tool-parse` library.
:return
Ollama response with "message" : {"tool_calls": ...} or None.
"""
base_model = self.__match_model()
if base_model is None:
raise ValueError(f'Model {self.model} is not supported.')
if not AVAILABLE_MODELS[base_model]['tools']:
raise NotImplementedError(f'{self.model} not support tool calling')

try:
self.verify_messages_format(messages)
except (TypeError, ValueError) as input_err:
raise input_err from input_err

if not tools:
raise ValueError('Empty tool list')

tool_response = self.client.chat(
model=self.model,
messages=messages,
tools=tools
)

return tool_response if tool_response['message'].get('tool_calls') \
else None

def __match_model(self) -> str | None:
"""Check if a model is supported, the model availability on Ollama
is upon the user; ProviderError is raised if not available."""
for model in list(AVAILABLE_MODELS.keys()):
if self.model.startswith(model):
return model
return None
from src.core.llm.schema import Provider
from src.core.llm.ollama import Ollama
from src.core.memory import Conversation


@dataclass
Expand All @@ -199,7 +23,7 @@ def __post_init__(self):

def query(
self,
messages: list
messages: Conversation
) -> Tuple[str, int]:
"""Generator that returns LLM response in a tuple containing:
(chunk, token_usage).
Expand All @@ -211,7 +35,7 @@ def query(

def tool_query(
self,
messages: list,
messages: Conversation,
tools: list | None = None
):
"""
Expand All @@ -223,9 +47,3 @@ def tool_query(
the conversion is managed by `tool-parse` library."""
return self.provider.tool_query(messages, tools)


if __name__ == "__main__":

Ollama(model='mistral', inference_endpoint='some')
Ollama(model='mistral:7b-instruct-v0.3-q8_0', inference_endpoint='some')
Ollama(model='gpt', inference_endpoint='some')
115 changes: 115 additions & 0 deletions src/core/llm/ollama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from typing import Tuple
from dataclasses import dataclass

import httpx
from ollama import Client, ResponseError
from pydantic import validate_call

from src.core.llm.schema import Provider, ProviderError
from src.core.memory import Conversation
from src.utils import get_logger


AVAILABLE_MODELS = {
'mistral': {
'options': {
'temperature': 0.5,
'num_ctx': 8192
},
'tools': True
},
'llama3.1': {
'options': {
'temperature': 0.5,
'num_ctx': 8192
},
'tools': False
},
'gemma2:9b': {
'options': {
'temperature': 0.5,
'num_ctx': 8192
},
'tools': False
}
}
logger = get_logger(__name__)


@dataclass
class Ollama(Provider):
"""Client for Ollama."""
client: Client | None = None

def __post_init__(self):
if self.__match_model() is None:
raise ValueError(f'Model {self.model} is not supported.')
try:
self.client = Client(host=self.inference_endpoint)
except Exception as err:
raise RuntimeError('Initialization Failed') from err

@validate_call
def query(
self,
messages: Conversation
) -> Tuple[str, int]:
"""Generator that returns a tuple containing:
(response_chunk, token_usage)"""
try:
options = AVAILABLE_MODELS[self.__match_model()]['options']
stream = self.client.chat(
model=self.model,
messages=messages.model_dump(),
stream=True,
options=options
)
for chunk in stream:
if 'eval_count' and 'prompt_eval_count' in chunk:
yield "", chunk['prompt_eval_count']

yield chunk['message']['content'], None
except (ResponseError, httpx.ConnectError) as err:
raise ProviderError(err) from err

@validate_call
def tool_query(
self,
messages: Conversation,
tools: list | None = None
):
"""Implements LLM tool calling.
:param messages:
The current conversation provided as a list of messages in the
format [{"role": "assistant/user/system", "content": "..."}, ...]
:param tools:
A list of tools in the format specified by `ollama-python`, the
conversion is managed by `ToolRegistry` from `tool-parse` library.
:return
Ollama response with "message" : {"tool_calls": ...} or None.
"""
base_model = self.__match_model()
if base_model is None:
raise ValueError(f'Model {self.model} is not supported.')
if not AVAILABLE_MODELS[base_model]['tools']:
raise NotImplementedError(f'{self.model} not support tool calling')

if not tools:
raise ValueError('Empty tool list')

tool_response = self.client.chat(
model=self.model,
messages=messages.model_dump(),
tools=tools
)

return tool_response if tool_response['message'].get('tool_calls') \
else None

def __match_model(self) -> str | None:
"""Check if a model is supported, the model availability on Ollama
is upon the user; ProviderError is raised if not available."""
for model in list(AVAILABLE_MODELS.keys()):
if self.model.startswith(model):
return model
return None
27 changes: 27 additions & 0 deletions src/core/llm/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from abc import ABC, abstractmethod
from typing import Tuple
from dataclasses import dataclass

from src.core.memory import Conversation


@dataclass
class Provider(ABC):
"""Defines a common interface for all LLM providers.
Current implementation only supports Ollama as provider."""
model: str
inference_endpoint: str = 'http://localhost:11434'
api_key: str | None = None

@abstractmethod
def query(self, messages: Conversation) -> Tuple[str, Tuple]:
"""Implement to makes query to the LLM provider"""

@abstractmethod
def tool_query(self, messages: Conversation, tools: list | None = None):
"""Implement for LLM tool calling"""


class ProviderError(Exception):
"""Just a wrapper to Exception for error handling
when an error is caused by the LLM provider"""
Loading

0 comments on commit 7b20b8e

Please sign in to comment.