diff --git a/src/sparseml/pytorch/utils/logger.py b/src/sparseml/pytorch/utils/logger.py index 9d7cc4e3c2d..66ea8e920ac 100644 --- a/src/sparseml/pytorch/utils/logger.py +++ b/src/sparseml/pytorch/utils/logger.py @@ -534,7 +534,14 @@ def __init__( init_kwargs: Optional[Dict] = None, name: str = "wandb", enabled: bool = True, + wandb_err: Optional[Exception] = wandb_err, ): + if wandb_err: + raise ModuleNotFoundError( + "Error: Failed to import wandb. " + "Please install the wandb library in order to use it." + ) from wandb_err + super().__init__( lambda_func=self._log_lambda, name=name,