diff --git a/auto_fp8/modeling.py b/auto_fp8/modeling.py index 69ca3b8..340a598 100644 --- a/auto_fp8/modeling.py +++ b/auto_fp8/modeling.py @@ -71,10 +71,10 @@ def skip(*args, **kwargs): torch.cuda.empty_cache() # Important defaults - if not hasattr(model_init_kwargs, "torch_dtype"): + if "torch_dtype" not in model_init_kwargs: model_init_kwargs["torch_dtype"] = "auto" - if not hasattr(model_init_kwargs, "device_map"): + if "device_map" not in model_init_kwargs: model_init_kwargs["device_map"] = "auto" merged_kwargs = {**model_init_kwargs, **cached_file_kwargs}