Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: custom [megatron] nvidia dmc loader #39

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dyana/loaders/megatron/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
dyana.py
dyana-requirements.txt
dyana-requirements-gpu.txt
50 changes: 50 additions & 0 deletions dyana/loaders/megatron/Dockerfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
FROM nvcr.io/nvidia/pytorch:24.04-py3

WORKDIR /app

# Install system dependencies
RUN apt-get update && \
apt-get install -y --no-install-recommends \
git \
ca-certificates \
build-essential \
&& rm -rf /var/lib/apt/lists/*

# Create required directories for multiprocessing
RUN mkdir -p /dev/shm && \
mkdir -p /tmp/pytorch_extensions && \
mkdir -p /run/shm && \
chmod -R 777 /dev/shm /tmp/pytorch_extensions /run/shm

# Create ALL required directories for IPC and shared memory
RUN mkdir -p /dev/shm && \
mkdir -p /run/shm && \
mkdir -p /tmp/pytorch_extensions && \
mkdir -p /tmp/.pytorch_jit_cache && \
mkdir -p /tmp/transformers && \
chmod -R 777 /dev/shm /run/shm /tmp/pytorch_extensions /tmp/.pytorch_jit_cache /tmp/transformers

# Only verify PyTorch version during build (not CUDA)
RUN python3 -c "import torch; print(f'PyTorch version: {torch.__version__}')"

# Create working directory
RUN mkdir -p /app/workspace

# Copy files in correct order
COPY requirements.txt /app/workspace/
COPY *.py /app/workspace/
COPY dyana-requirements*.txt /app/workspace/

WORKDIR /app/workspace

# Install dependencies
RUN pip install --no-cache-dir -r requirements.txt

# Install Megatron-LM
RUN git clone --depth 1 --branch dmc https://github.com/NVIDIA/Megatron-LM.git /app/Megatron-LM && \
cd /app/Megatron-LM && \
pip install -e .

ENV PYTHONPATH=/app/workspac:/app/Megatron-LM:$PYTHONPATH

ENTRYPOINT ["python3", "-W", "ignore", "main.py"]
155 changes: 155 additions & 0 deletions dyana/loaders/megatron/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import argparse
import logging
import sys
import warnings
import multiprocessing
from pathlib import Path

logging.basicConfig(level=logging.ERROR)

Check failure on line 8 in dyana/loaders/megatron/main.py

View workflow job for this annotation

GitHub Actions / Validate (3.9)

Ruff (I001)

dyana/loaders/megatron/main.py:1:1: I001 Import block is un-sorted or un-formatted

Check failure on line 8 in dyana/loaders/megatron/main.py

View workflow job for this annotation

GitHub Actions / Validate (3.10)

Ruff (I001)

dyana/loaders/megatron/main.py:1:1: I001 Import block is un-sorted or un-formatted

Check failure on line 8 in dyana/loaders/megatron/main.py

View workflow job for this annotation

GitHub Actions / Validate (3.11)

Ruff (I001)

dyana/loaders/megatron/main.py:1:1: I001 Import block is un-sorted or un-formatted
warnings.filterwarnings("ignore", category=UserWarning)

import torch

Check failure on line 11 in dyana/loaders/megatron/main.py

View workflow job for this annotation

GitHub Actions / Validate (3.9)

Ruff (E402)

dyana/loaders/megatron/main.py:11:1: E402 Module level import not at top of file

Check failure on line 11 in dyana/loaders/megatron/main.py

View workflow job for this annotation

GitHub Actions / Validate (3.10)

Ruff (E402)

dyana/loaders/megatron/main.py:11:1: E402 Module level import not at top of file

Check failure on line 11 in dyana/loaders/megatron/main.py

View workflow job for this annotation

GitHub Actions / Validate (3.11)

Ruff (E402)

dyana/loaders/megatron/main.py:11:1: E402 Module level import not at top of file

multiprocessing.set_start_method("spawn", force=True)

import transformer_engine.pytorch as te

Check failure on line 15 in dyana/loaders/megatron/main.py

View workflow job for this annotation

GitHub Actions / Validate (3.9)

Ruff (E402)

dyana/loaders/megatron/main.py:15:1: E402 Module level import not at top of file

Check failure on line 15 in dyana/loaders/megatron/main.py

View workflow job for this annotation

GitHub Actions / Validate (3.10)

Ruff (E402)

dyana/loaders/megatron/main.py:15:1: E402 Module level import not at top of file

Check failure on line 15 in dyana/loaders/megatron/main.py

View workflow job for this annotation

GitHub Actions / Validate (3.11)

Ruff (E402)

dyana/loaders/megatron/main.py:15:1: E402 Module level import not at top of file
from megatron.core import parallel_state

Check failure on line 16 in dyana/loaders/megatron/main.py

View workflow job for this annotation

GitHub Actions / Validate (3.9)

Ruff (E402)

dyana/loaders/megatron/main.py:16:1: E402 Module level import not at top of file

Check failure on line 16 in dyana/loaders/megatron/main.py

View workflow job for this annotation

GitHub Actions / Validate (3.10)

Ruff (E402)

dyana/loaders/megatron/main.py:16:1: E402 Module level import not at top of file

Check failure on line 16 in dyana/loaders/megatron/main.py

View workflow job for this annotation

GitHub Actions / Validate (3.11)

Ruff (E402)

dyana/loaders/megatron/main.py:16:1: E402 Module level import not at top of file
from megatron.core.transformer.transformer_config import TransformerConfig

Check failure on line 17 in dyana/loaders/megatron/main.py

View workflow job for this annotation

GitHub Actions / Validate (3.9)

Ruff (E402)

dyana/loaders/megatron/main.py:17:1: E402 Module level import not at top of file

Check failure on line 17 in dyana/loaders/megatron/main.py

View workflow job for this annotation

GitHub Actions / Validate (3.10)

Ruff (E402)

dyana/loaders/megatron/main.py:17:1: E402 Module level import not at top of file

Check failure on line 17 in dyana/loaders/megatron/main.py

View workflow job for this annotation

GitHub Actions / Validate (3.11)

Ruff (E402)

dyana/loaders/megatron/main.py:17:1: E402 Module level import not at top of file
from transformers import LlamaTokenizer

Check failure on line 18 in dyana/loaders/megatron/main.py

View workflow job for this annotation

GitHub Actions / Validate (3.9)

Ruff (E402)

dyana/loaders/megatron/main.py:18:1: E402 Module level import not at top of file

Check failure on line 18 in dyana/loaders/megatron/main.py

View workflow job for this annotation

GitHub Actions / Validate (3.10)

Ruff (E402)

dyana/loaders/megatron/main.py:18:1: E402 Module level import not at top of file

Check failure on line 18 in dyana/loaders/megatron/main.py

View workflow job for this annotation

GitHub Actions / Validate (3.11)

Ruff (E402)

dyana/loaders/megatron/main.py:18:1: E402 Module level import not at top of file

from dyana import Profiler

Check failure on line 20 in dyana/loaders/megatron/main.py

View workflow job for this annotation

GitHub Actions / Validate (3.9)

Ruff (E402)

dyana/loaders/megatron/main.py:20:1: E402 Module level import not at top of file

Check failure on line 20 in dyana/loaders/megatron/main.py

View workflow job for this annotation

GitHub Actions / Validate (3.10)

Ruff (E402)

dyana/loaders/megatron/main.py:20:1: E402 Module level import not at top of file

Check failure on line 20 in dyana/loaders/megatron/main.py

View workflow job for this annotation

GitHub Actions / Validate (3.11)

Ruff (E402)

dyana/loaders/megatron/main.py:20:1: E402 Module level import not at top of file


def find_tokenizer(model_path: Path) -> Path:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes! beautiful

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

taking after a true leader/ninja! ;)

"""Find tokenizer file in model directory or alongside model file."""
patterns = [
# LLaMA specific patterns first
"llama*tokenizer*.model", # LLaMA specific naming
"tokenizer.model", # Standard LLaMA tokenizer
# Generic patterns as fallback
"*.model", # sentencepiece models
"tokenizer.*", # huggingface style
"*/tokenizer.*", # nested folder
"vocab.*", # vocabulary files
"merges.txt", # BPE merges
]

# Try both the model's directory and its parent directory
search_dirs = [model_path.parent]
if model_path.parent.parent.exists():
search_dirs.append(model_path.parent.parent)

for directory in search_dirs:
all_files = list(directory.glob("*"))
for f in sorted(all_files):
print(f" {f}", file=sys.stderr)
# If it looks like a LLaMA tokenizer file, try it first
if "tokenizer" in f.name.lower() and f.name.endswith(".model"):
return f

# If no obvious tokenizer found, try the patterns
for pattern in patterns:
matches = list(directory.glob(pattern))
if matches:
return matches[0]

raise FileNotFoundError(
f"No tokenizer found in {[str(d) for d in search_dirs]} after trying patterns: {patterns}\n"
f"Available files in {model_path.parent}: {[f.name for f in model_path.parent.glob('*')]}"
)


def load_tokenizer(args) -> LlamaTokenizer:
if args.tokenizer:
tokenizer_path = Path(args.tokenizer)
if not tokenizer_path.exists():
raise FileNotFoundError(f"Tokenizer not found at {tokenizer_path}")
else:
# Otherwise search for tokenizer
tokenizer_path = find_tokenizer(model_path)

return LlamaTokenizer.from_pretrained(
str(tokenizer_path.parent),
local_files_only=True,
tokenizer_file=str(tokenizer_path.name),
)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model", required=True)
parser.add_argument("--size", choices=["7B", "13B"], default="7B")
parser.add_argument("--input", default="This is an example prompt.")
parser.add_argument("--tokenizer", help="Optional explicit tokenizer path")
args = parser.parse_args()

model_config = {
"7B": {"num_layers": 32, "hidden_size": 4096, "num_attention_heads": 32},
"13B": {"num_layers": 40, "hidden_size": 5120, "num_attention_heads": 40},
}[args.size]

profiler = Profiler(gpu=True)

try:
model_path = Path(args.model)
if not model_path.exists():
raise FileNotFoundError(f"Model not found at {model_path}")

tokenizer = load_tokenizer(args)
profiler.on_stage("tokenizer_loaded")

te.initialize()

has_gpu = torch.cuda.is_available()
device = torch.device("cuda" if has_gpu else "cpu")

# Megatron's tensor parallel
parallel_state.initialize_model_parallel(
tensor_model_parallel_size=1, # No tensor parallelism for now
pipeline_model_parallel_size=1, # No pipeline parallelism
)
profiler.on_stage("megatron_initialized")

# Megatron transformer config
config = TransformerConfig(
num_layers=model_config["num_layers"],
hidden_size=model_config["hidden_size"],
num_attention_heads=model_config["num_attention_heads"],
max_position_embeddings=4096,
init_method_std=0.02,
use_scaled_init_method=True,
attention_softmax_in_fp32=True,
rotary_pct=0.25, # LLaMA uses rotary embeddings
)

model = GPTModel( # noqa: F821
config=config,
vocab_size=tokenizer.vocab_size,
max_sequence_length=4096,
parallel_output=False,
share_embeddings_and_output_weights=True,
)
if has_gpu:
model = model.cuda()

profiler.on_stage("model_created")

# Load DMC checkpoint directly to GPU
checkpoint = torch.load(str(model_path), map_location=device)
model.load_state_dict(checkpoint)
model.eval()
profiler.on_stage("model_loaded")

# Run inference
input_ids = tokenizer(args.input, return_tensors="pt").to(device)
with torch.no_grad():
output = model(input_ids=input_ids["input_ids"])
logits = output.logits
next_token = torch.argmax(logits[:, -1, :], dim=-1)
generated = torch.cat([input_ids["input_ids"], next_token.unsqueeze(-1)], dim=-1)
text = tokenizer.decode(generated[0], skip_special_tokens=True)
profiler.track("output", text)
profiler.on_stage("inference_complete")

except Exception as e:
profiler.track_error("megatron", str(e))
19 changes: 19 additions & 0 deletions dyana/loaders/megatron/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
--extra-index-url https://download.pytorch.org/whl/cu121
--find-links https://developer.download.nvidia.com/compute/redist

# Base dependencies from Megatron core
torch>=2.0.0
packaging>=20.0
typing_extensions>=4.0.0

# Megatron DMC dependencies
flash-attn==2.6.1
sentencepiece==0.2.0
hydra-core==1.3.2
hydra_colorlog==1.2.0
nltk
datasets
transformers>=4.38.0

# Utilities
psutil>=5.6.7
30 changes: 30 additions & 0 deletions dyana/loaders/megatron/settings.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
description: Loads and profiles Megatron-LM DMC models for efficient inference

build_args:
extra-requirements: EXTRA_REQUIREMENTS

args:
- name: model
description: Path to model checkpoint (tokenizer should be in same directory)
required: true
volume: true

- name: size
description: Model size (7B or 13B)
required: false

- name: input
description: Input text for inference
default: "This is an example prompt."
required: false

- name: tokenizer
description: Optional explicit path to tokenizer file (otherwise auto-detected)
required: false
volume: true

examples:
- description: "Load a Megatron-DMC model with auto-detected tokenizer:"
command: dyana trace --loader megatron --model /path/to/model.pt --size 7B
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perfect! 👨🏻‍🍳

- description: "Load model with explicit tokenizer path:"
command: dyana trace --loader megatron --model /path/to/model.pt --size 7B --tokenizer /path/to/tokenizer.model
Loading