From f5f4e8cbe4293066ffddd3a0632fd7befd478ad2 Mon Sep 17 00:00:00 2001 From: Pip Liggins Date: Wed, 4 Dec 2024 17:14:15 +0000 Subject: [PATCH] Move LLM data models into folder create + test setup_llm() in utils --- src/adtl/autoparser/create_mapping.py | 16 ++--- src/adtl/autoparser/dict_writer.py | 38 +++--------- .../language_models/data_structures.py | 39 +++++++++++++ src/adtl/autoparser/language_models/gemini.py | 3 +- src/adtl/autoparser/language_models/openai.py | 3 +- src/adtl/autoparser/util.py | 58 ++++++++----------- tests/test_autoparser/test_dict_writer.py | 2 +- tests/test_autoparser/test_mapper.py | 3 +- tests/test_autoparser/test_openai.py | 2 +- tests/test_autoparser/test_utils.py | 17 +++++- tests/test_autoparser/testing_data_animals.py | 2 +- 11 files changed, 102 insertions(+), 81 deletions(-) create mode 100644 src/adtl/autoparser/language_models/data_structures.py diff --git a/src/adtl/autoparser/create_mapping.py b/src/adtl/autoparser/create_mapping.py index 4a8471e..977e132 100644 --- a/src/adtl/autoparser/create_mapping.py +++ b/src/adtl/autoparser/create_mapping.py @@ -12,9 +12,13 @@ import numpy as np import pandas as pd -from .language_models.gemini import GeminiLanguageModel -from .language_models.openai import OpenAILanguageModel -from .util import DEFAULT_CONFIG, load_data_dict, read_config_schema, read_json +from .util import ( + DEFAULT_CONFIG, + load_data_dict, + read_config_schema, + read_json, + setup_llm, +) class Mapper: @@ -55,12 +59,8 @@ def __init__( self.language = language if llm is None: self.model = None - elif llm == "openai": # pragma: no cover - self.model = OpenAILanguageModel(api_key) - elif llm == "gemini": # pragma: no cover - self.model = GeminiLanguageModel(api_key) else: - raise ValueError(f"Unsupported LLM: {llm}") + self.model = setup_llm(llm, api_key) self.config = read_config_schema( config or Path(Path(__file__).parent, DEFAULT_CONFIG) diff --git a/src/adtl/autoparser/dict_writer.py b/src/adtl/autoparser/dict_writer.py index 0d8d09d..c6ae712 100644 --- a/src/adtl/autoparser/dict_writer.py +++ b/src/adtl/autoparser/dict_writer.py @@ -10,9 +10,13 @@ import numpy as np import pandas as pd -from .language_models.gemini import GeminiLanguageModel -from .language_models.openai import OpenAILanguageModel -from .util import DEFAULT_CONFIG, load_data_dict, read_config_schema, read_data +from .util import ( + DEFAULT_CONFIG, + load_data_dict, + read_config_schema, + read_data, + setup_llm, +) class DictWriter: @@ -45,34 +49,10 @@ def __init__( ) if llm and api_key: - self._setup_llm(api_key, llm) + self.model = setup_llm(llm, api_key) else: self.model = None - def _setup_llm(self, key: str, name: str): - """ - Setup the LLM to use to generate descriptions. - - Separate from the __init__ method to allow for extra barrier between raw data & - LLM. - - Parameters - ---------- - key - API key - name - Name of the LLM to use (currently only OpenAI and Gemini are supported) - """ - if key is None: - raise ValueError("API key required for generating descriptions") - - if name == "openai": # pragma: no cover - self.model = OpenAILanguageModel(api_key=key) - elif name == "gemini": # pragma: no cover - self.model = GeminiLanguageModel(api_key=key) - else: - raise ValueError(f"Unsupported LLM: {name}") - def create_dict(self, data: pd.DataFrame | str) -> pd.DataFrame: """ Create a basic data dictionary from a dataset. @@ -182,7 +162,7 @@ def generate_descriptions( df = load_data_dict(self.config, data_dict) if not self.model: - self._setup_llm(key, llm) + self.model = setup_llm(llm, key) headers = df.source_field diff --git a/src/adtl/autoparser/language_models/data_structures.py b/src/adtl/autoparser/language_models/data_structures.py new file mode 100644 index 0000000..0f55e99 --- /dev/null +++ b/src/adtl/autoparser/language_models/data_structures.py @@ -0,0 +1,39 @@ +"""Stores the data structures for using with LLM API's""" + +from pydantic import BaseModel + +# target classes for generating descriptions + + +class SingleField(BaseModel): + field_name: str + translation: str | None + + +class ColumnDescriptionRequest(BaseModel): + field_descriptions: list[SingleField] + + +# target classes for matching fields +class SingleMapping(BaseModel): + target_field: str + source_description: str | None + + +class MappingRequest(BaseModel): + targets_descriptions: list[SingleMapping] + + +# target classes for matching values to enum/boolean options +class ValueMapping(BaseModel): + source_value: str + target_value: str | None + + +class FieldMapping(BaseModel): + field_name: str + mapped_values: list[ValueMapping] + + +class ValuesRequest(BaseModel): + values: list[FieldMapping] diff --git a/src/adtl/autoparser/language_models/gemini.py b/src/adtl/autoparser/language_models/gemini.py index ebf58bb..81b10d3 100644 --- a/src/adtl/autoparser/language_models/gemini.py +++ b/src/adtl/autoparser/language_models/gemini.py @@ -6,9 +6,8 @@ import google.generativeai as gemini -from adtl.autoparser.util import ColumnDescriptionRequest, MappingRequest, ValuesRequest - from .base_llm import LLMBase +from .data_structures import ColumnDescriptionRequest, MappingRequest, ValuesRequest class GeminiLanguageModel(LLMBase): diff --git a/src/adtl/autoparser/language_models/openai.py b/src/adtl/autoparser/language_models/openai.py index 9a33f3c..2fb2525 100644 --- a/src/adtl/autoparser/language_models/openai.py +++ b/src/adtl/autoparser/language_models/openai.py @@ -4,9 +4,8 @@ from openai import OpenAI -from adtl.autoparser.util import ColumnDescriptionRequest, MappingRequest, ValuesRequest - from .base_llm import LLMBase +from .data_structures import ColumnDescriptionRequest, MappingRequest, ValuesRequest class OpenAILanguageModel(LLMBase): diff --git a/src/adtl/autoparser/util.py b/src/adtl/autoparser/util.py index ed53ad4..a106698 100644 --- a/src/adtl/autoparser/util.py +++ b/src/adtl/autoparser/util.py @@ -10,7 +10,9 @@ import pandas as pd import tomli -from pydantic import BaseModel + +from adtl.autoparser.language_models.gemini import GeminiLanguageModel +from adtl.autoparser.language_models.openai import OpenAILanguageModel DEFAULT_CONFIG = "config/autoparser.toml" @@ -106,40 +108,26 @@ def load_data_dict( return data_dict -# Data structures for llm calls -------------------------- - -# target classes for generating descriptions - - -class SingleField(BaseModel): - field_name: str - translation: str | None - - -class ColumnDescriptionRequest(BaseModel): - field_descriptions: list[SingleField] - +def setup_llm(provider, api_key): + """ + Setup the LLM to use to generate descriptions. -# target classes for matching fields -class SingleMapping(BaseModel): - target_field: str - source_description: str | None + Separate from the __init__ method to allow for extra barrier between raw data & + LLM. + Parameters + ---------- + key + API key + name + Name of the LLM to use (currently only OpenAI and Gemini are supported) + """ + if api_key is None: + raise ValueError("API key required to set up an LLM") -class MappingRequest(BaseModel): - targets_descriptions: list[SingleMapping] - - -# target classes for matching values to enum/boolean options -class ValueMapping(BaseModel): - source_value: str - target_value: str | None - - -class FieldMapping(BaseModel): - field_name: str - mapped_values: list[ValueMapping] - - -class ValuesRequest(BaseModel): - values: list[FieldMapping] + if provider == "openai": # pragma: no cover + return OpenAILanguageModel(api_key=api_key) + elif provider == "gemini": # pragma: no cover + return GeminiLanguageModel(api_key=api_key) + else: + raise ValueError(f"Unsupported LLM provider: {provider}") diff --git a/tests/test_autoparser/test_dict_writer.py b/tests/test_autoparser/test_dict_writer.py index 79f14d3..b0af931 100644 --- a/tests/test_autoparser/test_dict_writer.py +++ b/tests/test_autoparser/test_dict_writer.py @@ -81,7 +81,7 @@ def test_missing_key_error(): def test_wrong_llm_error(): - with pytest.raises(ValueError, match="Unsupported LLM: fish"): + with pytest.raises(ValueError, match="Unsupported LLM provider: fish"): DictWriter(config=Path(CONFIG_PATH)).generate_descriptions( "fr", SOURCES + "animals_dd.csv", key="a12b3c", llm="fish" ) diff --git a/tests/test_autoparser/test_mapper.py b/tests/test_autoparser/test_mapper.py index fca7392..7addf53 100644 --- a/tests/test_autoparser/test_mapper.py +++ b/tests/test_autoparser/test_mapper.py @@ -215,11 +215,12 @@ def test_common_values_mapped_fields_error(): def test_mapper_class_init_raises(): - with pytest.raises(ValueError, match="Unsupported LLM: fish"): + with pytest.raises(ValueError, match="Unsupported LLM provider: fish"): Mapper( Path("tests/test_autoparser/schemas/animals.schema.json"), "tests/test_autoparser/sources/animals_dd_described.csv", "fr", + api_key="1234", llm="fish", ) diff --git a/tests/test_autoparser/test_openai.py b/tests/test_autoparser/test_openai.py index 015de2b..b3b1a73 100644 --- a/tests/test_autoparser/test_openai.py +++ b/tests/test_autoparser/test_openai.py @@ -9,8 +9,8 @@ ) from testing_data_animals import get_definitions, map_fields, map_values +from adtl.autoparser.language_models.data_structures import ColumnDescriptionRequest from adtl.autoparser.language_models.openai import OpenAILanguageModel -from adtl.autoparser.util import ColumnDescriptionRequest def test_init(): diff --git a/tests/test_autoparser/test_utils.py b/tests/test_autoparser/test_utils.py index 955f63b..3cbb3d1 100644 --- a/tests/test_autoparser/test_utils.py +++ b/tests/test_autoparser/test_utils.py @@ -6,7 +6,12 @@ import pandas as pd import pytest -from adtl.autoparser.util import load_data_dict, parse_choices, read_config_schema +from adtl.autoparser.util import ( + load_data_dict, + parse_choices, + read_config_schema, + setup_llm, +) CONFIG = read_config_schema(Path("tests/test_autoparser/test_config.toml")) @@ -86,3 +91,13 @@ def test_load_data_dict(): with pytest.raises(ValueError, match="Unsupported format"): load_data_dict(CONFIG, "tests/test_autoparser/sources/animals.txt") + + +def test_setup_llm_no_key(): + with pytest.raises(ValueError, match="API key required to set up an LLM"): + setup_llm("openai", None) + + +def test_setup_llm_bad_provider(): + with pytest.raises(ValueError, match="Unsupported LLM provider: fish"): + setup_llm("fish", "abcd") diff --git a/tests/test_autoparser/testing_data_animals.py b/tests/test_autoparser/testing_data_animals.py index 9e55456..74251ae 100644 --- a/tests/test_autoparser/testing_data_animals.py +++ b/tests/test_autoparser/testing_data_animals.py @@ -3,7 +3,7 @@ from __future__ import annotations from adtl.autoparser.language_models.base_llm import LLMBase -from adtl.autoparser.util import ( +from adtl.autoparser.language_models.data_structures import ( ColumnDescriptionRequest, FieldMapping, MappingRequest,