Skip to content

Commit

Permalink
Move LLM data models into folder
Browse files Browse the repository at this point in the history
create + test setup_llm() in utils
  • Loading branch information
pipliggins committed Dec 4, 2024
1 parent db05b32 commit f5f4e8c
Show file tree
Hide file tree
Showing 11 changed files with 102 additions and 81 deletions.
16 changes: 8 additions & 8 deletions src/adtl/autoparser/create_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,13 @@
import numpy as np
import pandas as pd

from .language_models.gemini import GeminiLanguageModel
from .language_models.openai import OpenAILanguageModel
from .util import DEFAULT_CONFIG, load_data_dict, read_config_schema, read_json
from .util import (
DEFAULT_CONFIG,
load_data_dict,
read_config_schema,
read_json,
setup_llm,
)


class Mapper:
Expand Down Expand Up @@ -55,12 +59,8 @@ def __init__(
self.language = language
if llm is None:
self.model = None
elif llm == "openai": # pragma: no cover
self.model = OpenAILanguageModel(api_key)
elif llm == "gemini": # pragma: no cover
self.model = GeminiLanguageModel(api_key)
else:
raise ValueError(f"Unsupported LLM: {llm}")
self.model = setup_llm(llm, api_key)

self.config = read_config_schema(
config or Path(Path(__file__).parent, DEFAULT_CONFIG)
Expand Down
38 changes: 9 additions & 29 deletions src/adtl/autoparser/dict_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@
import numpy as np
import pandas as pd

from .language_models.gemini import GeminiLanguageModel
from .language_models.openai import OpenAILanguageModel
from .util import DEFAULT_CONFIG, load_data_dict, read_config_schema, read_data
from .util import (
DEFAULT_CONFIG,
load_data_dict,
read_config_schema,
read_data,
setup_llm,
)


class DictWriter:
Expand Down Expand Up @@ -45,34 +49,10 @@ def __init__(
)

if llm and api_key:
self._setup_llm(api_key, llm)
self.model = setup_llm(llm, api_key)

Check warning on line 52 in src/adtl/autoparser/dict_writer.py

View check run for this annotation

Codecov / codecov/patch

src/adtl/autoparser/dict_writer.py#L52

Added line #L52 was not covered by tests
else:
self.model = None

def _setup_llm(self, key: str, name: str):
"""
Setup the LLM to use to generate descriptions.
Separate from the __init__ method to allow for extra barrier between raw data &
LLM.
Parameters
----------
key
API key
name
Name of the LLM to use (currently only OpenAI and Gemini are supported)
"""
if key is None:
raise ValueError("API key required for generating descriptions")

if name == "openai": # pragma: no cover
self.model = OpenAILanguageModel(api_key=key)
elif name == "gemini": # pragma: no cover
self.model = GeminiLanguageModel(api_key=key)
else:
raise ValueError(f"Unsupported LLM: {name}")

def create_dict(self, data: pd.DataFrame | str) -> pd.DataFrame:
"""
Create a basic data dictionary from a dataset.
Expand Down Expand Up @@ -182,7 +162,7 @@ def generate_descriptions(
df = load_data_dict(self.config, data_dict)

if not self.model:
self._setup_llm(key, llm)
self.model = setup_llm(llm, key)

headers = df.source_field

Expand Down
39 changes: 39 additions & 0 deletions src/adtl/autoparser/language_models/data_structures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""Stores the data structures for using with LLM API's"""

from pydantic import BaseModel

# target classes for generating descriptions


class SingleField(BaseModel):
field_name: str
translation: str | None


class ColumnDescriptionRequest(BaseModel):
field_descriptions: list[SingleField]


# target classes for matching fields
class SingleMapping(BaseModel):
target_field: str
source_description: str | None


class MappingRequest(BaseModel):
targets_descriptions: list[SingleMapping]


# target classes for matching values to enum/boolean options
class ValueMapping(BaseModel):
source_value: str
target_value: str | None


class FieldMapping(BaseModel):
field_name: str
mapped_values: list[ValueMapping]


class ValuesRequest(BaseModel):
values: list[FieldMapping]
3 changes: 1 addition & 2 deletions src/adtl/autoparser/language_models/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@

import google.generativeai as gemini

from adtl.autoparser.util import ColumnDescriptionRequest, MappingRequest, ValuesRequest

from .base_llm import LLMBase
from .data_structures import ColumnDescriptionRequest, MappingRequest, ValuesRequest


class GeminiLanguageModel(LLMBase):
Expand Down
3 changes: 1 addition & 2 deletions src/adtl/autoparser/language_models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@

from openai import OpenAI

from adtl.autoparser.util import ColumnDescriptionRequest, MappingRequest, ValuesRequest

from .base_llm import LLMBase
from .data_structures import ColumnDescriptionRequest, MappingRequest, ValuesRequest


class OpenAILanguageModel(LLMBase):
Expand Down
58 changes: 23 additions & 35 deletions src/adtl/autoparser/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@

import pandas as pd
import tomli
from pydantic import BaseModel

from adtl.autoparser.language_models.gemini import GeminiLanguageModel
from adtl.autoparser.language_models.openai import OpenAILanguageModel

DEFAULT_CONFIG = "config/autoparser.toml"

Expand Down Expand Up @@ -106,40 +108,26 @@ def load_data_dict(
return data_dict


# Data structures for llm calls --------------------------

# target classes for generating descriptions


class SingleField(BaseModel):
field_name: str
translation: str | None


class ColumnDescriptionRequest(BaseModel):
field_descriptions: list[SingleField]

def setup_llm(provider, api_key):
"""
Setup the LLM to use to generate descriptions.
# target classes for matching fields
class SingleMapping(BaseModel):
target_field: str
source_description: str | None
Separate from the __init__ method to allow for extra barrier between raw data &
LLM.
Parameters
----------
key
API key
name
Name of the LLM to use (currently only OpenAI and Gemini are supported)
"""
if api_key is None:
raise ValueError("API key required to set up an LLM")

class MappingRequest(BaseModel):
targets_descriptions: list[SingleMapping]


# target classes for matching values to enum/boolean options
class ValueMapping(BaseModel):
source_value: str
target_value: str | None


class FieldMapping(BaseModel):
field_name: str
mapped_values: list[ValueMapping]


class ValuesRequest(BaseModel):
values: list[FieldMapping]
if provider == "openai": # pragma: no cover
return OpenAILanguageModel(api_key=api_key)
elif provider == "gemini": # pragma: no cover
return GeminiLanguageModel(api_key=api_key)
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
2 changes: 1 addition & 1 deletion tests/test_autoparser/test_dict_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_missing_key_error():


def test_wrong_llm_error():
with pytest.raises(ValueError, match="Unsupported LLM: fish"):
with pytest.raises(ValueError, match="Unsupported LLM provider: fish"):
DictWriter(config=Path(CONFIG_PATH)).generate_descriptions(
"fr", SOURCES + "animals_dd.csv", key="a12b3c", llm="fish"
)
3 changes: 2 additions & 1 deletion tests/test_autoparser/test_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,12 @@ def test_common_values_mapped_fields_error():


def test_mapper_class_init_raises():
with pytest.raises(ValueError, match="Unsupported LLM: fish"):
with pytest.raises(ValueError, match="Unsupported LLM provider: fish"):
Mapper(
Path("tests/test_autoparser/schemas/animals.schema.json"),
"tests/test_autoparser/sources/animals_dd_described.csv",
"fr",
api_key="1234",
llm="fish",
)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_autoparser/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
)
from testing_data_animals import get_definitions, map_fields, map_values

from adtl.autoparser.language_models.data_structures import ColumnDescriptionRequest
from adtl.autoparser.language_models.openai import OpenAILanguageModel
from adtl.autoparser.util import ColumnDescriptionRequest


def test_init():
Expand Down
17 changes: 16 additions & 1 deletion tests/test_autoparser/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
import pandas as pd
import pytest

from adtl.autoparser.util import load_data_dict, parse_choices, read_config_schema
from adtl.autoparser.util import (
load_data_dict,
parse_choices,
read_config_schema,
setup_llm,
)

CONFIG = read_config_schema(Path("tests/test_autoparser/test_config.toml"))

Expand Down Expand Up @@ -86,3 +91,13 @@ def test_load_data_dict():

with pytest.raises(ValueError, match="Unsupported format"):
load_data_dict(CONFIG, "tests/test_autoparser/sources/animals.txt")


def test_setup_llm_no_key():
with pytest.raises(ValueError, match="API key required to set up an LLM"):
setup_llm("openai", None)


def test_setup_llm_bad_provider():
with pytest.raises(ValueError, match="Unsupported LLM provider: fish"):
setup_llm("fish", "abcd")
2 changes: 1 addition & 1 deletion tests/test_autoparser/testing_data_animals.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

from adtl.autoparser.language_models.base_llm import LLMBase
from adtl.autoparser.util import (
from adtl.autoparser.language_models.data_structures import (
ColumnDescriptionRequest,
FieldMapping,
MappingRequest,
Expand Down

0 comments on commit f5f4e8c

Please sign in to comment.