diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 4e9839b9..7268ca27 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -28,6 +28,7 @@ apply_quantization_status, ) from compressed_tensors.quantization.utils import iter_named_leaf_modules +from tests.testing_utils import requires_accelerate from transformers import AutoModelForCausalLM @@ -224,6 +225,7 @@ def get_sample_tinyllama_quant_config(status: str = "frozen"): return QuantizationConfig.parse_obj(config_dict) +@requires_accelerate() @pytest.mark.parametrize( "ignore,should_raise_warning", [ diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 2e9be7cf..e446cad3 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -26,8 +26,29 @@ def compressed_tensors_config_available(): return False +def accelerate_availabe(): + try: + import accelerate # noqa: F401 + + return True + + except ImportError: + return False + + +_is_compressed_tensors_config_available = compressed_tensors_config_available() +_is_accelerate_available = accelerate_availabe() + + def requires_hf_quantizer(): return pytest.mark.skipif( - not compressed_tensors_config_available(), + not _is_compressed_tensors_config_available, reason="requires transformers>=4.45 to support CompressedTensorsHfQuantizer", ) + + +def requires_accelerate(): + return pytest.mark.skipif( + not _is_accelerate_available, + reason="requires accelerate", + )