diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 20214ea..f863dce 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -11,7 +11,6 @@ jobs: build-on-ubuntu: runs-on: ubuntu-latest strategy: - fail-fast: false max-parallel: 3 matrix: python-version: ['3.9', '3.10', '3.11'] diff --git a/MANIFEST.in b/MANIFEST.in index eea6d6b..77c2e5a 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,5 @@ include LICENSE include README.md include CHANGELOG.md -recursive-include gemma_template docs Makefile *.md *.rst +include docs/*.md +recursive-include gemma_template diff --git a/Makefile b/Makefile index 6f7a1b3..2e4f91c 100644 --- a/Makefile +++ b/Makefile @@ -5,7 +5,7 @@ PROJECT=gemma_template DEFAULT_GOAL: help -.PHONY: help all install lint readme docs +.PHONY: help all install lint readme docs test tox # Colors for echos bold = \033[0;32m\033[1m*** @@ -14,6 +14,8 @@ end_bold = *** \033[0m\033[0;0m all: ##@target >> Run all. @make install @make lint + @make readme + @make test # And add help text after each target name starting with '\#\#' # A category can be added with @category @@ -50,11 +52,14 @@ lint: ##@target >> Run Lint. @$(PYTHON) -m flake8 $(PROJECT) test: + python -m pytest tests + +tox: tox -p rm -rf *.egg-info readme: - python setup.py check --restructuredtext --strict && ([ $$? -eq 0 ] && echo "README.rst and CHANGELOG.md ok") || echo "Invalid markup in README.md or CHANGELOG.md!" + python setup.py check --restructuredtext --strict && ([ $$? -eq 0 ] && echo "README.md and CHANGELOG.md ok") || echo "Invalid markup in README.md or CHANGELOG.md!" docs: mkdocs serve diff --git a/gemma_template/__init__.py b/gemma_template/__init__.py index d239922..f56ead9 100644 --- a/gemma_template/__init__.py +++ b/gemma_template/__init__.py @@ -1,5 +1,7 @@ +from .__version__ import (__author__, __description__, __license__, __title__, + __version__) from .constants import * from .exceptions import * from .models import (StructureField, Template, gemma_template, - vietnamese_template) + vietnamese_gemma_template) from .utils import get_frequently_words, get_language, get_n_grams diff --git a/gemma_template/__version__.py b/gemma_template/__version__.py index 2c697a0..6393b3a 100644 --- a/gemma_template/__version__.py +++ b/gemma_template/__version__.py @@ -3,5 +3,5 @@ __url__ = "https://github.com/thewebscraping/gemma-template" __author__ = "Tu Pham" __author_email__ = "thetwofarm@gmail.com" -__version__ = "0.1.0" +__version__ = "0.1.1" __license__ = "Apache-2.0" diff --git a/gemma_template/models.py b/gemma_template/models.py index ae8c7dc..84c6ea5 100644 --- a/gemma_template/models.py +++ b/gemma_template/models.py @@ -4,7 +4,8 @@ import json from pathlib import Path from string import punctuation -from typing import Callable, ClassVar, Literal, Optional, Union, get_origin +from typing import (Callable, ClassVar, Literal, Optional, Sequence, Union, + get_origin) import nest_asyncio from datasets import Dataset, DatasetDict, load_dataset @@ -248,17 +249,17 @@ class Template(BaseTemplate): # Response Structure Format: You must follow the response structure: - **Custom Title (Title):** Rewrite the title to make it concise, memorable, and optimized for SEO. - **Custom Description (Description):** Write description of the article in one or two sentences while focusing on reader benefits and engage curiosity. - **Custom Article (Article):** Transform this text into a formal, professional tone suitable for business communication or an academic audience. Focus on improving vocabulary, grammar, and structure. - **Custom Main Points (Main Points):** Summarize the main ideas into concise, actionable key points for added context to make them more engaging. - **Custom Categories (Categories):** Rewrite categories to align with industry standards or popular topics. - **Custom Tags (Tags):** Add trending keyword terms or phrases to the tags for increased visibility. + **Custom Title (Title):** Rewrite the title to reflect the main keyword and topic. + **Custom Description (Description):** Rewrite the description with a bold claim or statistic to grab attention. + **Custom Article (Article):** Rewrite this content to be SEO-friendly. Include relevant tags, optimize the title and subheadings, and ensure the text flows naturally for search engines and readers. + **Custom Main Points (Main Points):** Simplify the original key points to make them clearer and more reader-friendly. + **Custom Categories (Categories):** Assign appropriate categories to the article based text or target audience. + **Custom Tags (Tags):** Focus use tags that reflect the article’s subtopics or themes for better SEO. By adhering to this format, the response will maintain linguistic integrity while enhancing professionalism, structure and alignment with user expectations. # Text: - Gemma open models are built from the same research _____ technology as Gemini models. Gemma 2 comes in 2B, 9B _____ 27B and Gemma 1 comes in 2B _____ 7B sizes. + Gemma open models are _____ from the same research and technology as Gemini models. Gemma 2 comes in 2B, 9B _____ 27B and Gemma 1 comes in 2B _____ 7B sizes. model @@ -394,12 +395,13 @@ def load_dataset( instruction_template: Optional[TemplateTypes] = None, structure_template: Optional[TemplateTypes] = None, output_format: Union[str, Literal["text", "alpaca", "gpt"]] = "text", - eos_token_str: Optional[str] = "", + excluded_fields: Optional[Sequence[str]] = (), max_hidden_ratio: Union[float] = 0, max_hidden_words: Union[int, float] = 0, min_chars_length: int = 2, max_chars_length: int = 0, max_concurrency: int = 4, + is_close_async_loop: bool = True, **kwargs, ) -> Union[Dataset, DatasetDict]: """ @@ -421,8 +423,8 @@ def load_dataset( Template for structuring the user prompt. output_format (Union[str, Literal["text", "alpaca", "gpt"]]): Specifies the format for the generated prompts. Default is "text". - eos_token_str (Optional[str]): - Append eos token to the end of the model output. + excluded_fields (Optional[Sequence[str]]): + Fields excluded to response. Default is empty sequence. max_hidden_ratio (Union[float]): Percentage of documents that need to be word masked. Min: 0, Max: 1. Default: 0. max_hidden_words (Optional[str]): @@ -434,6 +436,9 @@ def load_dataset( Maximum character of a word, used to create unigrams, bigrams and trigrams. Default is 0. max_concurrency (int): Maximum number of concurrent threads for processing data. Default is 4. + is_close_async_loop (bool): + By default it will close the asyncio event loop every time I finish processing the dataset data. + Although it has handled the `RuntimeError` exception. However, you should set it to False if running on Kaggle Notebooks and Colab. **kwargs: Additional parameters, including: - `token` (Optional[str]): Hugging Face authentication token. - `split` (Optional[list[str]]): Dataset split for Hugging Face Dataset loading. @@ -468,7 +473,13 @@ def load_dataset( async def create_task(config, hidden_count: int = 0): async with semaphore: config.update(kwargs) - config.update(dict(min_chars_length=min_chars_length, max_chars_length=max_chars_length)) + config.update( + dict( + min_chars_length=min_chars_length, + max_chars_length=max_chars_length, + excluded_fields=excluded_fields, + ) + ) if max_hidden_ratio > 0 and hidden_count < max_hidden_count: config["max_hidden_words"] = max_hidden_words else: @@ -480,7 +491,6 @@ async def create_task(config, hidden_count: int = 0): user_template, instruction_template, structure_template, - eos_token_str, **config, ) ) @@ -490,7 +500,6 @@ async def create_task(config, hidden_count: int = 0): user_template, instruction_template, structure_template, - eos_token_str, **config, ) ) @@ -501,7 +510,6 @@ async def create_task(config, hidden_count: int = 0): user_template, instruction_template, structure_template, - eos_token_str, **config, ) ) @@ -510,18 +518,26 @@ async def create_task(config, hidden_count: int = 0): hidden_count += 1 async def run_task(ds): - await asyncio.wait([loop.create_task(create_task(config, idx)) for idx, config in enumerate(ds)]) + await asyncio.wait( + [ + loop.create_task(create_task(config, idx)) + for idx, config in enumerate(ds) + ] + ) def _close(): - """Notebook Error""" - try: - loop.close() - except RuntimeError: - pass + """Closed Asyncio event loop""" + if is_close_async_loop: + try: + loop.close() + except RuntimeError: + pass if max_hidden_ratio: if 0 > max_hidden_ratio > 1: - raise MaxHiddenRatioError("Maximum hidden ratio must be between 0 and 1.") + raise MaxHiddenRatioError( + "Maximum hidden ratio must be between 0 and 1." + ) dataset = fp if isinstance(dataset, (str, Path)): @@ -619,25 +635,38 @@ def get_user_kwargs( system_template_str, prompt_template_str, structure_template_str, document = ( self._get_prompts(structure_template, **kwargs) ) - language_code, language = get_language(document) + language_code = "auto" + language = kwargs.get("language") + if language is None: + language_code, language = get_language(document) + document = mask_hidden(language_code=language_code, **kwargs) - unigrams = self._get_frequently_words( - n=1, response_n=n_words, language_code=language_code, **kwargs - ) - bigrams = self._get_frequently_words( - document, - n=2, - response_n=n_words, - language_code=language_code, - excluded_words=unigrams, - ) - trigrams = self._get_frequently_words( - document, - n=3, - response_n=n_words, - language_code=language_code, - excluded_words=unigrams, - ) + + unigrams = kwargs.get("unigrams") + if unigrams is None: + unigrams = self._get_frequently_words( + n=1, response_n=n_words, language_code=language_code, **kwargs + ) + + bigrams = kwargs.get("bigrams") + if bigrams is None: + bigrams = self._get_frequently_words( + document, + n=2, + response_n=n_words, + language_code=language_code, + excluded_words=unigrams, + ) + trigrams = kwargs.get("trigrams") + if trigrams is None: + trigrams = self._get_frequently_words( + document, + n=3, + response_n=n_words, + language_code=language_code, + excluded_words=unigrams, + ) + instruction_kwargs = dict( document=document, topic_values=", ".join(kwargs.get("categories", []) or []), @@ -645,15 +674,14 @@ def get_user_kwargs( unigrams=unigrams, bigrams=bigrams, trigrams=trigrams, - n_words=n_words, + language_code=language_code, language=language, bullet_style=bullet_style, is_masked=bool(kwargs.get("max_hidden_words")), ) if isinstance(instruction_template, Callable): - instruction_template_str = instruction_template( - self, **instruction_kwargs, **kwargs - ) + kwargs.update(**instruction_kwargs) + instruction_template_str = instruction_template(self, **kwargs) else: instruction_template_str = self._formatting_instruction_fn( instruction_template, **instruction_kwargs @@ -661,9 +689,10 @@ def get_user_kwargs( if structure_template_str: if isinstance(structure_template, Callable): - structure_template_str = structure_template( - self, self._get_structure_attrs(**kwargs), **kwargs + kwargs.setdefault( + "structure_attrs", self._get_structure_attrs(**kwargs) ) + structure_template_str = structure_template(self, **kwargs) else: structure_template_str = self._formatting_structure_user_fn( structure_template, @@ -712,7 +741,6 @@ def template( user_template: Optional[TemplateTypes] = USER_TEMPLATE, instruction_template: Optional[TemplateTypes] = None, structure_template: Optional[TemplateTypes] = None, - eos_token_str: Optional[str] = "", **kwargs, ): """ @@ -723,7 +751,6 @@ def template( user_template (Optional[Union[str, Callable]]): User Template for user prompt. instruction_template (Optional[Union[str, Callable]]): Instruction template for instruction prompt, if applicable. structure_template (Optional[Union[str, Callable]]): Structuring template for structuring prompt, if applicable. - eos_token_str (Optional[str]): Append eos token to the end of the model output. **kwargs: Additional parameters including: - output: Optional[str] = Model response output. - title: Optional[list[str]] = List of title to include in the prompt. @@ -747,12 +774,9 @@ def template( >>> print(response) """ # noqa: E501 - user_template = self.generate_user_prompt( + user_template, model_template, _ = self._get_template( user_template, instruction_template, structure_template, **kwargs ) - model_template = self.generate_model_prompt( - structure_template, eos_token_str, **kwargs - ) if isinstance(template, Callable): return template(user_template=user_template, model_template=model_template) @@ -767,11 +791,12 @@ def generate_prompt( user_template: Optional[TemplateTypes] = USER_TEMPLATE, instruction_template: Optional[TemplateTypes] = None, structure_template: Optional[TemplateTypes] = None, - eos_token_str: Optional[str] = "", **kwargs, ): """Generates a prompt to predict.""" - return self.template(template, user_template, instruction_template, structure_template, eos_token_str, **kwargs) + return self.template( + template, user_template, instruction_template, structure_template, **kwargs + ) def generate_user_prompt( self, @@ -803,24 +828,15 @@ def generate_user_prompt( >>> print(response) """ # noqa: E501 - if instruction_template is not None: - user_kwargs = self.get_user_kwargs( - instruction_template, structure_template, **kwargs - ) - return user_template.format(**user_kwargs) - - return "\n\n".join( - [ - p.strip() - for p in self._get_prompts(structure_template, **kwargs) - if p.strip() - ] + user_template, *_ = self._get_template( + user_template, instruction_template, structure_template, **kwargs ) + return user_template def generate_model_prompt( self, structure_template: Optional[TemplateTypes] = None, - eos_token_str: Optional[str] = "", + excluded_fields: Optional[Sequence[str]] = (), bullet_style: Optional[Union[str, Literal["dash", "number"]]] = "dash", **kwargs, ) -> str: @@ -832,7 +848,7 @@ def generate_model_prompt( Args: structure_template (Optional[Union[str, Callable]]): A structure template defining the generating structure prompt. - eos_token_str (Optional[str]): Append eos token to the end of the model output. + excluded_fields (Sequence[str]): Fields excluded to response. Default is empty sequence. bullet_style (Optional[str]): Bullet list style start dash or number. Default is dash. **kwargs: See also `Template.template`. @@ -849,18 +865,25 @@ def generate_model_prompt( """ # noqa: E501 output_document = kwargs.get("output", "") + if excluded_fields: + for excluded_field in excluded_fields: + if excluded_field in kwargs: + kwargs.pop(excluded_field) + if isinstance(structure_template, (str, Callable)): kwargs["document"] = output_document if isinstance(structure_template, Callable): - if isinstance(structure_template, Callable): - self._structure_items = structure_template(structure_data, **kwargs) + kwargs.setdefault( + "structure_attrs", self._get_structure_attrs(**kwargs) + ) + output_document = structure_template(self, **kwargs) else: output_document = self._formatting_structure_model_fn( self._structure_items, bullet_style, **kwargs ) - return output_document.strip() + eos_token_str + return output_document.strip() def to_text( self, @@ -868,30 +891,12 @@ def to_text( user_template: Optional[TemplateTypes] = USER_TEMPLATE, instruction_template: Optional[TemplateTypes] = INSTRUCTION_TEMPLATE, structure_template: Optional[TemplateTypes] = STRUCTURE_TEMPLATE, - eos_token_str: Optional[str] = "", **kwargs, ) -> dict: """Generate SFT Text Template format""" - - user_kwargs = {} - if instruction_template is not None: - user_kwargs = self.get_user_kwargs( - instruction_template, structure_template, **kwargs - ) - user_template = user_template.format(**user_kwargs) - else: - user_template = "\n\n".join( - [ - p.strip() - for p in self._get_prompts(structure_template, **kwargs) - if p.strip() - ] - ) - - model_template = self.generate_model_prompt( - structure_template, eos_token_str, **kwargs + user_template, model_template, user_kwargs = self._get_template( + user_template, instruction_template, structure_template, **kwargs ) - if isinstance(template, Callable): text = template(user_template=user_template, model_template=model_template) else: @@ -907,6 +912,7 @@ def to_text( unigrams=user_kwargs.get("unigrams", []) or [], bigrams=user_kwargs.get("bigrams", []) or [], trigrams=user_kwargs.get("trigrams", []) or [], + language_code=user_kwargs.get("language_code", "auto"), language=user_kwargs.get("language"), is_masked=bool(user_kwargs.get("is_masked")), ) @@ -916,26 +922,22 @@ def to_alpaca( user_template: Optional[TemplateTypes] = USER_TEMPLATE, instruction_template: Optional[TemplateTypes] = INSTRUCTION_TEMPLATE, structure_template: Optional[TemplateTypes] = STRUCTURE_TEMPLATE, - eos_token_str: Optional[str] = "", **kwargs, ) -> dict: """Generate Alpaca Template format""" - user_kwargs = self.get_user_kwargs( - instruction_template, structure_template, **kwargs - ) - instruction = user_kwargs["instruction_template"] - model_template = self.generate_model_prompt( - structure_template, eos_token_str, **kwargs + user_template, model_template, user_kwargs = self._get_template( + user_template, instruction_template, structure_template, **kwargs ) return dict( - instruction=instruction, - input=instruction_kwargs.get("document", ""), + instruction=user_kwargs.get("instruction_template", "") or "", + input=user_kwargs.get("document", "") or "", output=model_template, is_instructed=bool(instruction_template is not None), is_structured=bool(structure_template is not None), unigrams=user_kwargs.get("unigrams", []) or [], bigrams=user_kwargs.get("bigrams", []) or [], trigrams=user_kwargs.get("trigrams", []) or [], + language_code=user_kwargs.get("language_code", "auto"), language=user_kwargs.get("language"), is_masked=bool(user_kwargs.get("is_masked")), ) @@ -945,30 +947,39 @@ def to_openai( user_template: Optional[TemplateTypes] = USER_TEMPLATE, instruction_template: Optional[TemplateTypes] = INSTRUCTION_TEMPLATE, structure_template: Optional[TemplateTypes] = STRUCTURE_TEMPLATE, - eos_token_str: Optional[str] = "", **kwargs, ) -> dict: """Generate Open AI Template format""" - - user_kwargs = self.get_user_kwargs( - instruction_template, structure_template, **kwargs - ) - human = user_template.format(**user_kwargs) - gpt = self.generate_model_prompt( - structure_template, eos_token_str, **kwargs + user_template, model_template, user_kwargs = self._get_template( + user_template, instruction_template, structure_template, **kwargs ) return dict( - human=human, - gpt=gpt, + human=user_template, + gpt=model_template, is_instructed=bool(instruction_template is not None), is_structured=bool(structure_template is not None), unigrams=user_kwargs.get("unigrams", []) or [], bigrams=user_kwargs.get("bigrams", []) or [], trigrams=user_kwargs.get("trigrams", []) or [], + language_code=user_kwargs.get("language_code", "auto"), language=user_kwargs.get("language"), is_masked=bool(user_kwargs.get("is_masked")), ) + def _get_template( + self, + user_template: Optional[TemplateTypes] = "", + instruction_template: Optional[TemplateTypes] = "", + structure_template: Optional[TemplateTypes] = "", + **kwargs, + ) -> tuple[str, str, dict]: + user_kwargs = self.get_user_kwargs( + instruction_template, structure_template, **kwargs + ) + user_template = user_template.format(**user_kwargs) + model_template = self.generate_model_prompt(structure_template, **kwargs) + return user_template, model_template, user_kwargs + def _get_prompts( self, structure_template: TemplateTypes = None, @@ -1042,6 +1053,7 @@ def _formatting_instruction_fn( def _ftm_template(word): return f"{word} => {language}" + instruction_template = instruction_template or "" return instruction_template.format( document=document, topic_values=topic_values, @@ -1118,7 +1130,7 @@ def _get_structure_attrs(self, **kwargs): gemma_template = Template() -vietnamese_template = Template( +vietnamese_gemma_template = Template( end_sep="và", system_prompts=[ ( diff --git a/gemma_template/utils.py b/gemma_template/utils.py index dd13486..052850a 100644 --- a/gemma_template/utils.py +++ b/gemma_template/utils.py @@ -12,7 +12,7 @@ EMAIL_RE = re.compile(r"[\w\-.]+@([\w-]+\.)+[\w-]{2,4}") URL_RE = re.compile(r"\w+://([A-Za-z_0-9.-]+).*") -MARKDOWN_RE = re.compile(r'(!|)\[[^]]*]\((.*?)\s*(\".*[^\"]\")?\s*\)') +MARKDOWN_RE = re.compile(r"(!|)\[[^]]*]\((.*?)\s*(\".*[^\"]\")?\s*\)") INVALID_WORD_RE = re.compile(r"[\d\W\-_]") @@ -165,7 +165,12 @@ def get_frequently_words( return outputs -def mask_hidden(document: str, max_hidden_words: Union[int, float] = 0, language_code: str = None, **kwargs) -> str: +def mask_hidden( + document: str, + max_hidden_words: Union[int, float] = 0, + language_code: str = None, + **kwargs, +) -> str: """Replace words in the document with '____'. Args: @@ -198,7 +203,9 @@ def mask_sentence(sentence: str, max_words: int) -> str: return sentence words = sentence.split() - valid_word_indices = [idx for idx, word in enumerate(words) if is_valid_word(word)] + valid_word_indices = [ + idx for idx, word in enumerate(words) if is_valid_word(word) + ] hidden_count = min(len(valid_word_indices), max_words) if hidden_count == 0: @@ -212,8 +219,14 @@ def mask_sentence(sentence: str, max_words: int) -> str: sentences = document.splitlines() word_count = len(document.split()) - max_hidden_count = max_hidden_words if isinstance(max_hidden_words, int) else int(max_hidden_words * word_count) + max_hidden_count = ( + max_hidden_words + if isinstance(max_hidden_words, int) + else int(max_hidden_words * word_count) + ) avg_max_words_in_sentence = max(1, max_hidden_count // max(1, len(sentences))) - masked_sentences = [mask_sentence(sentence, avg_max_words_in_sentence) for sentence in sentences] + masked_sentences = [ + mask_sentence(sentence, avg_max_words_in_sentence) for sentence in sentences + ] return "\n".join(masked_sentences) diff --git a/pyproject.toml b/pyproject.toml index d2839f6..4811f99 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,8 @@ [build-system] -requires = ['setuptools>=40.8.0'] -build-backend = 'setuptools.build_meta' +requires = [ + 'setuptools>=40.8.0' +] +build-backend = "setuptools.build_meta" [tool.pytest.ini_options] asyncio_mode = "auto" diff --git a/requirements-dev.txt b/requirements-dev.txt index 1eae8d1..feae626 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -20,3 +20,12 @@ black==24.3.0 flake8==7.0.0 pre-commit==3.7.0 isort==5.13.2 +mypy==1.11.2 +coverage[toml]==7.6.1 + +# Tests +# ------------------------------------------------------------------------------ +pytest==8.3.3 +pytest-asyncio==0.24.0 +pytest-cov==6.0.0 +tox==4.23.2 diff --git a/requirements.txt b/requirements.txt index ea85fc1..86536e7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -pydantic==2.10.4 -langdetect==1.0.9 -datasets==3.2.0 -tqdm==4.67.1 -nest-asyncio==1.6.0 +pydantic>=2.10.4 +langdetect>=1.0.9 +datasets>=3.2.0 +tqdm>=4.67.1 +nest-asyncio>=1.6.0 diff --git a/setup.cfg b/setup.cfg index 25a821c..0df7b52 100644 --- a/setup.cfg +++ b/setup.cfg @@ -4,12 +4,6 @@ long_description_content_type = text/markdown license = Apache2 license_file = LICENSE python_requires = >=3.9 -install_requires = - pydantic >= 2.10.4 - langdetect>= 1.0.9 - datasets >= 3.2.0 - nest-asyncio >= 1.6.0 - tqdm >= 4.67.1 classifiers = Development Status :: 4 - Beta Intended Audience :: Developers diff --git a/setup.py b/setup.py index eb11b2a..7b16df0 100644 --- a/setup.py +++ b/setup.py @@ -51,11 +51,18 @@ def normalize(name) -> str: if __name__ == "__main__": setup( - name=version["__title__"], + name=normalize(version["__title__"]), version=version["__version__"], description=version["__description__"], long_description_content_type="text/markdown", author=version["__author__"], author_email=version["__author_email__"], url=version["__url__"], + install_requires=[ + "pydantic", + "langdetect", + "datasets", + "tqdm", + "nest-asyncio", + ] ) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..5855f0a --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,37 @@ +import pytest # noqa + + +@pytest.fixture +def data_items(): + return [ + { + "id": "JnZJolR76_u2", + "title": "Gemma open models", + "description": "Gemma: Introducing new state-of-the-art open models", + "document": "Gemma open models are built from the same research and technology as Gemini models. Gemma 2 comes in 2B, 9B and 27B and Gemma 1 comes in 2B and 7B sizes.", + "categories": ["Topic 1", "Topic 2"], + "tags": ["Tag 1", "Tag 2"], + "output": "Sample output", + "main_points": ["Main point 1", "Main point 2"], + }, + { + "id": "JnZJolR76_u2", + "title": "Gemma open models", + "description": "Gemma: Introducing new state-of-the-art open models", + "document": "Gemma open models are built from the same research and technology as Gemini models. Gemma 2 comes in 2B, 9B and 27B and Gemma 1 comes in 2B and 7B sizes.", + "categories": ["Topic 1", "Topic 2"], + "tags": ["Tag 1", "Tag 2"], + "output": "Sample output", + "main_points": ["Main point 1", "Main point 2"], + }, + ] + + +@pytest.fixture +def config(): + return dict( + max_hidden_ratio=.5, + max_hidden_words=.1, + min_chars_length=2, + max_chars_length=8, + ) diff --git a/tests/test_instruction_template.py b/tests/test_instruction_template.py new file mode 100644 index 0000000..1765ff2 --- /dev/null +++ b/tests/test_instruction_template.py @@ -0,0 +1,52 @@ +from gemma_template import gemma_template + +INSTRUCTION_TEMPLATE = """# Role: +You are a highly skilled professional content writer, linguistic analyst, and multilingual expert specializing in structured writing and advanced text processing. + +# Task: +Your primary objectives are: +1. Your primary task is to rewrite the provided content into a more structured, professional format that maintains its original intent and meaning. +2. Enhance vocabulary comprehension by analyzing text with unigrams (single words), bigrams (two words), and trigrams (three words). +3. Ensure your response adheres strictly to the prescribed structure format. +4. Respond in the primary language of the input text unless alternative instructions are explicitly given. + +# Additional Expectations: +1. Provide a rewritten, enhanced version of the input text, ensuring professionalism, clarity, and improved structure. +2. Focus on multilingual proficiency, using complex vocabulary, grammar to improve your responses. +3. Preserve the context and cultural nuances of the original text when rewriting. + +Topics: {topic_values} +Keywords: {keyword_values} + +# Text Analysis: +Example 1: Unigrams (single words) +{unigrams} +Text Analysis 3: These are common {language} words, indicating the text is in {language}. + +Example 2: Bigrams (two words) +{bigrams} +Text Analysis 2: Frequent bigrams in Vietnamese confirm the language context. + +Example 3: Trigrams (three words) +{trigrams} +Text Analysis 3: Trigrams further validate the linguistic analysis and the necessity to respond in {language}. + +# Conclusion of Text Analysis: +The linguistic analysis confirms the text is predominantly in {language}. Consequently, the response should be structured and written in {language} to align with the original text and context. +""" # noqa: E501 + + +def test_instruction_template(data_items, config): + template = gemma_template.template(instruction_template=INSTRUCTION_TEMPLATE, **data_items[0], **config) + assert "You are a highly skilled professional content writer, linguistic analyst, and multilingual expert specializing in structured writing and advanced text processing." in template + + +def test_instruction_template_function(data_items, config): + def instruction_fn( + fn, + **instruction_kwargs, + ): + return "### INSTRUCTION TEST" + + template_fn = gemma_template.template(instruction_template=instruction_fn, **data_items[0], **config) + assert "### INSTRUCTION TEST" in template_fn diff --git a/tests/test_load_dataset.py b/tests/test_load_dataset.py new file mode 100644 index 0000000..eea55e0 --- /dev/null +++ b/tests/test_load_dataset.py @@ -0,0 +1,48 @@ +from datasets import Dataset + +from gemma_template import gemma_template + + +def assert_dataset_equal(ds, input_field: str = 'text', output_field: str = 'text'): + a, b = ds[0], ds[1] + assert input_field in a + assert output_field in a + assert input_field in b + assert output_field in b + assert a["is_masked"] is True + assert b["is_masked"] is False + assert "_____" in a[input_field] + assert len(ds) == 2 + + +def test_load_dataset_from_dict(data_items, config): + text_ds = gemma_template.load_dataset(data_items, output_format='text', **config) + assert_dataset_equal(text_ds, "text", "text") + alpaca_ds = gemma_template.load_dataset(data_items, output_format='alpaca', **config) + assert_dataset_equal(alpaca_ds, "input", "output") + gpt_ds = gemma_template.load_dataset(data_items, output_format='gpt', **config) + assert_dataset_equal(gpt_ds, "human", "gpt") + + +def test_load_dataset_from_Dataset(data_items, config): + dataset = Dataset.from_list(data_items) + text_ds = gemma_template.load_dataset(dataset, output_format='text', **config) + assert_dataset_equal(text_ds, "text", "text") + alpaca_ds = gemma_template.load_dataset(dataset, output_format='alpaca', **config) + assert_dataset_equal(alpaca_ds, "input", "output") + gpt_ds = gemma_template.load_dataset(dataset, output_format='gpt', **config) + assert_dataset_equal(gpt_ds, "human", "gpt") + + +def test_load_dataset_from_DatasetDict(data_items, config): + dataset = Dataset.from_list(data_items * 2) + dataset = dataset.train_test_split(test_size=0.5) + text_ds = gemma_template.load_dataset(dataset, output_format='text', **config) + assert_dataset_equal(text_ds["train"], "text", "text") + assert_dataset_equal(text_ds["test"], "text", "text") + alpaca_ds = gemma_template.load_dataset(dataset, output_format='alpaca', **config) + assert_dataset_equal(alpaca_ds["train"], "input", "output") + assert_dataset_equal(alpaca_ds["test"], "input", "output") + gpt_ds = gemma_template.load_dataset(dataset, output_format='gpt', **config) + assert_dataset_equal(gpt_ds["train"], "human", "gpt") + assert_dataset_equal(gpt_ds["test"], "human", "gpt") diff --git a/tests/test_structure_template.py b/tests/test_structure_template.py new file mode 100644 index 0000000..7e4ef8d --- /dev/null +++ b/tests/test_structure_template.py @@ -0,0 +1,54 @@ +from gemma_template import StructureField, Template, gemma_template + +STRUCTURE_TEMPLATE = """# Response Structure Format: +You must follow the response structure: +{structure_template} + +By adhering to this format, the response will maintain linguistic integrity while enhancing professionalism, structure and alignment with user expectations. +""" # noqa: E501 + + +def test_structure_template(data_items, config): + template = gemma_template.template(structure_template=STRUCTURE_TEMPLATE, **data_items[0], **config) + assert "# Response Structure Format" in template + + +def test_structure_template_function(data_items, config): + def structure_fn( + fn, + **instruction_kwargs, + ): + return "### STRUCTURE TEST" + + template_fn = gemma_template.template(structure_template=structure_fn, **data_items[0], **config) + assert "### STRUCTURE TEST" in template_fn + + +def test_fully_custom_structure_template(data_items, config): + def instruction_fn( + fn, + **instruction_kwargs, + ): + return "### INSTRUCTION TEST" + + prompt_instance = Template( + structure_field=StructureField( + title=["Custom Title"], + description=["Custom Description"], + document=["Custom Article"], + main_points=["Custom Main Points"], + categories=["Custom Categories"], + tags=["Custom Tags"], + ), + ) + + response = prompt_instance.template( + instruction_template=instruction_fn, + structure_template=STRUCTURE_TEMPLATE, + **data_items[0], + **config + ) + + assert "### INSTRUCTION TEST" in response + assert "Custom Title" in response + assert "Custom Description" in response