From 3798229d85e251eaf8611cd9ed1bf63101c1e58b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 17 Dec 2024 16:42:41 -0500 Subject: [PATCH] handle torch_compile set to auto (#2172) [skip ci] * handle torch_compile set to auto * update docs [skip ci] * add tests --- docs/config.qmd | 3 +- src/axolotl/utils/config/__init__.py | 4 +- .../config/models/input/v0_4_1/__init__.py | 21 +++++++++- tests/patched/test_validation.py | 40 +++++++++++++++++++ 4 files changed, 64 insertions(+), 4 deletions(-) diff --git a/docs/config.qmd b/docs/config.qmd index ba23384f0c..d52170959d 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -337,7 +337,8 @@ comet_experiment_config: # Dictionary for additional configuration settings, see output_dir: ./completed-model # Whether to use torch.compile and which backend to use -torch_compile: # bool +# setting to `auto` will enable torch compile when torch>=2.5.1 +torch_compile: # Optional[Union[Literal["auto"], bool]] torch_compile_backend: # Optional[str] # Training hyperparameters diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 30ba53ad25..c23359f34d 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -245,8 +245,8 @@ def validate_config( ) = merge_input_args() if capabilities or env_capabilities: - if (capabilities and not env_capabilities) or ( - env_capabilities and not capabilities + if (capabilities and env_capabilities is None) or ( + env_capabilities and capabilities is None ): raise ValueError( "Both capabilities and env_capabilities must be provided or not provided." diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index d05de2330d..69baf9af2b 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -741,7 +741,7 @@ class Config: special_tokens: Optional[SpecialTokensConfig] = None tokens: Optional[List[str]] = None - torch_compile: Optional[bool] = None + torch_compile: Optional[Union[Literal["auto"], bool]] = None torch_compile_backend: Optional[str] = None torch_compile_mode: Optional[ Literal["default", "reduce-overhead", "max-autotune"] @@ -1582,3 +1582,22 @@ def check_adopt_torch_version(cls, data): "ADOPT optimizer is incompatible with torch version < 2.5.1" ) return data + + @model_validator(mode="before") + @classmethod + def check_torch_compile_auto(cls, data): + if data.get("torch_compile") == "auto": + env_capabilities = data.get("env_capabilities", {}) + if env_capabilities.get("torch_version"): + if version.parse( + env_capabilities.get("torch_version") + ) >= version.parse("2.5.1"): + LOG.info( + "torch.compile is available, setting torch_compile to True" + ) + data["torch_compile"] = True + else: + data["torch_compile"] = False + else: + data["torch_compile"] = False + return data diff --git a/tests/patched/test_validation.py b/tests/patched/test_validation.py index 2e6fbab101..3d1b74789d 100644 --- a/tests/patched/test_validation.py +++ b/tests/patched/test_validation.py @@ -1196,6 +1196,46 @@ def test_torch_version_adopt_req(self, minimal_cfg): ) +class TestTorchCompileValidation(BaseValidation): + """ + test suite for when torch_compile is set to 'auto' + """ + + def test_torch_compile_auto(self, minimal_cfg): + cfg = ( + DictDefault( + { + "torch_compile": "auto", + } + ) + | minimal_cfg + ) + + env_capabilities = {"torch_version": "2.5.1"} + capabilities = {"bf16": True} + updated_cfg = validate_config( + cfg, capabilities=capabilities, env_capabilities=env_capabilities + ) + + assert updated_cfg.torch_compile is True + + env_capabilities = {"torch_version": "2.4.1"} + capabilities = {"bf16": True} + updated_cfg = validate_config( + cfg, capabilities=capabilities, env_capabilities=env_capabilities + ) + + assert updated_cfg.torch_compile is False + + env_capabilities = {} + capabilities = {"bf16": True} + updated_cfg = validate_config( + cfg, capabilities=capabilities, env_capabilities=env_capabilities + ) + + assert updated_cfg.torch_compile is False + + class TestValidationCheckModelConfig(BaseValidation): """ Test the validation for the config when the model config is available