Skip to content

Commit

Permalink
Add Ascend NPU as a backend for single device recipes
Browse files Browse the repository at this point in the history
fix lint
  • Loading branch information
xinyanhe authored and Nicorgi committed Jan 14, 2025
1 parent 213f386 commit 6c1951c
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 22 deletions.
8 changes: 5 additions & 3 deletions recipes/dev/generate_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,11 @@ def log_metrics(self, total_time: int, tokens_per_second: float) -> None:
self._logger.info(
f"Bandwidth achieved: {model_size * tokens_per_second / 1e9:.02f} GB/s"
)
self._logger.info(
f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB"
)
if self._device.type != "cpu":
torch_device = utils.get_torch_device_namespace()
self._logger.info(
f"Max memory allocated: {torch_device.max_memory_allocated() / 1e9:.02f} GB"
)

@torch.inference_mode()
def generate(self, cfg: DictConfig):
Expand Down
8 changes: 5 additions & 3 deletions recipes/eleuther_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,9 +547,11 @@ def evaluate(self) -> None:

# Log metrics
self.logger.info(f"Eval completed in {t1:.02f} seconds.")
self.logger.info(
f"Max memory allocated: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB"
)
if self.device.type != "cpu":
torch_device = utils.get_torch_device_namespace()
self.logger.info(
f"Max memory allocated: {torch_device.max_memory_allocated() / 1e9:.02f} GB"
)
formatted_output = make_table(output)
self.logger.info(f"\n\n{formatted_output}\n")

Expand Down
4 changes: 2 additions & 2 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,9 @@ def __init__(self, cfg: DictConfig) -> None:
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)

if self._log_peak_memory_stats and self._device.type != "cuda":
if self._log_peak_memory_stats and self._device.type == "cpu":
log.info(
"log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False."
"log_peak_memory_stats was set to True, however, training uses cpu. Setting log_peak_memory_stats=False."
)
self._log_peak_memory_stats = False

Expand Down
4 changes: 3 additions & 1 deletion recipes/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ def generate(self, cfg: DictConfig):
f"Time for inference: {t:.02f} sec total, {tokens_sec:.02f} tokens/sec"
)
logger.info(f"Bandwidth achieved: {model_size * tokens_sec / 1e9:.02f} GB/s")
logger.info(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
if self._device.type != "cpu":
torch_device = utils.get_torch_device_namespace()
logger.info(f"Memory used: {torch_device.max_memory_allocated() / 1e9:.02f} GB")


@config.parse
Expand Down
14 changes: 9 additions & 5 deletions recipes/knowledge_distillation_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ def __init__(self, cfg: DictConfig) -> None:
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)

if self._log_peak_memory_stats and self._device.type != "cuda":
if self._log_peak_memory_stats and self._device.type == "cpu":
log.info(
"log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False."
"log_peak_memory_stats was set to True, however, training uses cpu. Setting log_peak_memory_stats=False."
)
self._log_peak_memory_stats = False

Expand Down Expand Up @@ -223,6 +223,10 @@ def setup(self, cfg: DictConfig) -> None:
self._metric_logger.log_config(cfg)

self._compile = cfg.compile
if cfg.device == "npu" and cfg.compile:
raise ValueError(
"NPU does not support model compilation. Please set `compile: False` in the config."
)
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
teacher_checkpoint_dict = self.load_teacher_checkpoint(
cfg_checkpointer=cfg.teacher_checkpointer
Expand Down Expand Up @@ -447,7 +451,7 @@ def _setup_model(

log.info(f"Student model is initialized with precision {self._dtype}.")

if self._device.type == "cuda":
if self._device.type != "cpu":
log.info("Memory stats initializing student model:")
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(
Expand Down Expand Up @@ -476,7 +480,7 @@ def _setup_teacher_model(
)
log.info(f"Teacher model is initialized with precision {self._dtype}.")

if self._device.type == "cuda":
if self._device.type != "cpu":
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(
memory_stats, message="Memory stats after teacher model init:"
Expand Down Expand Up @@ -753,7 +757,7 @@ def train(self) -> None:
"tokens_per_second_per_gpu": num_tokens / time_per_step,
}
if (
self._device.type == "cuda"
self._device.type != "cpu"
and self._log_peak_memory_stats
):
log_dict.update(
Expand Down
6 changes: 3 additions & 3 deletions recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,9 @@ def __init__(self, cfg: DictConfig) -> None:
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)

if self._log_peak_memory_stats and self._device.type != "cuda":
if self._log_peak_memory_stats and self._device.type == "cpu":
log.info(
"log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False."
"log_peak_memory_stats was set to True, however, training uses cpu. Setting log_peak_memory_stats=False."
)
self._log_peak_memory_stats = False

Expand Down Expand Up @@ -327,7 +327,7 @@ def _setup_model(
# Compile model, if enabled.
if compile_model:
training.compile_model(model)
if self._device == torch.device("cuda"):
if self._device.type != "cpu":
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(memory_stats)
return model
Expand Down
6 changes: 3 additions & 3 deletions recipes/lora_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ def __init__(self, cfg: DictConfig) -> None:
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
self._log_peak_memory_stats = cfg.get("log_peak_memory_stats", False)

if self._log_peak_memory_stats and self._device.type != "cuda":
if self._log_peak_memory_stats and self._device.type == "cpu":
log.info(
"log_peak_memory_stats was set to True, however, training does not use cuda. Setting log_peak_memory_stats=False."
"log_peak_memory_stats was set to True, however, training uses cpu. Setting log_peak_memory_stats=False."
)
self._log_peak_memory_stats = False

Expand Down Expand Up @@ -735,7 +735,7 @@ def train(self) -> None:
"tokens_per_second_per_gpu": num_tokens / time_per_step,
}
if (
self._device.type == "cuda"
self._device.type != "cpu"
and self._log_peak_memory_stats
):
log_dict.update(
Expand Down
4 changes: 3 additions & 1 deletion recipes/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def quantize(self, cfg: DictConfig):
self._model = self._quantizer.quantize(self._model)
t = time.perf_counter() - t0
logger.info(f"Time for quantization: {t:.02f} sec")
logger.info(f"Memory used: {torch.cuda.max_memory_allocated() / 1e9:.02f} GB")
if self._device.type != "cpu":
torch_device = utils.get_torch_device_namespace()
logger.info(f"Memory used: {torch_device.max_memory_allocated() / 1e9:.02f} GB")

def save_checkpoint(self, cfg: DictConfig):
ckpt_dict = self._model.state_dict()
Expand Down
3 changes: 2 additions & 1 deletion torchtune/training/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def _set_float32_precision(precision: str = "high") -> None:
Args:
precision (str): The setting to determine which datatypes to use for matrix multiplication and convolution operations.
"""
if not torch.cuda.is_available(): # Not relevant for non-CUDA devices
# Not relevant for non-CUDA or non-NPU devices
if not torch.cuda.is_available() or not is_npu_available:
return
# set precision for matrix multiplications
torch.set_float32_matmul_precision(precision)
Expand Down

0 comments on commit 6c1951c

Please sign in to comment.