-
Notifications
You must be signed in to change notification settings - Fork 5.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Alternative weight loading via .safetensors (#507)
- Loading branch information
Showing
6 changed files
with
336 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
314 changes: 314 additions & 0 deletions
314
ch05/02_alternative_weight_loading/weight-loading-hf-safetensors.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,314 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "6d6bc54f-2b16-4b0f-be69-957eed5d112f", | ||
"metadata": {}, | ||
"source": [ | ||
"<table style=\"width:100%\">\n", | ||
"<tr>\n", | ||
"<td style=\"vertical-align:middle; text-align:left;\">\n", | ||
"<font size=\"2\">\n", | ||
"Supplementary code for the <a href=\"http://mng.bz/orYv\">Build a Large Language Model From Scratch</a> book by <a href=\"https://sebastianraschka.com\">Sebastian Raschka</a><br>\n", | ||
"<br>Code repository: <a href=\"https://github.com/rasbt/LLMs-from-scratch\">https://github.com/rasbt/LLMs-from-scratch</a>\n", | ||
"</font>\n", | ||
"</td>\n", | ||
"<td style=\"vertical-align:middle; text-align:left;\">\n", | ||
"<a href=\"http://mng.bz/orYv\"><img src=\"https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp\" width=\"100px\"></a>\n", | ||
"</td>\n", | ||
"</tr>\n", | ||
"</table>" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "72953590-5363-4398-85ce-54bde07f3d8a", | ||
"metadata": {}, | ||
"source": [ | ||
"# Bonus Code for Chapter 5" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "1a4ab5ee-e7b9-45d3-a82b-a12bcfc0945a", | ||
"metadata": {}, | ||
"source": [ | ||
"## Alternative Weight Loading from Hugging Face Model Hub Via `safetensors`" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "b2feea87-49f0-48b9-b925-b8f0dda4096f", | ||
"metadata": {}, | ||
"source": [ | ||
"- In the main chapter, we loaded the GPT model weights directly from OpenAI\n", | ||
"- This notebook provides alternative weight loading code to load the model weights from the [Hugging Face Model Hub](https://huggingface.co/docs/hub/en/models-the-hub) using `.safetensors` files\n", | ||
"- This is conceptually the same as loading weights of a PyTorch model from via the state-dict method described in chapter 5:\n", | ||
"\n", | ||
"```python\n", | ||
"state_dict = torch.load(\"model_state_dict.pth\")\n", | ||
"model.load_state_dict(state_dict) \n", | ||
"```\n", | ||
"\n", | ||
"- The appeal of `.safetensors` files lies in their secure design, as they only store tensor data and avoid the execution of potentially malicious code during loading\n", | ||
"- In newer versions of PyTorch (e.g., 2.0 and newer), a `weights_only=True` argument can be used with `torch.load` (e.g., `torch.load(\"model_state_dict.pth\", weights_only=True)`) to improve safety by skipping the execution of code and loading only the weights (this is now enabled by default in PyTorch 2.6 and newer)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "99b77109-5215-4d07-a618-4d10eff1a488", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# pip install safetensors" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "b0467eff-b43c-4a38-93e8-5ed87a5fc2b1", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"numpy version: 1.26.4\n", | ||
"torch version: 2.5.1\n", | ||
"safetensors version: 0.4.4\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"from importlib.metadata import version\n", | ||
"\n", | ||
"pkgs = [\"numpy\", \"torch\", \"safetensors\"]\n", | ||
"for p in pkgs:\n", | ||
" print(f\"{p} version: {version(p)}\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "d1cb0023-8a47-4b1a-9bde-54ab7eac476b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from previous_chapters import GPTModel, generate_text_simple" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "9ea9b1bc-7881-46ad-9555-27a9cf23faa7", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"BASE_CONFIG = {\n", | ||
" \"vocab_size\": 50257, # Vocabulary size\n", | ||
" \"context_length\": 1024, # Context length\n", | ||
" \"drop_rate\": 0.0, # Dropout rate\n", | ||
" \"qkv_bias\": True # Query-key-value bias\n", | ||
"}\n", | ||
"\n", | ||
"model_configs = {\n", | ||
" \"gpt2-small (124M)\": {\"emb_dim\": 768, \"n_layers\": 12, \"n_heads\": 12},\n", | ||
" \"gpt2-medium (355M)\": {\"emb_dim\": 1024, \"n_layers\": 24, \"n_heads\": 16},\n", | ||
" \"gpt2-large (774M)\": {\"emb_dim\": 1280, \"n_layers\": 36, \"n_heads\": 20},\n", | ||
" \"gpt2-xl (1558M)\": {\"emb_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25},\n", | ||
"}\n", | ||
"\n", | ||
"\n", | ||
"CHOOSE_MODEL = \"gpt2-small (124M)\"\n", | ||
"BASE_CONFIG.update(model_configs[CHOOSE_MODEL])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "e7b22375-6fac-4e90-9063-daa4de86c778", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import os\n", | ||
"import urllib.request\n", | ||
"from safetensors.torch import load_file\n", | ||
"\n", | ||
"URL_DIR = {\n", | ||
" \"gpt2-small (124M)\": \"gpt2\", # works ok\n", | ||
" \"gpt2-medium (355M)\": \"gpt2-medium\", # this file seems to have issues via `generate`\n", | ||
" \"gpt2-large (774M)\": \"gpt2-large\", # works ok\n", | ||
" \"gpt2-xl (1558M)\": \"gpt2-xl\" # works ok\n", | ||
"}\n", | ||
"\n", | ||
"url = f\"https://huggingface.co/openai-community/{URL_DIR[CHOOSE_MODEL]}/resolve/main/model.safetensors\"\n", | ||
"output_file = f\"model-{URL_DIR[CHOOSE_MODEL]}.safetensors\"\n", | ||
"\n", | ||
"# Download file\n", | ||
"if not os.path.exists(output_file):\n", | ||
" urllib.request.urlretrieve(url, output_file)\n", | ||
"\n", | ||
"# Load file\n", | ||
"state_dict = load_file(output_file)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"id": "4e2a4cf4-a54e-4307-9141-fb9f288e4dfa", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def assign(left, right):\n", | ||
" if left.shape != right.shape:\n", | ||
" raise ValueError(f\"Shape mismatch. Left: {left.shape}, Right: {right.shape}\")\n", | ||
" return torch.nn.Parameter(right.detach())" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 7, | ||
"id": "75be3077-f141-44bb-af88-62580ffd224c", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"def load_weights_into_gpt(gpt, params):\n", | ||
" gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params[\"wpe.weight\"])\n", | ||
" gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params[\"wte.weight\"])\n", | ||
"\n", | ||
" for b in range(len(gpt.trf_blocks)):\n", | ||
" q_w, k_w, v_w = torch.chunk(\n", | ||
" params[f\"h.{b}.attn.c_attn.weight\"], 3, axis=-1)\n", | ||
" gpt.trf_blocks[b].att.W_query.weight = assign(\n", | ||
" gpt.trf_blocks[b].att.W_query.weight, q_w.T)\n", | ||
" gpt.trf_blocks[b].att.W_key.weight = assign(\n", | ||
" gpt.trf_blocks[b].att.W_key.weight, k_w.T)\n", | ||
" gpt.trf_blocks[b].att.W_value.weight = assign(\n", | ||
" gpt.trf_blocks[b].att.W_value.weight, v_w.T)\n", | ||
"\n", | ||
" q_b, k_b, v_b = torch.chunk(\n", | ||
" params[f\"h.{b}.attn.c_attn.bias\"], 3, axis=-1)\n", | ||
" gpt.trf_blocks[b].att.W_query.bias = assign(\n", | ||
" gpt.trf_blocks[b].att.W_query.bias, q_b)\n", | ||
" gpt.trf_blocks[b].att.W_key.bias = assign(\n", | ||
" gpt.trf_blocks[b].att.W_key.bias, k_b)\n", | ||
" gpt.trf_blocks[b].att.W_value.bias = assign(\n", | ||
" gpt.trf_blocks[b].att.W_value.bias, v_b)\n", | ||
"\n", | ||
" gpt.trf_blocks[b].att.out_proj.weight = assign(\n", | ||
" gpt.trf_blocks[b].att.out_proj.weight,\n", | ||
" params[f\"h.{b}.attn.c_proj.weight\"].T)\n", | ||
" gpt.trf_blocks[b].att.out_proj.bias = assign(\n", | ||
" gpt.trf_blocks[b].att.out_proj.bias,\n", | ||
" params[f\"h.{b}.attn.c_proj.bias\"])\n", | ||
"\n", | ||
" gpt.trf_blocks[b].ff.layers[0].weight = assign(\n", | ||
" gpt.trf_blocks[b].ff.layers[0].weight,\n", | ||
" params[f\"h.{b}.mlp.c_fc.weight\"].T)\n", | ||
" gpt.trf_blocks[b].ff.layers[0].bias = assign(\n", | ||
" gpt.trf_blocks[b].ff.layers[0].bias,\n", | ||
" params[f\"h.{b}.mlp.c_fc.bias\"])\n", | ||
" gpt.trf_blocks[b].ff.layers[2].weight = assign(\n", | ||
" gpt.trf_blocks[b].ff.layers[2].weight,\n", | ||
" params[f\"h.{b}.mlp.c_proj.weight\"].T)\n", | ||
" gpt.trf_blocks[b].ff.layers[2].bias = assign(\n", | ||
" gpt.trf_blocks[b].ff.layers[2].bias,\n", | ||
" params[f\"h.{b}.mlp.c_proj.bias\"])\n", | ||
"\n", | ||
" gpt.trf_blocks[b].norm1.scale = assign(\n", | ||
" gpt.trf_blocks[b].norm1.scale,\n", | ||
" params[f\"h.{b}.ln_1.weight\"])\n", | ||
" gpt.trf_blocks[b].norm1.shift = assign(\n", | ||
" gpt.trf_blocks[b].norm1.shift,\n", | ||
" params[f\"h.{b}.ln_1.bias\"])\n", | ||
" gpt.trf_blocks[b].norm2.scale = assign(\n", | ||
" gpt.trf_blocks[b].norm2.scale,\n", | ||
" params[f\"h.{b}.ln_2.weight\"])\n", | ||
" gpt.trf_blocks[b].norm2.shift = assign(\n", | ||
" gpt.trf_blocks[b].norm2.shift,\n", | ||
" params[f\"h.{b}.ln_2.bias\"])\n", | ||
"\n", | ||
" gpt.final_norm.scale = assign(gpt.final_norm.scale, params[\"ln_f.weight\"])\n", | ||
" gpt.final_norm.shift = assign(gpt.final_norm.shift, params[\"ln_f.bias\"])\n", | ||
" gpt.out_head.weight = assign(gpt.out_head.weight, params[\"wte.weight\"])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 8, | ||
"id": "cda44d37-92c0-4c19-a70a-15711513afce", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"from previous_chapters import GPTModel\n", | ||
"\n", | ||
"\n", | ||
"gpt = GPTModel(BASE_CONFIG)\n", | ||
"\n", | ||
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", | ||
"load_weights_into_gpt(gpt, state_dict)\n", | ||
"gpt.to(device);" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 9, | ||
"id": "4ddd0d51-3ade-4890-9bab-d63f141d095f", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Output text:\n", | ||
" Every effort moves forward, but it's not enough.\n", | ||
"\n", | ||
"\"I'm not going to sit here and say, 'I'm not going to do this,'\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import tiktoken\n", | ||
"from previous_chapters import generate, text_to_token_ids, token_ids_to_text\n", | ||
"\n", | ||
"torch.manual_seed(123)\n", | ||
"\n", | ||
"tokenizer = tiktoken.get_encoding(\"gpt2\")\n", | ||
"\n", | ||
"token_ids = generate(\n", | ||
" model=gpt.to(device),\n", | ||
" idx=text_to_token_ids(\"Every effort moves\", tokenizer).to(device),\n", | ||
" max_new_tokens=30,\n", | ||
" context_size=BASE_CONFIG[\"context_length\"],\n", | ||
" top_k=1,\n", | ||
" temperature=1.0\n", | ||
")\n", | ||
"\n", | ||
"print(\"Output text:\\n\", token_ids_to_text(token_ids, tokenizer))" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.4" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters