Skip to content

Commit

Permalink
feat: Add initial support for structured output
Browse files Browse the repository at this point in the history
  • Loading branch information
xmnlab committed Nov 20, 2024
1 parent b621e56 commit 9731c83
Show file tree
Hide file tree
Showing 8 changed files with 282 additions and 454 deletions.
664 changes: 228 additions & 436 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ pypdf = ">=5"
langchain = ">=0.3.7"
langchain-community = ">=0.3.7"
spacy = ">=3"
instructor = ">=1"
pydantic = ">=2"

[tool.poetry.extras]
cpu = ["torch", "torchvision"]
Expand Down
29 changes: 18 additions & 11 deletions src/rago/generation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
from __future__ import annotations

from abc import abstractmethod
from typing import Any
from typing import Any, Optional

import torch

from pydantic import BaseModel
from typeguard import typechecked


Expand All @@ -26,6 +27,7 @@ class GenerationBase:
prompt_template: str = (
'question: \n```\n{query}\n```\ncontext: ```\n{context}\n```'
)
structured_output: Optional[BaseModel] = None

# default parameters that can be overwritten by the derived class
default_device_name: str = 'cpu'
Expand All @@ -44,6 +46,7 @@ def __init__(
prompt_template: str = '',
output_max_length: int = 500,
device: str = 'auto',
structured_output: Optional[BaseModel] = None,
logs: dict[str, Any] = {},
) -> None:
"""Initialize Generation class.
Expand All @@ -58,29 +61,33 @@ def __init__(
output_max_length : int
Maximum length of the generated output.
device: str (default=auto)
structured_output: Optional[BaseModel] = None
logs: dict[str, Any] = {}
"""
self.api_key = api_key
self.model_name = model_name or self.default_model_name
self.output_max_length = (
self.api_key: str = api_key
self.model_name: str = model_name or self.default_model_name
self.output_max_length: int = (
output_max_length or self.default_output_max_length
)
self.temperature = temperature or self.default_temperature
self.temperature: float = temperature or self.default_temperature

self.prompt_template = prompt_template or self.default_prompt_template
self.prompt_template: str = (
prompt_template or self.default_prompt_template
)
self.structured_output: Optional[BaseModel] = None

if self.device_name not in ['cpu', 'cuda', 'auto']:
if device not in ['cpu', 'cuda', 'auto']:
raise Exception(
f'Device {self.device_name} not supported. '
'Options: cpu, cuda, auto.'
f'Device {device} not supported. ' 'Options: cpu, cuda, auto.'
)

cuda_available = torch.cuda.is_available()
self.device_name = (
self.device_name: str = (
'cpu' if device == 'cpu' or not cuda_available else 'cuda'
)
self.device = torch.device(self.device_name)

self.logs = logs
self.logs: dict[str, Any] = logs

self._validate()
self._setup()
Expand Down
7 changes: 6 additions & 1 deletion src/rago/generation/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import cast

import google.generativeai as genai
import instructor

from typeguard import typechecked

Expand All @@ -20,7 +21,11 @@ class GeminiGen(GenerationBase):
def _setup(self) -> None:
"""Set up the object with the initial parameters."""
genai.configure(api_key=self.api_key)
self.model = genai.GenerativeModel(self.model_name)
model = genai.GenerativeModel(self.model_name)

self.model = (
instructor.from_gemini(model) if self.structured_output else model
)

def generate(self, query: str, context: list[str]) -> str:
"""Generate text using Gemini model support."""
Expand Down
14 changes: 10 additions & 4 deletions src/rago/generation/hugging_face.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

import warnings

import torch

from transformers import T5ForConditionalGeneration, T5Tokenizer
Expand All @@ -22,13 +24,17 @@ def _validate(self) -> None:
f'The given model {self.model_name} is not supported.'
)

if self.structured_output:
warnings.warn(
'Structured output is not supported yet in '
f'{self.__class__.__name__}.'
)

def _setup(self) -> None:
"""Set models to t5-small models."""
self.tokenizer = T5Tokenizer.from_pretrained(self.model_name)
self.model = T5ForConditionalGeneration.from_pretrained(
self.model_name
)
self.model = self.model.to(self.device)
model = T5ForConditionalGeneration.from_pretrained(self.model_name)
self.model = model.to(self.device)

def generate(self, query: str, context: list[str]) -> str:
"""Generate the text from the query and augmented context."""
Expand Down
8 changes: 8 additions & 0 deletions src/rago/generation/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

from __future__ import annotations

import warnings

import torch

from langdetect import detect
Expand All @@ -27,6 +29,12 @@ def _validate(self) -> None:
'by meta.'
)

if self.structured_output:
warnings.warn(
'Structured output is not supported yet in '
f'{self.__class__.__name__}.'
)

def _setup(self) -> None:
"""Set up the object with the initial parameters."""
self.tokenizer = AutoTokenizer.from_pretrained(
Expand Down
7 changes: 6 additions & 1 deletion src/rago/generation/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from typing import cast

import instructor
import openai

from typeguard import typechecked
Expand All @@ -19,7 +20,11 @@ class OpenAIGen(GenerationBase):

def _setup(self) -> None:
"""Set up the object with the initial parameters."""
self.model = openai.OpenAI(api_key=self.api_key)
model = openai.OpenAI(api_key=self.api_key)

self.model = (
instructor.from_openai(model) if self.structured_output else model
)

def generate(
self,
Expand Down
5 changes: 4 additions & 1 deletion tests/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,10 @@ def test_aug_openai(animals_data: list[str], api_key: str) -> None:
aug_result = aug_openai.search(query, ret_result)

assert aug_openai.top_k == top_k
assert len(aug_result) == top_k
# note: openai as augmented doesn't work as expected
# it is returning a very poor result
# it needs to be revisited and improved
assert len(aug_result) >= 1
assert 'blue whale' in aug_result[0].lower()

# check if logs have been used
Expand Down

0 comments on commit 9731c83

Please sign in to comment.