diff --git a/.github/workflows/cpu-tests.yml b/.github/workflows/cpu-tests.yml index 2f23c0f8d6..59a3ce3075 100644 --- a/.github/workflows/cpu-tests.yml +++ b/.github/workflows/cpu-tests.yml @@ -16,7 +16,6 @@ defaults: env: HF_TOKEN: ${{ secrets.HF_TOKEN }} - UV_HTTP_TIMEOUT: 500 jobs: cpu-tests: diff --git a/README.md b/README.md index 6f1e740375..be4cc38696 100644 --- a/README.md +++ b/README.md @@ -170,6 +170,7 @@ Use, Finetune, pretrain, deploy over 20+ LLMs ([full list](tutorials/download_mo | Model | Model size | Author | Reference | |----|----|----|----| +| CodeGemma | 7B | Google | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) | | Code Llama | 7B, 13B, 34B, 70B | Meta AI | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) | | Dolly | 3B, 7B, 12B | Databricks | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) | | Falcon | 7B, 40B, 180B | TII UAE | [TII 2023](https://falconllm.tii.ae) | diff --git a/extensions/thunder/pretrain.py b/extensions/thunder/pretrain.py index fe9c798dc0..24f140c9df 100644 --- a/extensions/thunder/pretrain.py +++ b/extensions/thunder/pretrain.py @@ -26,6 +26,7 @@ from litgpt.utils import ( CLI, CycleIterator, + capture_hparams, choose_logger, chunked_cross_entropy, copy_config_files, @@ -97,7 +98,7 @@ def setup( executors: If using Thunder, the executors to enable. strategy: If desired, the strategy to use. """ - hparams = locals() + hparams = capture_hparams() data = TinyLlama() if data is None else data if model_config is not None and model_name is not None: raise ValueError("Only one of `model_name` or `model_config` can be set.") diff --git a/litgpt/config.py b/litgpt/config.py index caad1454b9..0a4234222d 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -888,6 +888,32 @@ def norm_class(self) -> Type: copy["hf_config"]["name"] = f"{c['hf_config']['name']}-it" configs.append(copy) +################## +# Google CodeGemma +################## +codegemma = [ + # https://huggingface.co/google/codegemma-7b-it/blob/main/config.json + dict( + name="CodeGemma-7b-it", + hf_config=dict(org="google", name="codegemma-7b-it"), + scale_embeddings=True, + vocab_size=256000, + padding_multiple=64, + n_embd=3072, + n_layer=28, + n_head=16, + head_size=256, + rotary_percentage=1.0, + parallel_residual=False, + bias=False, + norm_class_name="RMSNorm", + mlp_class_name="GemmaMLP", + gelu_approximate="tanh", + intermediate_size=24576, + ), +] +configs.extend(codegemma) + ########################## # Stability AI FreeWilly2 diff --git a/litgpt/pretrain.py b/litgpt/pretrain.py index 4ab31b414e..f75a93d8c6 100644 --- a/litgpt/pretrain.py +++ b/litgpt/pretrain.py @@ -26,6 +26,7 @@ from litgpt.utils import ( CLI, CycleIterator, + capture_hparams, choose_logger, chunked_cross_entropy, copy_config_files, @@ -87,7 +88,7 @@ def setup( logger_name: The name of the logger to send metrics to. seed: The random seed to use for reproducibility. """ - hparams = locals() + hparams = capture_hparams() data = TinyLlama() if data is None else data if model_config is not None and model_name is not None: raise ValueError("Only one of `model_name` or `model_config` can be set.") diff --git a/litgpt/prompts.py b/litgpt/prompts.py index d1266c731b..df1a7150b6 100644 --- a/litgpt/prompts.py +++ b/litgpt/prompts.py @@ -251,7 +251,7 @@ def stop_tokens(self, tokenizer: "Tokenizer") -> Tuple[List[int], ...]: class Phi2(PromptStyle): def apply(self, prompt: str, **kwargs: str) -> str: - return f"Instruct:{prompt}\nOutput:" + return f"Instruct: {prompt}\nOutput:" class TinyLlama(PromptStyle): @@ -330,7 +330,7 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle: return Phi2() if re.search(r"tiny-llama.*chat", model_name): return TinyLlama() - if re.search(r"Gemma.*-it", model_name): + if re.search(r"(Code)?Gemma.*-it", model_name): return Gemma() return Default() diff --git a/litgpt/utils.py b/litgpt/utils.py index fb6a86c107..37ebdfd6f9 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -1,11 +1,12 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. """Utility functions for training and inference.""" +import inspect import math import pickle import shutil import sys -from dataclasses import asdict +from dataclasses import asdict, is_dataclass from io import BytesIO from pathlib import Path from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Literal, Mapping, Optional, TypeVar, Union @@ -404,6 +405,21 @@ def CLI(*args: Any, **kwargs: Any) -> Any: return CLI(*args, **kwargs) +def capture_hparams() -> Dict[str, Any]: + """Captures the local variables ('hyperparameters') from where this function gets called.""" + caller_frame = inspect.currentframe().f_back + locals_of_caller = caller_frame.f_locals + hparams = {} + for name, value in locals_of_caller.items(): + if value is None or isinstance(value, (int, float, str, bool, Path)): + hparams[name] = value + elif is_dataclass(value): + hparams[name] = asdict(value) + else: + hparams[name] = str(value) + return hparams + + def save_hyperparameters(function: callable, checkpoint_dir: Path) -> None: """Captures the CLI parameters passed to `function` without running `function` and saves them to the checkpoint.""" from jsonargparse import capture_parser diff --git a/tests/test_utils.py b/tests/test_utils.py index 63caf4158a..d76ae98056 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,5 @@ # Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file. +from dataclasses import asdict import os from contextlib import redirect_stderr @@ -18,9 +19,11 @@ from lightning_utilities.core.imports import RequirementCache from litgpt import GPT +from litgpt.args import TrainArgs from litgpt.utils import ( CLI, CycleIterator, + capture_hparams, check_valid_checkpoint_dir, choose_logger, chunked_cross_entropy, @@ -219,6 +222,26 @@ def test_copy_config_files(fake_checkpoint_dir, tmp_path): assert expected.issubset(contents) +def test_capture_hparams(): + integer = 1 + string = "string" + boolean = True + none = None + path = Path("/path") + dataclass = TrainArgs() + other = torch.nn.Linear(1, 1) + hparams = capture_hparams() + assert hparams == { + "integer": integer, + "string": string, + "boolean": boolean, + "none": none, + "path": path, + "dataclass": asdict(dataclass), + "other": str(other), + } + + def _test_function(out_dir: Path, foo: bool = False, bar: int = 1): save_hyperparameters(_test_function, out_dir) diff --git a/tutorials/download_model_weights.md b/tutorials/download_model_weights.md index 55c214a01c..b91afa5929 100644 --- a/tutorials/download_model_weights.md +++ b/tutorials/download_model_weights.md @@ -5,6 +5,7 @@ LitGPT supports a variety of LLM architectures with publicly available weights. | Model | Model size | Reference | |----------------------------------------------|------------------------------------------|------------------------------------------------------------------------------------------------------------------------------| +| CodeGemma by Google | 7B | [Google Team, Google Deepmind](https://ai.google.dev/gemma/docs/codegemma) | | Code Llama by Meta AI | 7B, 13B, 34B, 70B | [Rozière et al. 2023](https://arxiv.org/abs/2308.12950) | | Dolly by Databricks | 3B, 7B, 12B | [Conover et al. 2023](https://www.databricks.com/blog/2023/04/12/dolly-first-open-commercially-viable-instruction-tuned-llm) | | Falcon by TII UAE | 7B, 40B, 180B | [TII 2023](https://falconllm.tii.ae) | @@ -84,6 +85,7 @@ garage-bAInd/Platypus2-70B garage-bAInd/Platypus2-70B-instruct garage-bAInd/Platypus2-7B garage-bAInd/Stable-Platypus2-13B +google/codegemma-7b-it google/gemma-2b google/gemma-2b-it google/gemma-7b