Skip to content

Commit

Permalink
lots of small improvements: add pyo3 log, add dynamic gradient clippe…
Browse files Browse the repository at this point in the history
…r, removed old stuff
  • Loading branch information
bastiscode committed Dec 9, 2024
1 parent b6e14f6 commit 4686635
Show file tree
Hide file tree
Showing 14 changed files with 232 additions and 221 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ crate-type = ["cdylib", "rlib"]

[dependencies]
log = "0.4"
env_logger = "0.11"
rayon = "1.8"
indicatif = { version = "0.17", features = ["rayon"] }
pyo3 = { version = "0.22", features = ["anyhow", "extension-module", "abi3-py310", "auto-initialize"]}
pyo3-log = "0.11"
itertools = "0.13"
rand = "0.8"
rand_distr = "0.4"
Expand All @@ -29,7 +29,7 @@ anyhow = "1.0"
num_cpus = "1.14"
numpy = "0.22"
clap = { version = "4", features = ["derive"] }
tokenizers = "0.20"
tokenizers = "0.21"
lru = "0.12"

[dev-dependencies]
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ dependencies = [
]

[project.scripts]
"tu.create_continuation_index" = "text_utils.cli.create_continuation_index:main"
"tu.create_dictionary" = "text_utils.cli.create_dictionary:main"
"tu.train_bpe" = "text_utils.cli.train_bpe:main"

Expand Down
22 changes: 15 additions & 7 deletions python/text_utils/api/processor.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import collections
import math
import sys
import os
import pprint
from typing import Iterator, Any, Callable
import sys
from typing import Any, Callable, Iterator

import torch
from torch import nn
from torch.backends import cudnn, cuda
from torch.backends import cuda, cudnn

from text_utils import (
api,
logging,
configuration,
io,
data,
io,
logging,
)
from text_utils.api.utils import Device, get_devices

Expand Down Expand Up @@ -218,11 +218,14 @@ def _process(
for item, output in zip(batch.items(), outputs):
if item.item_idx not in results:
results[item.item_idx] = {}
if progress_unit == "it":
pbar.update(1)

if progress_unit == "byte":
pbar.update(item.window_bytes())

results[item.item_idx][item.window_idx] = (item, output)
if progress_unit == "it":
pbar.update(1)

outputs = []
for item_idx in range(len(results)):
window_items = []
Expand All @@ -231,6 +234,7 @@ def _process(
item, output = results[item_idx][window_idx]
window_items.append(item)
window_outputs.append(output)

yield postprocessing_fn(window_items, window_outputs)

else:
Expand All @@ -246,14 +250,18 @@ def _process(
window_items.append(item)
window_outputs.append(output)
continue

yield postprocessing_fn(window_items, window_outputs)
if progress_unit == "byte":
pbar.update(sum(item.window_bytes() for item in window_items))

prev_item_idx = item.item_idx
window_items = [item]
window_outputs = [output]

if progress_unit == "it":
pbar.update(1)

# dont forget to yield final item
yield postprocessing_fn(window_items, window_outputs)

Expand Down
81 changes: 47 additions & 34 deletions python/text_utils/api/trainer.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,48 @@
import argparse
import math
import copy
import functools
import sys
import random
import os
import hashlib
import math
import os
import random
import shutil
import sys
import time
import zipfile
from typing import Any, Callable

from text_utils.api.utils import get_gradient_clipper
import torch
from torch.backends import cuda, cudnn
import yaml
from torch import distributed as dist
from torch import multiprocessing as mp
from torch import nn
from torch.optim import lr_scheduler
from torch.backends import cudnn, cuda # noqa
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel as FSDP,
)
from torch.backends import cuda, cudnn # noqa
from torch.distributed.fsdp.api import (
BackwardPrefetch,
CPUOffload,
FullOptimStateDictConfig,
FullStateDictConfig,
MixedPrecision,
ShardingStrategy,
StateDictType,
FullStateDictConfig,
FullOptimStateDictConfig,
CPUOffload,
BackwardPrefetch,
)
from torch.distributed.fsdp.fully_sharded_data_parallel import (
FullyShardedDataParallel as FSDP,
)
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim import lr_scheduler
from torch.utils.tensorboard.writer import SummaryWriter
import yaml

from text_utils import api, configuration, data, distributed, io, logging, tensorboard
from text_utils.modules.loss import loss_from_config
from text_utils.modules.optimizer import optimizer_from_config
from text_utils.modules.scheduler import (
lr_scheduler_from_config,
max_length_scheduler_from_config,
)
from text_utils.modules.optimizer import optimizer_from_config
from text_utils import distributed, data, configuration, io, logging, api, tensorboard


def clamp(v: float, minimum: int, maximum: int) -> int:
Expand Down Expand Up @@ -140,8 +140,6 @@ def __init__(
)
model = self._prepare_peft(model, peft)

sharding_policy = self._sharding_policy(model)

mixed_precision = self.cfg["train"].get("mixed_precision", None)
if mixed_precision == "fp16":
self.mixed_precision = torch.float16
Expand Down Expand Up @@ -180,6 +178,7 @@ def __init__(
gradient_as_bucket_view=True,
)
else:
sharding_policy = self._sharding_policy(model)
offload_params = dist_cfg.get("offload", False)
prefetch = dist_cfg.get("prefetch", True)
strategy = ShardingStrategy[dist_cfg.get("strategy", "NO_SHARD")]
Expand Down Expand Up @@ -245,9 +244,12 @@ def __init__(
additional_optimizer_fn=self._additional_optimizer_fn(),
)

self.clip_gradient_norm: float | None = self.cfg["train"].get(
"clip_gradient_norm", None
)
gradient_clipping_cfg = self.cfg["train"].get("gradient_clipping", None)
if gradient_clipping_cfg:
self.gradient_clipper = get_gradient_clipper(gradient_clipping_cfg)
else:
self.gradient_clipper = None

gradient_accumulation_cfg = self.cfg["train"].get("gradient_accumulation", {})
self.gradient_accumulation_steps = gradient_accumulation_cfg.get("steps", 1)
self.gradient_accumulation_reduction = gradient_accumulation_cfg.get(
Expand Down Expand Up @@ -992,6 +994,9 @@ def _train_one_epoch(self):
mean_item_size_ratio = tensorboard.DistAverageTracker(
"train_item_size_ratio", self.info.device
)
mean_peak_gpu_memory = tensorboard.DistAverageTracker(
"train_peak_gpu_memory", self.info.device
)
total_batch_size = torch.zeros(1, dtype=torch.long, device=self.info.device)
min_num_batches = torch.zeros(1, dtype=torch.long, device=self.info.device)

Expand Down Expand Up @@ -1105,14 +1110,15 @@ def step(
if self.info.is_main_process:
mean_grad_norm.add(grad_norm)

if self.clip_gradient_norm is not None:
if self.gradient_clipper is not None:
self.gradient_clipper.add_norm(grad_norm)
clip_norm = self.gradient_clipper.get_norm()

self.grad_scaler.unscale_(self.optimizer)
if isinstance(self.model, FSDP):
self.model.clip_grad_norm_(self.clip_gradient_norm)
self.model.clip_grad_norm_(clip_norm)
else:
torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.clip_gradient_norm
)
torch.nn.utils.clip_grad_norm_(self.model.parameters(), clip_norm)

self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()
Expand All @@ -1127,6 +1133,9 @@ def step(
mean_batch_size.add(batch_items)

mean_loss.add(sum(losses))
mean_peak_gpu_memory.add(
torch.cuda.max_memory_allocated(self.info.device) / 1024**3
)

if self.total_items >= self.step_at:
lr_scheduler = self.cooldown_scheduler or self.lr_scheduler
Expand All @@ -1152,6 +1161,7 @@ def step(
mean_item_size.sync()
mean_item_size_ratio.sync()
mean_batch_preparation.sync()
mean_peak_gpu_memory.sync()
end = time.perf_counter()

if self.info.is_main_process:
Expand Down Expand Up @@ -1210,6 +1220,11 @@ def step(
)
mean_item_size_ratio.log_info(self.logger, self.total_step)

mean_peak_gpu_memory.log_tensorboard(
self.summary_writer, self.total_step
)
mean_peak_gpu_memory.log_info(self.logger, self.total_step)

items = batches[0].items()
for metric in metrics:
metric.set_values(items, first_outputs)
Expand All @@ -1229,12 +1244,6 @@ def step(
f"[step {self.total_step}] [epoch {self.epoch + 1}] {eta_msg}"
)

if self.info.is_local_main_process:
self.logger.info(
f"[step {self.total_step}] [rank {self.info.rank}] nvidia-smi:\n"
f"{api.nvidia_smi()}"
)

start = end
mean_loss.reset()
mean_grad_norm.reset()
Expand All @@ -1244,6 +1253,8 @@ def step(
mean_item_size.reset()
mean_item_size_ratio.reset()
mean_batch_preparation.reset()
mean_peak_gpu_memory.reset()
torch.cuda.reset_peak_memory_stats(self.info.device)
self.log_at += self.log_interval

if (
Expand Down Expand Up @@ -1277,6 +1288,8 @@ def step(
mean_item_size.reset()
mean_item_size_ratio.reset()
mean_batch_preparation.reset()
mean_peak_gpu_memory.reset()
torch.cuda.reset_peak_memory_stats(self.info.device)
start = time.perf_counter()

self.eval_at += self.eval_interval
Expand Down
Loading

0 comments on commit 4686635

Please sign in to comment.