From 67c40799d2d26abf75a5a4c00647caabdf19a42f Mon Sep 17 00:00:00 2001 From: sid Date: Sun, 21 Mar 2021 20:58:40 +0100 Subject: [PATCH] cleanup + add pretrained models --- GPTNeo_example_notebook.ipynb | 301 ++++++++++-- README.md | 19 +- ...mallPileAblation_small_CC100_newinput.json | 9 - ...allPileAblation_small_CC_raw_newinput.json | 9 - ...SmallPileAblation_small_Pile_newinput.json | 9 - .../SmallPileAblation_small_owt_newinput.json | 9 - .../dataset_configs/cc100en_40G_ablation.json | 9 - .../dataset_configs/cc_raw_40G_ablation.json | 9 - .../openwebtext-documents.json | 9 - configs/dataset_configs/owt_40G_ablation.json | 9 - .../dataset_configs/pile_40G_ablation.json | 9 - configs/dataset_configs/test.json | 9 - configs/gpt3_XL_128_Pile.json | 37 -- configs/gpt3_XL_256.json | 37 -- ...gpt3_XL_256_SmallPileAblation_CC100en.json | 38 -- .../gpt3_XL_256_SmallPileAblation_CC_raw.json | 38 -- .../gpt3_XL_256_SmallPileAblation_Pile.json | 38 -- .../gpt3_XL_256_SmallPileAblation_owt.json | 38 -- configs/gpt3_XL_64_Pile.json | 37 -- configs/gpt3_local.json | 36 -- configs/gpt3_scaling_128_pile.json | 40 -- configs/gpt3_scaling_256_pile.json | 40 -- configs/gpt3_scaling_32_pile.json | 40 -- configs/gpt3_small_local_256.json | 36 -- configs/gpt3_small_moe_8.json | 40 -- data/create_tfrecords.py | 79 ++-- inputs.py | 373 ++++++++------- model_fns.py | 45 +- models/activations.py | 128 +++++ models/gpt2/gpt2.py | 446 +----------------- models/layers.py | 317 +++++++++++++ models/utils.py | 38 +- 32 files changed, 1026 insertions(+), 1305 deletions(-) delete mode 100644 configs/dataset_configs/SmallPileAblation_small_CC100_newinput.json delete mode 100644 configs/dataset_configs/SmallPileAblation_small_CC_raw_newinput.json delete mode 100644 configs/dataset_configs/SmallPileAblation_small_Pile_newinput.json delete mode 100644 configs/dataset_configs/SmallPileAblation_small_owt_newinput.json delete mode 100644 configs/dataset_configs/cc100en_40G_ablation.json delete mode 100644 configs/dataset_configs/cc_raw_40G_ablation.json delete mode 100644 configs/dataset_configs/openwebtext-documents.json delete mode 100644 configs/dataset_configs/owt_40G_ablation.json delete mode 100644 configs/dataset_configs/pile_40G_ablation.json delete mode 100644 configs/dataset_configs/test.json delete mode 100644 configs/gpt3_XL_128_Pile.json delete mode 100644 configs/gpt3_XL_256.json delete mode 100644 configs/gpt3_XL_256_SmallPileAblation_CC100en.json delete mode 100644 configs/gpt3_XL_256_SmallPileAblation_CC_raw.json delete mode 100644 configs/gpt3_XL_256_SmallPileAblation_Pile.json delete mode 100644 configs/gpt3_XL_256_SmallPileAblation_owt.json delete mode 100644 configs/gpt3_XL_64_Pile.json delete mode 100644 configs/gpt3_local.json delete mode 100644 configs/gpt3_scaling_128_pile.json delete mode 100644 configs/gpt3_scaling_256_pile.json delete mode 100644 configs/gpt3_scaling_32_pile.json delete mode 100644 configs/gpt3_small_local_256.json delete mode 100644 configs/gpt3_small_moe_8.json create mode 100644 models/activations.py create mode 100644 models/layers.py diff --git a/GPTNeo_example_notebook.ipynb b/GPTNeo_example_notebook.ipynb index eaa5932e..c0ccd1f4 100644 --- a/GPTNeo_example_notebook.ipynb +++ b/GPTNeo_example_notebook.ipynb @@ -3,7 +3,7 @@ "nbformat_minor": 0, "metadata": { "colab": { - "name": "GPTNeo_example_notebook.ipynb", + "name": "Copy of GPTNeo_example_notebook.ipynb", "provenance": [], "collapsed_sections": [] }, @@ -36,8 +36,7 @@ { "cell_type": "code", "metadata": { - "id": "K-53qkZV6Lv9", - "cellView": "form" + "id": "K-53qkZV6Lv9" }, "source": [ "#@title Setup\n", @@ -55,7 +54,7 @@ "id": "R918l14UhrBR" }, "source": [ - "First, we need to download and tokenize a dataset - you can choose from:\n", + "Whether you're training from scratch or finetuning, we first need to download and tokenize a dataset - you can choose from:\n", "\n", "* OpenWebText - an opensource clone of OpenAI's WebText dataset, the original training data of GPT2.\n", "\n", @@ -65,6 +64,8 @@ "\n", "* NIHExporter - Data relating to various projects from the national institute of health.\n", "\n", + "* Custom - if this option is chosen you will be prompted to enter the path to your own dataset. It should be a directory containing .txt or .jsonl files.\n", + "\n", "All these datasets are from EleutherAI's side project - [The Pileβ„’](https://github.com/EleutherAI/The-Pile) - an effort to gather a general purpose, diverse and open source plain text dataset large enough to train 1T+ parameter language models.\n", "\n", "Even the smallest datasets are fairly large files, so this step will likely take a while. Select a dataset in the next cell, then run the next two cells, and go grab a snack and a cup of tea 😊\n", @@ -80,42 +81,42 @@ }, "source": [ "# Select a Dataset:\n", - "dataset = 'OpenWebText' #@param [\"OpenWebText\", \"YoutubeSubtitles\", \"HackerNews\", \"NIHExporter\", \"Custom\"]" + "dataset = 'HackerNews' #@param [\"OpenWebText\", \"YoutubeSubtitles\", \"HackerNews\", \"NIHExporter\", \"Custom\"]" ], - "execution_count": null, + "execution_count": 2, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "yiNKm4xsLZCq", - "cellView": "both" + "cellView": "form" }, "source": [ - "# Download Selected Dataset, or enter details of custom data:\n", + "# @title Download Selected Dataset, or enter details of custom data\n", "import os\n", "\n", "if dataset == 'OpenWebText':\n", - " !gdown https://drive.google.com/uc?id=1EA5V0oetDCOke7afsktL_JDQ-ETtNOvx\n", + " !wget https://the-eye.eu/public/AI/pile_preliminary_components/openwebtext2.jsonl.zst.tar -O openwebtext.tar.xz\n", " !tar xf openwebtext.tar.xz\n", " dataset_path = \"openwebtext\"\n", " dataset_name = dataset_path\n", " out_name = dataset_name + \"_tokenized\"\n", "elif dataset == 'YoutubeSubtitles':\n", " os.makedirs('data', exist_ok=True)\n", - " !wget wget https://eaidata.bmk.sh/data/yt_subs.jsonl.zst -O data/yt_subs.jsonl.zst\n", + " !wget https://the-eye.eu/public/AI/pile_preliminary_components/yt_subs.jsonl.zst -O data/yt_subs.jsonl.zst\n", " dataset_path = 'data'\n", " dataset_name = 'ytsubs'\n", " out_name = dataset_name + \"_tokenized\"\n", "elif dataset == 'HackerNews':\n", " os.makedirs('data', exist_ok=True)\n", - " !wget https://eaidata.bmk.sh/data/hn.tar.gz -O data/hn.tar.gz\n", + " !wget https://the-eye.eu/public/AI/pile_preliminary_components/hn.tar.gz -O data/hn.tar.gz\n", " dataset_path = 'data'\n", " dataset_name = 'hackernews'\n", " out_name = dataset_name + \"_tokenized\"\n", "elif dataset == \"NIHExporter\":\n", " os.makedirs('data', exist_ok=True)\n", - " !gdown https://drive.google.com/uc?id=11mO-0LuL2YeKoqqWXyHPHf3d2ODnjVPP\n", + " !wget https://the-eye.eu/public/AI/pile_preliminary_components/NIH_ExPORTER_awarded_grant_text.jsonl.zst -O data/NIH_ExPORTER_awarded_grant_text.jsonl.zst\n", " dataset_path = 'data'\n", " os.system('mv NIH_ExPORTER_awarded_grant_text.jsonl.zst ./data')\n", " dataset_name = 'nihexporter'\n", @@ -139,7 +140,7 @@ }, "source": [ "# Tokenize Data:\n", - "!python data/create_tfrecords.py --mode documents --input_dir /content/GPTNeo/$dataset_path --name $dataset_name --output_dir $out_name --write_dataset_config" + "!python data/create_tfrecords.py --mode documents --input_dir /content/GPTNeo/$dataset_path --name $dataset_name --files_per 1000 --output_dir $out_name --write_dataset_config --processes 1" ], "execution_count": null, "outputs": [] @@ -153,6 +154,15 @@ "# Train on TPU" ] }, + { + "cell_type": "markdown", + "metadata": { + "id": "M0R1owh2qvp8" + }, + "source": [ + "## Prepare Data" + ] + }, { "cell_type": "markdown", "metadata": { @@ -197,9 +207,9 @@ "cellView": "form" }, "source": [ - "path_to_cloud_bucket = 'gs://path_to_your_bucket/datasets/' #@param {type:\"string\"}" + "path_to_cloud_bucket = 'gs://your-bucket-name/datasets/' #@param {type:\"string\"}" ], - "execution_count": null, + "execution_count": 8, "outputs": [] }, { @@ -210,8 +220,10 @@ }, "source": [ "# copy the data to your bucket\n", + "if not path_to_cloud_bucket.endswith('/'):\n", + " path_to_cloud_bucket += '/'\n", "copy_loc = path_to_cloud_bucket + dataset_name\n", - "!gsutil cp -r /content/GPTNeo/$out_name $copy_loc\n", + "!gsutil -m cp -r /content/GPTNeo/$out_name $copy_loc\n", "!gsutil ls $path_to_cloud_bucket" ], "execution_count": null, @@ -227,9 +239,8 @@ "\n", "* First change the writefile path to point to your chosen dataset - e.g `%%writefile configs/dataset_configs/ytsubs.json`\n", "* Change the \"path\" field to point to your cloud bucket location - e.g `gs://neo_lmdatasets/datasets/ytsubs_*.tfrecords`\n", - "* In the second cell (the model config) change \"model_path\" to the location in your cloud bucket you want the model weights to be saved to.\n", - "* Change the first item in the \"datasets\" list to the name of your chosen dataset. (the filename minus .json in ./configs/dataset_configs)\n", - "* Once you've made the edits, then run the cells to overwrite the existing files.\n", + "* Change `dataset_name` in `%%writefile configs/dataset_configs/dataset_name.json` to the name of your chosen dataset.\n", + "* Once you've made the edits, then run the cell below to overwrite the existing files.\n", "\n", "\n" ] @@ -237,13 +248,17 @@ { "cell_type": "code", "metadata": { - "id": "MCsZP48vavCP" + "id": "MCsZP48vavCP", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "0adc5415-bb06-474e-e13a-bd7ae7f1ba00" }, "source": [ - "%%writefile configs/dataset_configs/nihexporter.json\n", + "%%writefile configs/dataset_configs/dataset_name.json\n", "\n", "{\n", - " \"path\": \"gs://your_bucket_name/datasets/nihexporter/nihexporter_*.tfrecords\",\n", + " \"path\": \"gs://your_bucket_name/datasets/dataset_name/dataset_name*.tfrecords\",\n", " \"eval_path\": \"\",\n", " \"n_vocab\": 50256,\n", " \"tokenizer_is_pretrained\": true,\n", @@ -252,8 +267,25 @@ " \"padding_id\": 50257\n", "}\n" ], - "execution_count": null, - "outputs": [] + "execution_count": 35, + "outputs": [ + { + "output_type": "stream", + "text": [ + "Overwriting configs/dataset_configs/hackernews.json\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dH0x3dI9j85P" + }, + "source": [ + "## Train from Scratch" + ] }, { "cell_type": "markdown", @@ -261,13 +293,17 @@ "id": "I6GnCgAkB7GQ" }, "source": [ - "The model below is identical to our pretrained GPT2 model [TODO: add link]. You can finetune from this model by downloading the weights [TODO: add link] and copying them to your specified \"model_path\" below. \n", + "The model below is identical to our pretrained GPT3XL model (1.3B Params). \n", "\n", - "If no previous model is found in \"model_path\", the model will start training from scratch.\n", + "If no previous model is found in \"model_path\", the model will start training from scratch. If you'd prefer to finetune from pretrained, skip to the `Finetune a Pretrained Model` section.\n", "\n", "If you want to use a smaller model, you can modify any of the config files in ../configs/ ending in _8.json, all of which are designed to train on tpu-v8s.\n", "\n", - "For a more detailed breakdown on what each item in the configuration file means - please read through our training and config guides in our [github README](https://github.com/EleutherAI/GPTNeo#training-guide). " + "For a more detailed breakdown on what each item in the configuration file means - please read through our training and config guides in our [github README](https://github.com/EleutherAI/GPTNeo#training-guide). \n", + "\n", + "You'll want to change the first item in the `datasets` list to the name of your chosen dataset. (the filename minus .json in ./configs/dataset_configs)\n", + "\n", + "You'll also want to modify the `model_path` field to point to your google cloud bucket, so checkpoints get saved to there." ] }, { @@ -279,7 +315,7 @@ "%%writefile configs/colab_XL.json\n", "\n", "{\n", - " \"n_head\": 32,\n", + " \"n_head\": 16,\n", " \"n_vocab\": 50260,\n", " \"embed_dropout\": 0,\n", " \"lr\": 0.0002,\n", @@ -289,25 +325,25 @@ " \"beta2\": 0.95,\n", " \"epsilon\": 1e-8,\n", " \"opt_name\": \"adam\",\n", - " \"weight_decay\": 0.1,\n", + " \"weight_decay\": 0,\n", " \"train_batch_size\": 256,\n", " \"attn_dropout\": 0,\n", - " \"train_steps\": 572300,\n", + " \"train_steps\": 600000,\n", " \"eval_steps\": 0,\n", " \"predict_steps\": 1,\n", " \"res_dropout\": 0,\n", - " \"eval_batch_size\": 64,\n", + " \"eval_batch_size\": 4,\n", " \"predict_batch_size\": 1,\n", " \"iterations\": 100,\n", " \"n_embd\": 2048,\n", - " \"datasets\": [[\"edit_this\", 21, \"documents_random\", 1.0]],\n", + " \"datasets\": [[\"dataset_name\", null, null, null]],\n", " \"model\": \"GPT\",\n", - " \"model_path\": \"gs://your_bucket_name/models/GPT3_XL\",\n", + " \"model_path\": \"gs://your_bucket/GPT3_XL\",\n", " \"n_ctx\": 2048,\n", " \"n_layer\": 24,\n", " \"scale_by_depth\": true,\n", " \"scale_by_in\": false,\n", - " \"attention_types\" : [[[\"global\"],24]],\n", + " \"attention_types\" : [[[\"global\", \"local\"],12]],\n", " \"mesh_shape\": \"x:4,y:2\",\n", " \"layout\": \"intermediate_expanded:x,heads:x,vocab:x,memory_length:y,embd:y\",\n", " \"activation_function\": \"gelu\",\n", @@ -339,6 +375,199 @@ "execution_count": null, "outputs": [] }, + { + "cell_type": "markdown", + "metadata": { + "id": "koKQHA5ikCvD" + }, + "source": [ + "## Finetune a Pretrained Model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0QZv4_pnkk26" + }, + "source": [ + "If you want to finetune from a pretrained model, EleutherAI has pretrained two models for release. One with [1.3B parameters](https://the-eye.eu/eleuther_staging/gptneo-release/GPT3_XL/), and another with [2.7B](https://the-eye.eu/eleuther_staging/gptneo-release/GPT3_2-7B/). \n", + "\n", + "Select an option below to download the weights locally. You will then need to upload them to your cloud bucket in order to finetune from them. If the download command isn't working, try the commented out code to download from a different source.\n", + "\n", + "The 2-7B model likely won't fit into the colab TPUs memory, and you may have to get some larger pods to finetune from it.\n", + "\n", + "Sampling from it, however, works just fine.\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "lgTG1ammqGB0", + "cellView": "form" + }, + "source": [ + "# @title Download pretrained model weights:\n", + "pretrained_model = 'GPT3_XL' #@param [\"GPT3_XL\", \"GPT3_2-7B\"]\n", + "\n", + "# !wget -m -np -c -U \"eye02\" -w 2 -R \"index.html*\" \"https://the-eye.eu/eleuther_staging/gptneo-release/$pretrained_model/\"\n", + "# path_to_local_weights = /content/GPTNeo/the-eye.eu/eleuther_staging/gptneo-release/$pretrained_model\n", + "\n", + "URL = f\"http://eaidata.bmk.sh/data/gptneo-release/{pretrained_model}/\"\n", + "FOLDER_NAME = \"GPT3_XL\"\n", + "!curl $URL | grep -i \"\" | sed -n 's/.*href=\"\\([^\"]*\\).*/\\1/p' | sed \"s|^|$URL|\" | xargs -n 1 -P 4 wget -P $pretrained_model\n", + "path_to_local_weights = $pretrained_model" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "2lzEJDnZ2yNm" + }, + "source": [ + "# upload to your bucket\n", + "bucket_base = \"gs://\" + path_to_cloud_bucket.replace('gs://', '').split('/')[0]\n", + "!gsutil -m cp -r $path_to_local_weights $bucket_base" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bnqkKBTOn0ox" + }, + "source": [ + "If everything has worked successfully you should now see your model listed in your bucket below." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "80t9MMionm2h" + }, + "source": [ + "!gsutil ls $bucket_base" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QDKL8fCSoApL" + }, + "source": [ + "Now we want to make a few modifications to the model config in order to \n", + "\n", + "1. Get training working on colab, and\n", + "2. Finetune on your chosen dataset. (\n", + "\n", + "You can change parameters below. \n", + "\n", + "* `path_to_model` should point to the model weights location in your cloud bucket, and will default to `$bucket_base/${pretrained_model}` if nothing is entered.\n", + "\n", + "* `batch_size` is your train batch size - if you're encountering memory errors, try lowering this.\n", + "\n", + "* `dataset_name` is the name of your dataset, if nothing is entered, this should default to the dataset you selected in the `Prepare Data` section.\n", + "\n", + "* `mesh_shape` specifies the way the model will be divided up across the TPU cores. We suggest leaving this alone unless you know what you're doing.\n", + "\n", + "* `train_steps` specifies how many steps you want the model to finetune for. We set this to 1000 for demonstrative purposes but you may need to increase this a little depending on your goals.\n", + "\n", + "* `steps_per_checkpoint` specifies how often you want to save model weights during training.\n", + "\n" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "Laf0slBMDCUj", + "cellView": "form" + }, + "source": [ + "# @title Modify config for colab. \n", + " \n", + "import json\n", + "from pprint import pprint\n", + "\n", + "path_to_model = \"\" #@param {type:\"string\"}\n", + "batch_size = 16 #@param {type:\"integer\"}\n", + "dset = \"\" #@param {type:\"string\"}\n", + "mesh_shape = \"x:4,y:2\" #@param {type:\"string\"}\n", + "train_steps = 1000 #@param {type:\"integer\"}\n", + "steps_per_checkpoint = 500 #@param {type:\"integer\"}\n", + "start_step = 400000 if pretrained_model == \"GPT3_2-7B\" else 362000\n", + "\n", + "if path_to_model == \"\":\n", + " path_to_model = f'{bucket_base.strip(\"/\")}/{pretrained_model}'\n", + "print(f'MODEL PATH: {path_to_model}\\n')\n", + "\n", + "if dset == \"\":\n", + " dset = dataset_name\n", + "\n", + "def pad_to_multiple_of(n, mult):\n", + " \"\"\"\n", + " pads n to a multiple of mult\n", + " \"\"\"\n", + " extra = n % mult\n", + " if extra > 0:\n", + " n = n + mult - extra\n", + " return n\n", + "\n", + "with open(f'/content/GPTNeo/the-eye.eu/eleuther_staging/gptneo-release/{pretrained_model}/config.json', 'r') as f:\n", + " data = json.load(f)\n", + " pprint(data)\n", + " mods = {\n", + " \"mesh_shape\": mesh_shape,\n", + " \"layout\": \"intermediate_expanded:x,heads:x,memory_length:y,embd:y\",\n", + " \"model_path\": path_to_model,\n", + " \"datasets\": [[dataset_name, None, None, None]],\n", + " \"train_steps\": start_step + train_steps,\n", + " \"eval_steps\": 0,\n", + " \"train_batch_size\": batch_size\n", + " }\n", + " data.update(mods)\n", + " print('\\n--->\\n')\n", + " pprint(data)\n", + " with open(f'configs/{pretrained_model}.json', 'w') as outfile:\n", + " json.dump(data, outfile, indent=2)" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EBnfILxbuVM7" + }, + "source": [ + "If everything's set up correctly, you can now run the main.py function to start training!" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "mKzkoJsFvoR3" + }, + "source": [ + "!gsutil ls gs://test-bucket-neo" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "metadata": { + "id": "0YlaHzyXuMaj" + }, + "source": [ + "!python3 main.py --model $pretrained_model --steps_per_checkpoint $steps_per_checkpoint --tpu colab" + ], + "execution_count": null, + "outputs": [] + }, { "cell_type": "markdown", "metadata": { @@ -381,10 +610,10 @@ "id": "sf_5E4fHFQIh" }, "source": [ - "!python3 main.py --model colab_XL --steps_per_checkpoint 500 --tpu colab --predict --prompt example_prompt.txt" + "!python3 main.py --model $pretrained_model --steps_per_checkpoint 500 --tpu colab --predict --prompt example_prompt.txt" ], "execution_count": null, "outputs": [] } ] -} +} \ No newline at end of file diff --git a/README.md b/README.md index 1fc5cc43..368b9574 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ πŸŽ‰ 1T or bust my dudes πŸŽ‰ -An implementation of model & data parallel [GPT2](https://openai.com/blog/better-language-models/) & [GPT3](https://arxiv.org/abs/2005.14165)-like models, with the ability to scale up to full GPT3 sizes (and possibly more!), using the [mesh-tensorflow](https://github.com/tensorflow/mesh) library. +An implementation of model & data parallel [GPT2](https://openai.com/blog/better-language-models/) & [GPT3](https://arxiv.org/abs/2005.14165) -like models, with the ability to scale up to full GPT3 sizes (and possibly more!), using the [mesh-tensorflow](https://github.com/tensorflow/mesh) library. Training and inference supported on both TPUs and GPUs. @@ -14,8 +14,19 @@ Also included are alternative model architectures and linear attention implement * [Axial Positional embedding](https://arxiv.org/abs/1912.12180) * Masked Language Modelling -Pretrained models will be released as they are finished training. +# Pretrained Models +**21/03/2021:** + +We're proud to release two pretrained GPT-Neo models trained on The Pile, the weights and configs can be freely downloaded from [the-eye.eu](https://the-eye.eu/eleuther_staging/gptneo-release/). + +1.3B: https://the-eye.eu/eleuther_staging/gptneo-release/GPT3_XL/ + +2.7B: https://the-eye.eu/eleuther_staging/gptneo-release/GPT3_2-7B/ + +For more information on how to get these set up, see the colab notebook, or read through the rest of the readme. + +This repository will be (mostly) archived as we move focus to our GPU training repo, [GPT-Neox](https://github.com/EleutherAI/gpt-neox/) # Setup ```bash @@ -44,10 +55,6 @@ You can also choose to train GPTNeo locally on your GPUs. To do so, you can omit Google colab provides tpu-v8s for free, which should be enough to finetune our models up to GPT3XL (1.5B parameter) sizes. Click the above button to run through our example colab notebook. -# Downloading Pretrained Models - -TODO - # Generating Text Once you have a trained model, or you've downloaded one of our pre-trained models (coming soon), generating text is as simple as running the main.py script with the `--predict` flag on. You can pass a path to your prompt txt file with the `--prompt` flag, like so: diff --git a/configs/dataset_configs/SmallPileAblation_small_CC100_newinput.json b/configs/dataset_configs/SmallPileAblation_small_CC100_newinput.json deleted file mode 100644 index c6ff4d4e..00000000 --- a/configs/dataset_configs/SmallPileAblation_small_CC100_newinput.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "n_vocab": 50257, - "path": "gs://neo-datasets/SmallPileAblation_small_CC100_newinput/*.tfrecords", - "eval_path": "gs://neo-datasets/pile_val_0_186482.tfrecords", - "tokenizer_is_pretrained": true, - "tokenizer_path": "gpt2", - "eos_id": 50256, - "padding_id": 50257 -} diff --git a/configs/dataset_configs/SmallPileAblation_small_CC_raw_newinput.json b/configs/dataset_configs/SmallPileAblation_small_CC_raw_newinput.json deleted file mode 100644 index e55479bb..00000000 --- a/configs/dataset_configs/SmallPileAblation_small_CC_raw_newinput.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "n_vocab": 50257, - "path": "gs://neo-datasets/SmallPileAblation_small_CC_raw_newinput/*.tfrecords", - "eval_path": "gs://neo-datasets/pile_val_0_186482.tfrecords", - "tokenizer_is_pretrained": true, - "tokenizer_path": "gpt2", - "eos_id": 50256, - "padding_id": 50257 -} diff --git a/configs/dataset_configs/SmallPileAblation_small_Pile_newinput.json b/configs/dataset_configs/SmallPileAblation_small_Pile_newinput.json deleted file mode 100644 index 41dbdf84..00000000 --- a/configs/dataset_configs/SmallPileAblation_small_Pile_newinput.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "n_vocab": 50257, - "path": "gs://neo-datasets/SmallPileAblation_small_Pile_newinput/*.tfrecords", - "eval_path": "gs://neo-datasets/pile_val_0_186482.tfrecords", - "tokenizer_is_pretrained": true, - "tokenizer_path": "gpt2", - "eos_id": 50256, - "padding_id": 50257 -} diff --git a/configs/dataset_configs/SmallPileAblation_small_owt_newinput.json b/configs/dataset_configs/SmallPileAblation_small_owt_newinput.json deleted file mode 100644 index 7f113670..00000000 --- a/configs/dataset_configs/SmallPileAblation_small_owt_newinput.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "n_vocab": 50257, - "path": "gs://neo-datasets/SmallPileAblation_small_owt_newinput/*.tfrecords", - "eval_path": "gs://neo-datasets/pile_val_0_186482.tfrecords", - "tokenizer_is_pretrained": true, - "tokenizer_path": "gpt2", - "eos_id": 50256, - "padding_id": 50257 -} diff --git a/configs/dataset_configs/cc100en_40G_ablation.json b/configs/dataset_configs/cc100en_40G_ablation.json deleted file mode 100644 index 47e9657f..00000000 --- a/configs/dataset_configs/cc100en_40G_ablation.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "n_vocab": 50257, - "path": "gs://neo-datasets/ablation_subset_cc100en/cc100en_*.tfrecords", - "eval_path": "gs://neo-datasets/pile_val.tfrecords", - "tokenizer_is_pretrained": true, - "tokenizer_path": "gpt2", - "eos_id": 50256, - "padding_id": 50257 -} diff --git a/configs/dataset_configs/cc_raw_40G_ablation.json b/configs/dataset_configs/cc_raw_40G_ablation.json deleted file mode 100644 index 8da41a89..00000000 --- a/configs/dataset_configs/cc_raw_40G_ablation.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "n_vocab": 50257, - "path": "gs://neo-datasets/cc_raw_ablation/cc_raw_*.tfrecords", - "eval_path": "gs://neo-datasets/pile_val.tfrecords", - "tokenizer_is_pretrained": true, - "tokenizer_path": "gpt2", - "eos_id": 50256, - "padding_id": 50257 -} diff --git a/configs/dataset_configs/openwebtext-documents.json b/configs/dataset_configs/openwebtext-documents.json deleted file mode 100644 index fa5f7af6..00000000 --- a/configs/dataset_configs/openwebtext-documents.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "n_vocab": 50257, - "path": "gs://neo-datasets/openwebtext-documents/openwebtext_*.tfrecords", - "eval_path": "", - "tokenizer_is_pretrained": true, - "tokenizer_path": "gpt2", - "eos_id": 50256, - "padding_id": 50257 -} diff --git a/configs/dataset_configs/owt_40G_ablation.json b/configs/dataset_configs/owt_40G_ablation.json deleted file mode 100644 index 056a8e7b..00000000 --- a/configs/dataset_configs/owt_40G_ablation.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "n_vocab": 50257, - "path": "gs://neo-datasets/owt_ablation_simple/owt_ablation_small_*.tfrecords", - "eval_path": "gs://neo-datasets/pile_val.tfrecords", - "tokenizer_is_pretrained": true, - "tokenizer_path": "gpt2", - "eos_id": 50256, - "padding_id": 50257 -} diff --git a/configs/dataset_configs/pile_40G_ablation.json b/configs/dataset_configs/pile_40G_ablation.json deleted file mode 100644 index fdf6682a..00000000 --- a/configs/dataset_configs/pile_40G_ablation.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "n_vocab": 50257, - "path": "gs://neo-datasets/pile_small_ablation/pile_small_ablation_*.tfrecords", - "eval_path": "gs://neo-datasets/pile_val.tfrecords", - "tokenizer_is_pretrained": true, - "tokenizer_path": "gpt2", - "eos_id": 50256, - "padding_id": 50257 -} diff --git a/configs/dataset_configs/test.json b/configs/dataset_configs/test.json deleted file mode 100644 index 7c530176..00000000 --- a/configs/dataset_configs/test.json +++ /dev/null @@ -1,9 +0,0 @@ -{ - "n_vocab": 50257, - "path": "gs://neo-datasets/new_input_test/*.tfrecords", - "eval_path": "gs://neo-datasets/new_input_test/*.tfrecords", - "tokenizer_is_pretrained": true, - "tokenizer_path": "gpt2", - "eos_id": 50256, - "padding_id": 50257 -} diff --git a/configs/gpt3_XL_128_Pile.json b/configs/gpt3_XL_128_Pile.json deleted file mode 100644 index 8f3a1ae0..00000000 --- a/configs/gpt3_XL_128_Pile.json +++ /dev/null @@ -1,37 +0,0 @@ -{ - "n_head": 32, - "n_vocab": 50257, - "embed_dropout": 0, - "lr": 0.0002, - "lr_decay": "cosine", - "warmup_steps": 3000, - "beta1": 0.9, - "beta2": 0.95, - "epsilon": 1e-8, - "opt_name": "adam", - "weight_decay": 0.1, - "train_batch_size": 512, - "attn_dropout": 0, - "train_steps": 286150, - "eval_steps": 10, - "predict_steps": 1, - "res_dropout": 0, - "eval_batch_size": 512, - "predict_batch_size": 1, - "iterations": 500, - "n_embd": 2048, - "datasets": [["pile", 25, "documents_random", 1.0]], - "model_path": "gs://neo-models/GPT3_XL_Pile", - "n_ctx": 2048, - "n_layer": 24, - "scale_by_depth": true, - "scale_by_in": false, - "attention_types" : [[["global"],24]], - "mesh_shape": "x:64,y:2", - "layout": "batch:x,memory_length:y,embd:y", - "activation_function": "gelu", - "recompute_grad": true, - "gradient_clipping": 1.0, - "tokens_per_mb_per_replica": 2048, - "precision": "bfloat16" -} diff --git a/configs/gpt3_XL_256.json b/configs/gpt3_XL_256.json deleted file mode 100644 index c1261c07..00000000 --- a/configs/gpt3_XL_256.json +++ /dev/null @@ -1,37 +0,0 @@ -{ - "n_head": 32, - "n_vocab": 50257, - "embed_dropout": 0, - "lr": 0.0002, - "lr_decay": "cosine", - "warmup_steps": 3000, - "beta1": 0.9, - "beta2": 0.95, - "epsilon": 1e-8, - "opt_name": "adam", - "weight_decay": 0.1, - "train_batch_size": 512, - "attn_dropout": 0, - "train_steps": 286150, - "eval_steps": 0, - "predict_steps": 1, - "res_dropout": 0, - "eval_batch_size": 128, - "predict_batch_size": 1, - "iterations": 2500, - "n_embd": 2048, - "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]], - "model_path": "gs://neo-models/GPT3_XL", - "n_ctx": 2048, - "n_layer": 24, - "scale_by_depth": true, - "scale_by_in": false, - "attention_types" : [[["global"],24]], - "mesh_shape": "x:128,y:2", - "layout": "batch:x,memory_length:y,embd:y", - "activation_function": "gelu", - "recompute_grad": true, - "gradient_clipping": 1.0, - "tokens_per_mb_per_replica": 2048 -} - diff --git a/configs/gpt3_XL_256_SmallPileAblation_CC100en.json b/configs/gpt3_XL_256_SmallPileAblation_CC100en.json deleted file mode 100644 index 2bd50039..00000000 --- a/configs/gpt3_XL_256_SmallPileAblation_CC100en.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "n_head": 32, - "n_vocab": 50257, - "embed_dropout": 0, - "lr": 0.0002, - "lr_decay": "cosine", - "warmup_steps": 3000, - "beta1": 0.9, - "beta2": 0.95, - "epsilon": 1e-8, - "opt_name": "adam", - "weight_decay": 0.1, - "train_batch_size": 256, - "attn_dropout": 0, - "train_steps": 25000, - "eval_steps": 10, - "predict_steps": 1, - "res_dropout": 0, - "eval_batch_size": 512, - "predict_batch_size": 1, - "iterations": 1000, - "n_embd": 2048, - "datasets": [["SmallPileAblation_small_CC100_newinput", null, null, null]], - "model_path": "gs://neo-models/GPT3_SmallPileAblation_small_CC100_newinput", - "n_ctx": 2048, - "n_layer": 24, - "scale_by_depth": true, - "scale_by_in": false, - "attention_types" : [[["global"],24]], - "mesh_shape": "x:128,y:2", - "layout": "batch:x,memory_length:y,embd:y", - "activation_function": "gelu", - "recompute_grad": true, - "gradient_clipping": 1.0, - "tokens_per_mb_per_replica": 2048, - "precision": "bfloat16", - "eval_tasks": ["lambada", "wikitext103"] -} diff --git a/configs/gpt3_XL_256_SmallPileAblation_CC_raw.json b/configs/gpt3_XL_256_SmallPileAblation_CC_raw.json deleted file mode 100644 index 4a882141..00000000 --- a/configs/gpt3_XL_256_SmallPileAblation_CC_raw.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "n_head": 32, - "n_vocab": 50257, - "embed_dropout": 0, - "lr": 0.0002, - "lr_decay": "cosine", - "warmup_steps": 3000, - "beta1": 0.9, - "beta2": 0.95, - "epsilon": 1e-8, - "opt_name": "adam", - "weight_decay": 0.1, - "train_batch_size": 256, - "attn_dropout": 0, - "train_steps": 25000, - "eval_steps": 10, - "predict_steps": 1, - "res_dropout": 0, - "eval_batch_size": 512, - "predict_batch_size": 1, - "iterations": 1000, - "n_embd": 2048, - "datasets": [["SmallPileAblation_small_CC_raw_newinput", null, null, null]], - "model_path": "gs://neo-models/GPT3_SmallPileAblation_small_CC_raw_newinput", - "n_ctx": 2048, - "n_layer": 24, - "scale_by_depth": true, - "scale_by_in": false, - "attention_types" : [[["global"],24]], - "mesh_shape": "x:128,y:2", - "layout": "batch:x,memory_length:y,embd:y", - "activation_function": "gelu", - "recompute_grad": true, - "gradient_clipping": 1.0, - "tokens_per_mb_per_replica": 2048, - "precision": "bfloat16", - "eval_tasks": ["lambada", "wikitext103"] -} diff --git a/configs/gpt3_XL_256_SmallPileAblation_Pile.json b/configs/gpt3_XL_256_SmallPileAblation_Pile.json deleted file mode 100644 index 9769ae54..00000000 --- a/configs/gpt3_XL_256_SmallPileAblation_Pile.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "n_head": 32, - "n_vocab": 50257, - "embed_dropout": 0, - "lr": 0.0002, - "lr_decay": "cosine", - "warmup_steps": 3000, - "beta1": 0.9, - "beta2": 0.95, - "epsilon": 1e-8, - "opt_name": "adam", - "weight_decay": 0.1, - "train_batch_size": 256, - "attn_dropout": 0, - "train_steps": 25000, - "eval_steps": 10, - "predict_steps": 1, - "res_dropout": 0, - "eval_batch_size": 512, - "predict_batch_size": 1, - "iterations": 1000, - "n_embd": 2048, - "datasets": [["SmallPileAblation_small_Pile_newinput", null, null, null]], - "model_path": "gs://neo-models/GPT3_SmallPileAblation_small_Pile_newinput", - "n_ctx": 2048, - "n_layer": 24, - "scale_by_depth": true, - "scale_by_in": false, - "attention_types" : [[["global"],24]], - "mesh_shape": "x:128,y:2", - "layout": "batch:x,memory_length:y,embd:y", - "activation_function": "gelu", - "recompute_grad": true, - "gradient_clipping": 1.0, - "tokens_per_mb_per_replica": 2048, - "precision": "bfloat16", - "eval_tasks": ["lambada", "wikitext103"] -} diff --git a/configs/gpt3_XL_256_SmallPileAblation_owt.json b/configs/gpt3_XL_256_SmallPileAblation_owt.json deleted file mode 100644 index d2df74dd..00000000 --- a/configs/gpt3_XL_256_SmallPileAblation_owt.json +++ /dev/null @@ -1,38 +0,0 @@ -{ - "n_head": 32, - "n_vocab": 50257, - "embed_dropout": 0, - "lr": 0.0002, - "lr_decay": "cosine", - "warmup_steps": 3000, - "beta1": 0.9, - "beta2": 0.95, - "epsilon": 1e-8, - "opt_name": "adam", - "weight_decay": 0.1, - "train_batch_size": 256, - "attn_dropout": 0, - "train_steps": 25000, - "eval_steps": 10, - "predict_steps": 1, - "res_dropout": 0, - "eval_batch_size": 512, - "predict_batch_size": 1, - "iterations": 1000, - "n_embd": 2048, - "datasets": [["SmallPileAblation_small_owt_newinput", null, null, null]], - "model_path": "gs://neo-models/GPT3_SmallPileAblation_small_owt_newinput", - "n_ctx": 2048, - "n_layer": 24, - "scale_by_depth": true, - "scale_by_in": false, - "attention_types" : [[["global"],24]], - "mesh_shape": "x:128,y:2", - "layout": "batch:x,memory_length:y,embd:y", - "activation_function": "gelu", - "recompute_grad": true, - "gradient_clipping": 1.0, - "tokens_per_mb_per_replica": 2048, - "precision": "bfloat16", - "eval_tasks": ["lambada", "wikitext103"] -} diff --git a/configs/gpt3_XL_64_Pile.json b/configs/gpt3_XL_64_Pile.json deleted file mode 100644 index 71057481..00000000 --- a/configs/gpt3_XL_64_Pile.json +++ /dev/null @@ -1,37 +0,0 @@ -{ - "n_head": 32, - "n_vocab": 50257, - "embed_dropout": 0, - "lr": 0.0002, - "lr_decay": "cosine", - "warmup_steps": 3000, - "beta1": 0.9, - "beta2": 0.95, - "epsilon": 1e-8, - "opt_name": "adam", - "weight_decay": 0.1, - "train_batch_size": 512, - "attn_dropout": 0, - "train_steps": 286150, - "eval_steps": 10, - "predict_steps": 1, - "res_dropout": 0, - "eval_batch_size": 512, - "predict_batch_size": 1, - "iterations": 500, - "n_embd": 2048, - "datasets": [["pile", 25, "documents_random", 1.0]], - "model_path": "gs://neo-models/GPT3_XL_Pile", - "n_ctx": 2048, - "n_layer": 24, - "scale_by_depth": true, - "scale_by_in": false, - "attention_types" : [[["global"],24]], - "mesh_shape": "x:32,y:2", - "layout": "batch:x,memory_length:y,embd:y", - "activation_function": "gelu", - "recompute_grad": true, - "gradient_clipping": 1.0, - "tokens_per_mb_per_replica": 2048, - "precision": "bfloat16" -} diff --git a/configs/gpt3_local.json b/configs/gpt3_local.json deleted file mode 100644 index 588ba9ac..00000000 --- a/configs/gpt3_local.json +++ /dev/null @@ -1,36 +0,0 @@ -{ - "n_vocab": 32768, - "n_head": 8, - "embed_dropout": 0, - "lr": 0.0003, - "lr_decay": "cosine", - "warmup_steps": 3000, - "beta1": 0.9, - "beta2": 0.95, - "epsilon": 1e-8, - "opt_name": "adam", - "weight_decay": 0.00, - "train_batch_size": 10, - "attn_dropout": 0, - "train_steps": 572300, - "eval_steps": 0, - "predict_steps": 0, - "res_dropout": 0, - "eval_batch_size": 32, - "iterations": 500, - "n_embd": 512, - "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]], - "model_path": "./tfmodels", - "n_ctx": 256, - "n_layer": 6, - "scale_by_depth": true, - "scale_by_in": false, - "attention_types" : [[["global"],6]], - "mesh_shape": "x:1", - "layout": "batch:x", - "activation_function": "gelu", - "recompute_grad": false, - "no_weight_tie": true, - "num_microbatches": 4, - "gradient_clipping": 0.5 -} \ No newline at end of file diff --git a/configs/gpt3_scaling_128_pile.json b/configs/gpt3_scaling_128_pile.json deleted file mode 100644 index 98394f54..00000000 --- a/configs/gpt3_scaling_128_pile.json +++ /dev/null @@ -1,40 +0,0 @@ -{ - "n_head": 64, - "n_vocab": 50257, - "embed_dropout": 0, - "lr": 0.0001, - "lr_decay": "cosine", - "warmup_steps": 3000, - "beta1": 0.9, - "beta2": 0.95, - "epsilon": 1e-8, - "ada_epsilon1": 1e-30, - "ada_epsilon2": 1e-3, - "opt_name": "adam", - "weight_decay": 0.10, - "train_batch_size": 128, - "attn_dropout": 0, - "train_steps": 143075, - "eval_steps": 0, - "predict_steps": 1, - "res_dropout": 0, - "eval_batch_size": 128, - "predict_batch_size": 1, - "iterations": 10, - "n_embd": 16384, - "datasets": [["SmallPileAblation_small_Pile_newinput", null, null, null]], - "model_path": "gs://neo-models/gpt3_scaling_128_pile", - "n_ctx": 2048, - "n_layer": 30, - "scale_by_depth": true, - "scale_by_in": false, - "attention_types" : [[["global", "local"],15]], - "mesh_shape": "x:1,y:128", - "layout": "batch:x,embd:y,memory_length:y ", - "activation_function": "gelu", - "recompute_grad": true, - "gradient_clipping": 1.0, - "tokens_per_mb_per_replica": 2048, - "precision": "bfloat16" -} - diff --git a/configs/gpt3_scaling_256_pile.json b/configs/gpt3_scaling_256_pile.json deleted file mode 100644 index e56e985b..00000000 --- a/configs/gpt3_scaling_256_pile.json +++ /dev/null @@ -1,40 +0,0 @@ -{ - "n_head": 64, - "n_vocab": 50257, - "embed_dropout": 0, - "lr": 0.0001, - "lr_decay": "cosine", - "warmup_steps": 3000, - "beta1": 0.9, - "beta2": 0.95, - "epsilon": 1e-8, - "ada_epsilon1": 1e-30, - "ada_epsilon2": 1e-3, - "opt_name": "adam", - "weight_decay": 0.10, - "train_batch_size": 128, - "attn_dropout": 0, - "train_steps": 143075, - "eval_steps": 0, - "predict_steps": 1, - "res_dropout": 0, - "eval_batch_size": 128, - "predict_batch_size": 1, - "iterations": 10, - "n_embd": 16384, - "datasets": [["SmallPileAblation_small_Pile_newinput", null, null, null]], - "model_path": "gs://neo-models/gpt3_scaling_256_pile", - "n_ctx": 2048, - "n_layer": 60, - "scale_by_depth": true, - "scale_by_in": false, - "attention_types" : [[["global", "local"],30]], - "mesh_shape": "x:1,y:256", - "layout": "batch:x,embd:y,memory_length:y ", - "activation_function": "gelu", - "recompute_grad": true, - "gradient_clipping": 1.0, - "tokens_per_mb_per_replica": 2048, - "precision": "bfloat16" -} - diff --git a/configs/gpt3_scaling_32_pile.json b/configs/gpt3_scaling_32_pile.json deleted file mode 100644 index e27dea2f..00000000 --- a/configs/gpt3_scaling_32_pile.json +++ /dev/null @@ -1,40 +0,0 @@ -{ - "n_head": 64, - "n_vocab": 50257, - "embed_dropout": 0, - "lr": 0.0001, - "lr_decay": "cosine", - "warmup_steps": 3000, - "beta1": 0.9, - "beta2": 0.95, - "epsilon": 1e-8, - "ada_epsilon1": 1e-30, - "ada_epsilon2": 1e-3, - "opt_name": "adam", - "weight_decay": 0.10, - "train_batch_size": 128, - "attn_dropout": 0, - "train_steps": 143075, - "eval_steps": 0, - "predict_steps": 1, - "res_dropout": 0, - "eval_batch_size": 128, - "predict_batch_size": 1, - "iterations": 10, - "n_embd": 8192, - "datasets": [["SmallPileAblation_small_Pile_newinput", null, null, null]], - "model_path": "gs://neo-models/gpt3_scaling_32_pile", - "n_ctx": 2048, - "n_layer": 30, - "scale_by_depth": true, - "scale_by_in": false, - "attention_types" : [[["global", "local"],15]], - "mesh_shape": "x:1,y:32", - "layout": "batch:x,embd:y,memory_length:y ", - "activation_function": "gelu", - "recompute_grad": true, - "gradient_clipping": 1.0, - "tokens_per_mb_per_replica": 2048, - "precision": "bfloat16" -} - diff --git a/configs/gpt3_small_local_256.json b/configs/gpt3_small_local_256.json deleted file mode 100644 index b6356409..00000000 --- a/configs/gpt3_small_local_256.json +++ /dev/null @@ -1,36 +0,0 @@ -{ - "n_head": 12, - "n_vocab": 50304, - "embed_dropout": 0, - "lr": 0.0006, - "lr_decay": "cosine", - "warmup_steps": 3000, - "beta1": 0.9, - "beta2": 0.95, - "epsilon": 1e-8, - "opt_name": "adam", - "weight_decay": 0.10, - "train_batch_size": 256, - "attn_dropout": 0, - "train_steps": 572300, - "eval_steps": 0, - "predict_steps": 1, - "res_dropout": 0, - "eval_batch_size": 64, - "predict_batch_size": 1, - "iterations": 2500, - "n_embd": 768, - "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]], - "model_path": "gs://neo-models/GPT3_SMALL_LOCAL", - "n_ctx": 2048, - "n_layer": 12, - "scale_by_depth": true, - "scale_by_in": false, - "attention_types": [[["local", "global"],6]], - "mesh_shape": "x:64,y:4", - "layout": "batch:x,heads:y,vocab:y", - "activation_function": "gelu", - "recompute_grad": false, - "gradient_clipping": 1.0 -} - diff --git a/configs/gpt3_small_moe_8.json b/configs/gpt3_small_moe_8.json deleted file mode 100644 index a3755d01..00000000 --- a/configs/gpt3_small_moe_8.json +++ /dev/null @@ -1,40 +0,0 @@ -{ - "n_head": 12, - "n_vocab": 50304, - "embed_dropout": 0, - "lr": 0.0006, - "lr_decay": "cosine", - "warmup_steps": 3000, - "beta1": 0.9, - "beta2": 0.95, - "epsilon": 1e-8, - "opt_name": "adam", - "weight_decay": 0.10, - "train_batch_size": 256, - "attn_dropout": 0, - "train_steps": 572300, - "eval_steps": 0, - "predict_steps": 1, - "res_dropout": 0, - "eval_batch_size": 8, - "predict_batch_size": 1, - "iterations": 1000, - "n_embd": 768, - "moe_params": { - "moe_dropout_rate": 0.0 - }, - "moe_layers": [2,4,6,8,10,12], - "datasets": [["openwebtext-documents", 25, "documents_random", 1.0]], - "model_path": "gs://neo-models/gpt3_small_moe", - "n_ctx": 2048, - "n_layer": 12, - "scale_by_depth": true, - "scale_by_in": false, - "attention_types": [[["global"],12]], - "mesh_shape": "x:4,y:2", - "layout": "batch:x,heads:y,vocab:y,intermediate_expanded:y,experts:y", - "activation_function": "gelu", - "recompute_grad": false, - "gradient_clipping": 1.0, - "tokens_per_mb_per_replica": 4096 -} \ No newline at end of file diff --git a/data/create_tfrecords.py b/data/create_tfrecords.py index ea3be2c0..dd709bb8 100644 --- a/data/create_tfrecords.py +++ b/data/create_tfrecords.py @@ -15,18 +15,18 @@ logging.getLogger("transformers").setLevel(logging.ERROR) parser = argparse.ArgumentParser() -parser.add_argument("--mode", type=str, choices=["chunks", "documents"], default="documents", - help="Whether a tfrecord example is a constant sized chunk or a full document") -parser.add_argument("--input_dir", type=str, help="Path to where your files are located. Files ending in .zst are treated as \ - archives, all others as raw text.") +parser.add_argument("--input_dir", type=str, help="Path to where your files are located. Files ending in .zst are " + "treated as archives, all others as raw text.") parser.add_argument("--files_per", type=int, default=100000, help="Text files per tfrecord") parser.add_argument("--name", type=str, default="openwebtext", help="Name of output files will be name_i.tfrecords where i is the number of the file") parser.add_argument("--output_dir", type=str, default="./tfrecords", help="Where to put tfrecords") -parser.add_argument("--encoder_path", type=str, help="Path to encoder files, or leave unspecified to use GPT2 tokenizer") +parser.add_argument("--encoder_path", type=str, + help="Path to encoder files, or leave unspecified to use GPT2 tokenizer") parser.add_argument("--minimum_size", type=int, default=100, help="Minimum size a document has to be to be included") parser.add_argument("--ftfy", action="store_false", help="normalize with ftfy") -parser.add_argument("--separator", nargs="+", type=int, default=[50256], help="separator to place between files in chunk mode") +parser.add_argument("--separator", nargs="+", type=int, default=[50256], + help="separator to place between files in chunk mode") parser.add_argument("--chunk_size", type=int, default=2048, help="How big a chunk should be in chunk mode. " "Should equal your model's context size") parser.add_argument("--write_dataset_config", action="store_true", help="Write the dataset config file on completion") @@ -46,6 +46,7 @@ def _int64_feature(value): """ return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) + def write_to_file(writer, data): """ writes data to tfrecord file @@ -56,25 +57,29 @@ def write_to_file(writer, data): tf_example = tf.train.Example(features=tf.train.Features(feature=feature)) writer.write(tf_example.SerializeToString()) + def get_tokenizer(args): if args.encoder_path is None: return GPT2TokenizerFast.from_pretrained('gpt2') else: return Tokenizer.from_file(args.encoder_path) + def split_list(l, n): # splits list/string into n size chunks - return [l[i:i+n] for i in range(0, len(l), n)] + return [l[i:i + n] for i in range(0, len(l), n)] + def archive_to_tokens(f, encoder, args): # Generator that yields the contents of the files in an archive # if data_to_prepend is not None, prepend data_to_prepend + a EOS separator to the encoded data reader = Reader(f) for doc in reader.stream_data(threaded=False): - if args.ftfy: # fix text with ftfy if specified + if args.ftfy: # fix text with ftfy if specified doc = ftfy.fix_text(doc, normalization='NFKC') - doc = encoder.encode(doc) + args.separator # read document from lmd and append separator token - yield split_list(doc, args.chunk_size) # split into n_ctx + 1 size chunks + doc = encoder.encode(doc) + args.separator # read document from lmd and append separator token + yield split_list(doc, args.chunk_size) # split into n_ctx + 1 size chunks + def write_files(files, files_per, output_dir, out_name, start_no, write_remainder=False, process_no=None): # writes a list of files to .tfrecords @@ -82,17 +87,17 @@ def write_files(files, files_per, output_dir, out_name, start_no, write_remainde return chunks = split_list(files, files_per) - if len(chunks[-1]) != files_per and not write_remainder: # pop the last file if it's length != files per + if len(chunks[-1]) != files_per and not write_remainder: # pop the last file if it's length != files per remainder = chunks.pop(-1) else: - remainder = None # assuming files = remainder from an old chunk here + remainder = None # assuming files = remainder from an old chunk here files_per = len(chunks[-1]) for files in chunks: fp = f"{output_dir}/{out_name}_{start_no}" if process_no is not None: fp += f"_{process_no}" - fp += f"_{files_per}" # add number of files in tfrecord to end of fp + fp += f"_{files_per}" # add number of files in tfrecord to end of fp fp += ".tfrecords" with tf.io.TFRecordWriter(fp) as writer: for f in files: @@ -100,12 +105,14 @@ def write_files(files, files_per, output_dir, out_name, start_no, write_remainde start_no += 1 return start_no, remainder + def get_files(input_dir, filetypes=None): # gets all files of in input_dir if filetypes == None: filetypes = ["jsonl.zst", ".txt", ".xz", ".tar.gz"] files = [list(Path(input_dir).glob(f"*{ft}")) for ft in filetypes] - return [str(item) for sublist in files for item in sublist] # flatten list of list -> list and stringify Paths + return [str(item) for sublist in files for item in sublist] # flatten list of list -> list and stringify Paths + def read_checkpoint(checkpoint_path, resume_from_checkpoint=True): # init checkpointing @@ -118,15 +125,18 @@ def read_checkpoint(checkpoint_path, resume_from_checkpoint=True): pass return 0, 0 -def create_tfrecords(params, write_remainder=True, write_every_n_files=1, save_checkpoints=False, resume_from_checkpoint=False, display_pbar=False): + +def create_tfrecords(params, write_remainder=True, write_every_n_files=1, save_checkpoints=False, + resume_from_checkpoint=False, display_pbar=False): # iterates through files in input_dir, splitting into chunks and saving a tfrecords file every chunks. files, args, process_no = params - enc = get_tokenizer(args) # get tokenizer + enc = get_tokenizer(args) # get tokenizer # init metadata discarded_files = 0 files_processed = 0 - pbar = tqdm(desc=f"Writing TFRecord Files to {args.output_dir}. Parsed 0 input files. files_written ", disable= not display_pbar) + pbar = tqdm(desc=f"Writing TFRecord Files to {args.output_dir}. Parsed 0 input files. files_written ", + disable=not display_pbar) checkpoint_path = f"{args.output_dir}/checkpoint.txt" resume_files_processed, tfrecord_count = read_checkpoint(checkpoint_path, resume_from_checkpoint) @@ -137,7 +147,7 @@ def create_tfrecords(params, write_remainder=True, write_every_n_files=1, save_c for tokenized_files in archive_to_tokens(f, enc, args): files_processed += 1 if files_processed < resume_files_processed: - continue # resume from checkpoint + continue # resume from checkpoint # if the last chunk < chunk size, but > minimum_size, take it and append it to the beginning of the next file n_tokens = len(tokenized_files[-1]) @@ -155,28 +165,35 @@ def create_tfrecords(params, write_remainder=True, write_every_n_files=1, save_c # add tokenized files > chunk size to main array tokenized_files_array.extend(tokenized_files) - if len(tokenized_files_array) >= args.files_per * write_every_n_files: # write every n files - _tfrecord_count, remainder = write_files(tokenized_files_array, files_per=args.files_per, output_dir=args.output_dir, out_name=args.name, start_no = tfrecord_count, process_no=process_no) - pbar.update(_tfrecord_count - tfrecord_count) # update progress bar - pbar.set_description(f"Writing TFRecord Files to {args.output_dir}. Parsed {files_processed} input files. files_written ") + if len(tokenized_files_array) >= args.files_per * write_every_n_files: # write every n files + _tfrecord_count, remainder = write_files(tokenized_files_array, files_per=args.files_per, + output_dir=args.output_dir, out_name=args.name, + start_no=tfrecord_count, process_no=process_no) + pbar.update(_tfrecord_count - tfrecord_count) # update progress bar + pbar.set_description( + f"Writing TFRecord Files to {args.output_dir}. Parsed {files_processed} input files. files_written ") tfrecord_count = _tfrecord_count - tokenized_files_array = remainder if remainder is not None else [] # add remaining files to next chunk + tokenized_files_array = remainder if remainder is not None else [] # add remaining files to next chunk with open(checkpoint_path, "w") as checkpoint_file: checkpoint_file.write(f"{files_processed}, {tfrecord_count}") - if len(tokenized_files_array) >= args.files_per: # also write at end - _tfrecord_count, remainder = write_files(tokenized_files_array, files_per=args.files_per, output_dir=args.output_dir, out_name=args.name, start_no=tfrecord_count, process_no=process_no) + if len(tokenized_files_array) >= args.files_per: # also write at end + _tfrecord_count, remainder = write_files(tokenized_files_array, files_per=args.files_per, + output_dir=args.output_dir, out_name=args.name, + start_no=tfrecord_count, process_no=process_no) pbar.update(_tfrecord_count - tfrecord_count) - pbar.set_description(f"Writing TFRecord Files to {args.output_dir}. Parsed {files_processed} input files. files_written ") + pbar.set_description( + f"Writing TFRecord Files to {args.output_dir}. Parsed {files_processed} input files. files_written ") tfrecord_count = _tfrecord_count with open(checkpoint_path, "w") as checkpoint_file: checkpoint_file.write(f"{files_processed}, {tfrecord_count}") else: - remainder = tokenized_files_array # add remaining to remainder + remainder = tokenized_files_array # add remaining to remainder if write_remainder: # write out the remaining files even if there's less than files_per - write_files(remainder, files_per=args.files_per, output_dir=args.output_dir, out_name=args.name, start_no=tfrecord_count, write_remainder=True) + write_files(remainder, files_per=args.files_per, output_dir=args.output_dir, out_name=args.name, + start_no=tfrecord_count, write_remainder=True) successful_files = files_processed - discarded_files return {"discarded": discarded_files, "processed": files_processed, "successful": successful_files} @@ -190,14 +207,14 @@ def create_tfrecords_mp(files, args): for results in pbar: pbar.update() for k, v in results.items(): - meta[k] += v # update metadata + meta[k] += v # update metadata return meta if __name__ == "__main__": - os.makedirs(args.output_dir, exist_ok=True) # make output dir if it doesn't exist + os.makedirs(args.output_dir, exist_ok=True) # make output dir if it doesn't exist files = get_files(args.input_dir) - args.chunk_size += 1 # we shift the data by 1 to the right for targets, so increment the chunk size here + args.chunk_size += 1 # we shift the data by 1 to the right for targets, so increment the chunk size here if args.processes == 0: args.processes = cpu_count() diff --git a/inputs.py b/inputs.py index 6191eec3..bef699be 100644 --- a/inputs.py +++ b/inputs.py @@ -8,6 +8,183 @@ from itertools import cycle from utils import natural_sort + +### IN USE ### + +def _get_number_of_documents(filename): + # extracts number of files from a filename formatted "_.tfrecords." + # if no pattern is matched, returns None + match = re.search("_(\d{1,}).tfrecords$", filename) + return int(match.group(1)) if match is not None else match + + +def _get_number_of_documents_by_iteration(filename): + # extracts number of files from a tfrecord document in the event it doesn't have metadata in the filename + # this could be very slow. + logging.warning( + "inputs/sequential_input() found no metadata found in filename - iterating through first tfrecord to find global length") + count = 0 + for item in tf.io.tf_record_iterator(filename): + count += 1 + return count + + +def _get_skip_index(all_files, n_batches): + prev_cumsum = 0 + cumsum = 0 + global_n_documents = None + for count, f in cycle(enumerate(all_files)): + prev_cumsum = cumsum + if _get_number_of_documents(f) is not None: + cumsum += _get_number_of_documents(f) + elif global_n_documents is None: + global_n_documents = _get_number_of_documents_by_iteration(f) + cumsum += global_n_documents + else: + cumsum += global_n_documents + if cumsum == n_batches: + remainder = 0 + skip_idx = count + 1 + elif cumsum > n_batches: + remainder = n_batches - prev_cumsum + skip_idx = count + break + return skip_idx, remainder + + +def _parse_function(example_proto): + features = { + "text": tf.VarLenFeature(tf.int64) + } + parsed_features = tf.parse_single_example(example_proto, features) + return tf.sparse.to_dense(parsed_features["text"], parsed_features["text"].dense_shape[0]) + + +def autoregressive_sample_text(params, x): + vals1 = x[:params["n_ctx"]] + vals2 = x[1:params["n_ctx"] + 1] + + vals1 = tf.reshape(vals1, [params["n_ctx"]]) + vals2 = tf.reshape(vals2, [params["n_ctx"]]) + vals1 = tf.cast(vals1, dtype=tf.int32) + vals2 = tf.cast(vals2, dtype=tf.int32) + return vals1, vals2 + + +def sequential_input(params, global_step=None, eval=False): + """ + Input fn that reads tfrecords encoded with a fixed chunk size (== n_ctx + 1), and that either: + + - has the number of documents for each tfrecord file encoded in the title in the format + _.tfrecords. + + OR + + - has a fixed number of documents per tfrecord file. + + If the glob pattern above isn't matched, we assume that each document has the same number of samples as the first tfrecord read. + If this isn't the case, it may result in errors, or some samples being missed. + + This means we can calculate the number of samples we've seen so far using the global step, + and can use dataset.skip() to iterate through the list of filenames, as opposed to the whole dataset, which is incredibly inefficient. + + If training is starting and stopping often, as with TPU pre-emption, reading the whole dataset sequentially appears to improve model + performance, as it results in less repeated data. + """ + if not eval: + assert global_step is not None + logging.warning( + "Changing batch size with sequential_input() will result in some data being skipped or repeated. Please ensure your batch size stays constant throughout training.") + batch_size = params['eval_batch_size' if eval else 'train_batch_size'] + + filenames = [] + for dataset_config in params['dataset_configs'].values(): # iterate through each dataset and read params + path_key = 'path' if not eval else 'eval_path' + path = dataset_config[path_key] + filenames.extend( + tf.io.gfile.glob(path)) # then glob all files that fit the pattern specified in dataset_configs + + filenames = natural_sort(filenames) + shuffle_filenames = params.get("shuffle_input_filenames", True) + if shuffle_filenames: + seed = params.get('seed', 1) # shuffle deterministically + random.seed(seed) + random.shuffle(filenames) + + dataset = tf.data.Dataset.from_tensor_slices(filenames).repeat() # repeat filenames to infinity + + if not eval: + # skip forward first in the filenames list, then skip the remaining amount in the parsed tfrecords files + skip_idx, remainder = _get_skip_index(filenames, n_batches=global_step * params[ + "train_batch_size"]) # TODO: fix for > 1 epoch + dataset = dataset.skip(skip_idx) # skip to skip idx + + # read tfrecord examples and skip remainder + dataset = dataset.apply(tf.data.TFRecordDataset) + dataset = dataset.skip(remainder) + else: + # shuffle filenames if in eval mode + dataset = dataset.shuffle(len(filenames)) + dataset = dataset.apply(tf.data.TFRecordDataset) + + # parse the tokenized data from the tfrecord files and shuffle + dataset = dataset.map(_parse_function, num_parallel_calls=1) + dataset = dataset.map(partial(autoregressive_sample_text, params), num_parallel_calls=1) + + # batch data and repeat to infinity + dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(params["iterations"] * 2) + return dataset.repeat() + + +def pred_input(params, logger, enc=None, + path_to_prompt=""): + unicorns = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \ + "previously unexplored valley, in the Andes Mountains. Even more surprising to the " \ + "researchers was the fact that the unicorns spoke perfect English." + + text = unicorns if path_to_prompt == "" else open(path_to_prompt, "r").read() + tokens = encode(enc, text) + + if len(tokens) > params["n_ctx"]: + logger.info("The length of your input prompt is longer than the model's context length - truncating input.") + tokens = tokens[len(tokens) - params["n_ctx"]:] + if len(tokens) < params["n_ctx"]: + tokens = tf.pad(tokens, [[0, params["n_ctx"] - len(tokens)]], constant_values=params["padding_id"]) + t = tf.broadcast_to(tokens, [params["batch_size"], params["n_ctx"]]) + dataset = tf.data.Dataset.from_tensors(t) + + def _dummy_labels(x): + return x, x + + dataset = dataset.map(_dummy_labels) + return dataset + + +def handle_pred_output(predictions, logger, enc, params, out_name="test"): + with tf.gfile.Open(f"{out_name}.txt", "w") as f: + for i, p in enumerate(predictions): + p = p["outputs"] + + # remove eos + padding ids from output + idx = np.argmax(p == params['eos_id']) + if idx > 0: + p = p[:idx] + idx = np.argmax(p == params['padding_id']) + if idx > 0: + p = p[:idx] + + text = enc.decode(p) + f.write("=" * 40 + " SAMPLE " + str(i) + " " + "=" * 40 + "\n") + f.write(text) + f.write("\n" + "=" * 80 + "\n") + + logger.info("=" * 40 + " SAMPLE " + str(i) + " " + "=" * 40 + "\n") + logger.info(text) + logger.info("\n" + "=" * 80 + "\n") + + +### DEPRECATED ### + def generic_text(params, eval=False, sample_text_fn=None, **kwargs): logging.warning("DEPRECATION WARNING: generic_text will be phased out in future versions.") i = 0 if not eval else 1 @@ -18,7 +195,8 @@ def generic_text(params, eval=False, sample_text_fn=None, **kwargs): for dataset in params["datasets"]: dataset_id, stitch, datatype, weight = dataset - assert dataset_id in params['dataset_configs'], f'Unknown dataset id {dataset_id} given. Please make sure your dataset ids contain that configuration' + assert dataset_id in params[ + 'dataset_configs'], f'Unknown dataset id {dataset_id} given. Please make sure your dataset ids contain that configuration' dataset_config = params['dataset_configs'][dataset_id] path_key = 'path' if not eval else 'eval_path' @@ -27,10 +205,10 @@ def generic_text(params, eval=False, sample_text_fn=None, **kwargs): datasets.append(text_dataset( tf.io.gfile.glob(path), params, - stitch = stitch, - datatype = datatype, - batch = False, - sample_text_fn = sample_text_fn + stitch=stitch, + datatype=datatype, + batch=False, + sample_text_fn=sample_text_fn )) weights.append(weight) @@ -42,9 +220,10 @@ def generic_text(params, eval=False, sample_text_fn=None, **kwargs): dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(params["iterations"] * 2) return dataset + def text_dataset(files, params, stitch, datatype, batch=True, sample_text_fn=None): seed = params.get('seed', None) - deterministic = seed is not None + deterministic = seed is not None num_parallel_calls = 1 if deterministic else tf.data.experimental.AUTOTUNE dataset = tf.data.Dataset.from_tensor_slices(files) @@ -95,12 +274,12 @@ def _get_x(i): # Hack-y way to stitch together multiple texts dataset = dataset.shuffle(1000 * stitch, seed=seed).batch(stitch, drop_remainder=True).map(_stitch_text, - num_parallel_calls=num_parallel_calls) + num_parallel_calls=num_parallel_calls) # Sample 1024(+1) tokens from the stitched together text is_random_documents = datatype == "documents_random" if sample_text_fn is not None: - _sample_text = partial(sample_text_fn, random_documents = is_random_documents) + _sample_text = partial(sample_text_fn, random_documents=is_random_documents) else: _sample_text = autoregressive_sample_text_random_documents if is_random_documents else autoregressive_sample_text _sample_text = partial(_sample_text, params) @@ -114,15 +293,6 @@ def _get_x(i): return dataset -def autoregressive_sample_text(params, x): - vals1 = x[:params["n_ctx"]] - vals2 = x[1:params["n_ctx"] + 1] - - vals1 = tf.reshape(vals1, [params["n_ctx"]]) - vals2 = tf.reshape(vals2, [params["n_ctx"]]) - vals1 = tf.cast(vals1, dtype=tf.int32) - vals2 = tf.cast(vals2, dtype=tf.int32) - return vals1, vals2 def autoregressive_sample_text_random_documents(params, x): seed = params.get('seed', None) @@ -131,7 +301,8 @@ def autoregressive_sample_text_random_documents(params, x): r1 = tf.range(r, r + params["n_ctx"]) r2 = tf.range(r + 1, (r + 1) + params["n_ctx"]) r1 = tf.reshape(r1, [params["n_ctx"]]) # Somehow, this makes the compiler happy - r2 = tf.reshape(r2, [params["n_ctx"]]) # TPUs want constant sized input, and these reshapes makes it recognize the shape of the input + r2 = tf.reshape(r2, [params[ + "n_ctx"]]) # TPUs want constant sized input, and these reshapes makes it recognize the shape of the input vals1 = tf.gather(x, r1) vals2 = tf.gather(x, r2) @@ -141,7 +312,8 @@ def autoregressive_sample_text_random_documents(params, x): vals2 = tf.cast(vals2, dtype=tf.int32) return vals1, vals2 -def mlm_sample_text(params, x, random_documents = False): + +def mlm_sample_text(params, x, random_documents=False): seed = params.get('seed', None) ctx_len = params["n_ctx"] assert 'mlm_mask_id' in params, 'the key `mlm_mask_id` must be set on your config to do masked language model training, specifying the id of the reserved mask token' @@ -185,12 +357,14 @@ def mlm_sample_text(params, x, random_documents = False): mask_mask &= can_mask # generate mask for actually replacing the tokens, for allowing a small number of tokens to stay the same - replace_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed), 1 - same_token_prob) + replace_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed), + 1 - same_token_prob) # randomly replace some tokens with random tokens before masking if random_token_prob > 0: - random_token_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed), random_token_prob) - random_tokens = tf.random.uniform(shape, minval = 1, maxval = num_tokens, dtype = tf.dtypes.int32, seed = seed) + random_token_mask = tf.less(tf.random.uniform(shape, minval=0., maxval=1., dtype=tf.float32, seed=seed), + random_token_prob) + random_tokens = tf.random.uniform(shape, minval=1, maxval=num_tokens, dtype=tf.dtypes.int32, seed=seed) # make sure random tokens do not include illegal token ids specified by `mlm_mask_ignore_ids` random_can_mask = tf.not_equal(random_tokens, 0) @@ -208,158 +382,3 @@ def mlm_sample_text(params, x, random_documents = False): masked_features, labels = map(lambda t: tf.reshape(t, [ctx_len]), (masked_features, labels)) return masked_features, labels - - -def pred_input(params, logger, enc=None, - path_to_prompt=""): - - unicorns = "In a shocking finding, scientists discovered a herd of unicorns living in a remote, " \ - "previously unexplored valley, in the Andes Mountains. Even more surprising to the " \ - "researchers was the fact that the unicorns spoke perfect English." - - text = unicorns if path_to_prompt == "" else open(path_to_prompt, "r").read() - tokens = encode(enc, text) - - if len(tokens) > params["n_ctx"]: - logger.info("The length of your input prompt is longer than the model's context length - truncating input.") - tokens = tokens[len(tokens) - params["n_ctx"]:] - if len(tokens) < params["n_ctx"]: - tokens = tf.pad(tokens, [[0, params["n_ctx"] - len(tokens)]], constant_values=params["padding_id"]) - t = tf.broadcast_to(tokens, [params["batch_size"], params["n_ctx"]]) - dataset = tf.data.Dataset.from_tensors(t) - - def _dummy_labels(x): - return x, x - - dataset = dataset.map(_dummy_labels) - return dataset - - -def handle_pred_output(predictions, logger, enc, params, out_name="test"): - with tf.gfile.Open(f"{out_name}.txt", "w") as f: - for i, p in enumerate(predictions): - p = p["outputs"] - - # remove eos + padding ids from output - idx = np.argmax(p == params['eos_id']) - if idx > 0: - p = p[:idx] - idx = np.argmax(p == params['padding_id']) - if idx > 0: - p = p[:idx] - - text = enc.decode(p) - f.write("=" * 40 + " SAMPLE " + str(i) + " " + "=" * 40 + "\n") - f.write(text) - f.write("\n" + "=" * 80 + "\n") - - logger.info("=" * 40 + " SAMPLE " + str(i) + " " + "=" * 40 + "\n") - logger.info(text) - logger.info("\n" + "=" * 80 + "\n") - -def _get_number_of_documents(filename): - # extracts number of files from a filename formatted "_.tfrecords." - # if no pattern is matched, returns None - match = re.search("_(\d{1,}).tfrecords$", filename) - return int(match.group(1)) if match is not None else match - -def _get_number_of_documents_by_iteration(filename): - # extracts number of files from a tfrecord document in the event it doesn't have metadata in the filename - # this could be very slow. - logging.warning("inputs/sequential_input() found no metadata found in filename - iterating through first tfrecord to find global length") - count = 0 - for item in tf.io.tf_record_iterator(filename): - count += 1 - return count - -def _get_skip_index(all_files, n_batches): - prev_cumsum = 0 - cumsum = 0 - global_n_documents = None - for count, f in cycle(enumerate(all_files)): - prev_cumsum = cumsum - if _get_number_of_documents(f) is not None: - cumsum += _get_number_of_documents(f) - elif global_n_documents is None: - global_n_documents = _get_number_of_documents_by_iteration(f) - cumsum += global_n_documents - else: - cumsum += global_n_documents - if cumsum == n_batches: - remainder = 0 - skip_idx = count + 1 - elif cumsum > n_batches: - remainder = n_batches - prev_cumsum - skip_idx = count - break - return skip_idx, remainder - -def _parse_function(example_proto): - features = { - "text": tf.VarLenFeature(tf.int64) - } - parsed_features = tf.parse_single_example(example_proto, features) - return tf.sparse.to_dense(parsed_features["text"], parsed_features["text"].dense_shape[0]) - -def sequential_input(params, global_step=None, eval=False): - """ - Input fn that reads tfrecords encoded with a fixed chunk size (== n_ctx + 1), and that either: - - - has the number of documents for each tfrecord file encoded in the title in the format - _.tfrecords. - - OR - - - has a fixed number of documents per tfrecord file. - - If the glob pattern above isn't matched, we assume that each document has the same number of samples as the first tfrecord read. - If this isn't the case, it may result in errors, or some samples being missed. - - This means we can calculate the number of samples we've seen so far using the global step, - and can use dataset.skip() to iterate through the list of filenames, as opposed to the whole dataset, which is incredibly inefficient. - - If training is starting and stopping often, as with TPU pre-emption, reading the whole dataset sequentially appears to improve model - performance, as it results in less repeated data. - """ - if not eval: - assert global_step is not None - logging.warning("Changing batch size with sequential_input() will result in some data being skipped or repeated. Please ensure your batch size stays constant throughout training.") - batch_size = params['eval_batch_size' if eval else 'train_batch_size'] - - filenames = [] - for dataset_config in params['dataset_configs'].values(): # iterate through each dataset and read params - path_key = 'path' if not eval else 'eval_path' - path = dataset_config[path_key] - filenames.extend(tf.io.gfile.glob(path)) # then glob all files that fit the pattern specified in dataset_configs - - filenames = natural_sort(filenames) - shuffle_filenames = params.get("shuffle_input_filenames", True) - if shuffle_filenames: - seed = params.get('seed', 1) # shuffle deterministically - random.seed(seed) - random.shuffle(filenames) - - dataset = tf.data.Dataset.from_tensor_slices(filenames).repeat() # repeat filenames to infinity - - if not eval: - # skip forward first in the filenames list, then skip the remaining amount in the parsed tfrecords files - skip_idx, remainder = _get_skip_index(filenames, n_batches=global_step * params["train_batch_size"]) # TODO: fix for > 1 epoch - dataset = dataset.skip(skip_idx) # skip to skip idx - - # read tfrecord examples and skip remainder - dataset = dataset.apply(tf.data.TFRecordDataset) - dataset = dataset.skip(remainder) - else: - # shuffle filenames if in eval mode - dataset = dataset.shuffle(len(filenames)) - dataset = dataset.apply(tf.data.TFRecordDataset) - - - # parse the tokenized data from the tfrecord files and shuffle - dataset = dataset.map(_parse_function, num_parallel_calls=1) - dataset = dataset.map(partial(autoregressive_sample_text, params), num_parallel_calls=1) - - # batch data and repeat to infinity - dataset = dataset.batch(batch_size, drop_remainder=True).prefetch(params["iterations"] * 2) - return dataset.repeat() - diff --git a/model_fns.py b/model_fns.py index 541816d2..ed2c3ad7 100644 --- a/model_fns.py +++ b/model_fns.py @@ -1,16 +1,17 @@ import mesh_tensorflow as mtf import tensorflow.compat.v1 as tf from tensorflow.python.tpu import tpu_estimator -import mesh_tensorflow.auto_mtf import mesh_tensorflow.transformer as mtf_transformer from optimizers import get_optimizer -from utils import (create_host_call, get_graph_info, remove_batch_from_layout, simd_mesh_setup, add_mode_to_params, get_batch_size, auto_layout, auto_layout_and_mesh_shape) +from utils import (create_host_call, get_graph_info, remove_batch_from_layout, simd_mesh_setup, add_mode_to_params, + get_batch_size, auto_layout, auto_layout_and_mesh_shape) from models.utils import biasmask_attn_weights from tensorflow.python.ops import resources from sample import sample_autoregressive from models.gpt2 import gpt2 import math + def model_fn(features, labels, mode, params): # Get global step global_step = tf.train.get_global_step() @@ -18,10 +19,8 @@ def model_fn(features, labels, mode, params): # Construct mtf graph + mesh from params graph = mtf.Graph() mesh_shape = mtf.convert_to_shape(params["mesh_shape"]) - if mode == tf.estimator.ModeKeys.PREDICT: - params["layout"] = remove_batch_from_layout(params["layout"]) layout_rules = mtf.convert_to_layout_rules(params["layout"]) - + # Mesh setup if params["use_tpu"]: var_placer, mesh_impl = simd_mesh_setup(params, mesh_shape, layout_rules) @@ -34,7 +33,8 @@ def model_fn(features, labels, mode, params): # Trainable variable precision # Store to checkpoints in master type, train in slice type, compute in activation type if params["precision"] == "bfloat16": - variable_dtype = mtf.VariableDType(master_dtype=tf.bfloat16, slice_dtype=tf.float32, activation_dtype=tf.bfloat16) + variable_dtype = mtf.VariableDType(master_dtype=tf.bfloat16, slice_dtype=tf.float32, + activation_dtype=tf.bfloat16) else: variable_dtype = mtf.VariableDType(master_dtype=tf.float32, slice_dtype=tf.float32, activation_dtype=tf.float32) @@ -98,12 +98,14 @@ def model_fn(features, labels, mode, params): if not export: mtf_samples = sample_autoregressive( inputs, other_features=other_features, params=params, variable_dtype=variable_dtype, - remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"], sampling_use_entmax=params['sampling_use_entmax']) + remove_partial_sequences=params["remove_partial_sequences"], stop_at_token=params["eos_id"], + sampling_use_entmax=params['sampling_use_entmax']) else: with mtf.utils.outside_all_rewrites(): with tf.variable_scope('gpt2'): - mtf_samples, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh, variable_dtype=variable_dtype, context=None) + mtf_samples, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh, + variable_dtype=variable_dtype, context=None) mtf_samples = mtf.anonymize(mtf_samples) inputs = mtf.anonymize(inputs) @@ -113,7 +115,7 @@ def model_fn(features, labels, mode, params): predictions = { "inputs": inputs, "outputs": outputs} - + def scaffold_fn(): return tf.train.Scaffold( local_init_op=tf.group( @@ -122,7 +124,7 @@ def scaffold_fn(): name="mtf_local_init_op"), ready_op=tf.concat( [tf.report_uninitialized_variables(), - resources.report_uninitialized_resources()], + resources.report_uninitialized_resources()], axis=0, name="mtf_ready_op")) @@ -139,22 +141,24 @@ def scaffold_fn(): # Gets number of microbatches per batch for serialized training # if param tokens_per_mb_per_replica = None, this defaults to 1 and no microbatching is performed num_microbatches = int(mtf_transformer.utils.serialize_num_microbatches(batch_dim=batch_dim, - sequence_length=sequence_length_dict, - mesh_shape=mesh_shape, - layout_rules=layout_rules, - tokens_per_microbatch_per_replica=params["tokens_per_mb_per_replica"])) + sequence_length=sequence_length_dict, + mesh_shape=mesh_shape, + layout_rules=layout_rules, + tokens_per_microbatch_per_replica= + params["tokens_per_mb_per_replica"])) else: num_microbatches = 1 params["num_microbatches"] = num_microbatches # Add num microbatches to params - + if num_microbatches > 1: - + # For serialize_training_step we need to modify the model to output results in a dict def serialized_fn(mtf_features): if params["model"] == "GPT": with tf.variable_scope('gpt2'): - logits, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh, variable_dtype=variable_dtype) + logits, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh, + variable_dtype=variable_dtype) return {"logits": logits, "loss": loss, "loss_batch": loss_batch} else: raise Exception(f"'{params['model']}' is not a valid model - please select from [GPT]") @@ -169,7 +173,8 @@ def serialized_fn(mtf_features): if params["model"] == "GPT": with mtf.utils.outside_all_rewrites(): with tf.variable_scope('gpt2'): - logits, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh, variable_dtype=variable_dtype, context=None) + logits, loss, loss_batch = gpt2.model(mtf_features, other_features, params, mesh, + variable_dtype=variable_dtype, context=None) else: raise Exception(f"'{params['model']}' is not a valid model - please select from [GPT]") @@ -184,7 +189,8 @@ def serialized_fn(mtf_features): if params["num_microbatches"] > 1: # If we are splitting the batch into microbatches, var grads are created in the serialize_training_step fn # So we pass them in here - _, update_ops, var_grads = get_optimizer(mesh, loss, params, variable_dtype=variable_dtype, inp_var_grads=var_grads) + _, update_ops, var_grads = get_optimizer(mesh, loss, params, variable_dtype=variable_dtype, + inp_var_grads=var_grads) else: # Otherwise, they are created in the get_optimizer fn, so we leave inp_var_grads blank _, update_ops, var_grads = get_optimizer(mesh, loss, params, variable_dtype=variable_dtype) @@ -297,4 +303,3 @@ def _lambada_metric_fn(labels, tf_max_logits, tf_loss_batch): evaluation_hooks=[restore_hook], loss=tf_loss, eval_metrics=eval_metrics) - diff --git a/models/activations.py b/models/activations.py new file mode 100644 index 00000000..9b199f87 --- /dev/null +++ b/models/activations.py @@ -0,0 +1,128 @@ +import mesh_tensorflow as mtf +import tensorflow.compat.v1 as tf +import random + + +def get_activation_fn(params): + activation_fn = params.get("activation_fn", "gelu") + + def _arcsinh(x): + return mtf.log(x + mtf.sqrt(1 + x ** 2)) + + def _var(x, init): + return mtf.get_variable(x.mesh, f"activation-{random.randint(0, 2 ** 32):x}", [], + initializer=tf.constant_initializer(init), dtype=x.dtype) + + def _pos_var(x, val): + return mtf.softplus(_var(x, 0)) + val + + if activation_fn == "gelu": # https://arxiv.org/abs/1606.08415 + return mtf.gelu + elif activation_fn == "relu": + return mtf.relu + elif activation_fn == "sigmoid": + return mtf.sigmoid + elif activation_fn == "tanh": + return mtf.tanh + elif activation_fn == "selu": # https://arxiv.org/abs/1706.02515 + return mtf.selu + elif activation_fn == "elu": # https://arxiv.org/abs/1511.07289 + return mtf.elu + elif activation_fn == "lrelu001": + return lambda x: mtf.leaky_relu(x, alpha=0.01) + elif activation_fn == "lrelu020": + return lambda x: mtf.leaky_relu(x, alpha=0.20) + + elif activation_fn == "abs": + return mtf.abs + elif activation_fn == "id": + return lambda x: x + elif activation_fn == "sin": + return mtf.sin + elif activation_fn == "cos": + return mtf.cos + elif activation_fn == "sign": + return mtf.sign + elif activation_fn == "triangle_relax": + return lambda x: mtf.sin(x) - mtf.sin(3 * x) / 9 + mtf.sin(5 * x) / 25 - mtf.sin(7 * x) / 49 + elif activation_fn == "square_relax": + return lambda x: mtf.cos(x) - mtf.cos(3 * x) / 3 + mtf.cos(5 * x) / 5 - mtf.cos(7 * x) / 7 + elif activation_fn == "spike": + return lambda x: 1 / (1 + x ** 2) + elif activation_fn == "spike2": + return lambda x: mtf.exp(-x ** 2) + + elif activation_fn == "tanhshrink": + return lambda x: x - tanh(x) + elif activation_fn == "softsign": + return lambda x: x / (mtf.abs(x) + 1) + elif activation_fn == "softmax": + return lambda x: mtf.softmax(x, x.shape[-1]) + elif activation_fn == "logsoftmax": + return lambda x: mtf.log_softmax(x, x.shape[-1]) + elif activation_fn == "bipolarsigmoid": + return lambda x: mtf.sigmoid(x) * 2 - 1 + elif activation_fn == "rrelu": # https://arxiv.org/abs/1505.00853 + def _rrelu_fn(x): + negative_scale = random.random() + return (negative_scale * mtf.abs(x) + x) / (1 + negative_scale) + + return _rrelu_fn + elif activation_fn == "elish": # https://arxiv.org/abs/1808.00783v1 + def _elish_fn(x): + cond = mtf.cast(mtf.greater(x, 0), x.dtype) + exp = mtf.exp(x) + return cond * x / (1 + exp) + (1 - cond) * (exp - 1) / (1 / exp + 1) + + return _elish_fn + + elif activation_fn == "silu": # https://arxiv.org/abs/1710.05941 + return mtf.swish + + elif activation_fn == "arcsinh": + return _arcsinh + + + # parametric + elif activation_fn == "aria": # https://arxiv.org/abs/1805.08878 + return lambda x: x * (_var(x, 0) + _var(x, 1) / ( + _pos_var(x, 0) + _var(x, 1) * mtf.exp(_var(x, -1) * x) ** (1 / _pos_var(x, 1)))) + elif activation_fn == "prelu": # https://arxiv.org/abs/1502.01852 + return lambda x: mtf.leaky_relu(x, alpha=_var(x, 0.2)) + elif activation_fn == "parcsinh": + return lambda x: _var(x, 1) * _arcsinh(x * _pos_var(x, 1)) + elif activation_fn == "psoftplus": + return lambda x: _var(x, 1) * mtf.softplus(x * _var(x, 1)) + _var(x, 0) + elif activation_fn == "proottanh": + return lambda x: (x ** _pos_var(x, 2) + _pos_var(x, 1)) ** (1 / _pos_var(x, 3)) * mtf.tanh(x) + + # https://arxiv.org/abs/1710.05941, https://arxiv.org/abs/1901.02671 + elif activation_fn == "maxsig": + return lambda x: mtf.maximum(x, mtf.sigmoid(x)) + elif activation_fn == "cosid": + return lambda x: mtf.cos(x) - x + elif activation_fn == "minsin": + return lambda x: mtf.minimum(x, mtf.sin(x)) + elif activation_fn == "maxtanh": + return lambda x: mtf.maximum(x, mtf.tanh(x)) + + elif activation_fn == "softplus": + return mtf.softplus + elif activation_fn == "mish": # https://arxiv.org/abs/1908.08681 + return lambda x: x * mtf.tanh(mtf.softplus(x)) + elif activation_fn == "tanhexp": # https://arxiv.org/abs/2003.09855 + return lambda x: x * mtf.tanh(mtf.exp(x)) + elif activation_fn == "lisht": # https://arxiv.org/abs/1901.05894 + return lambda x: x * mtf.tanh(x) + elif activation_fn == "seagull": # https://arxiv.org/abs/2011.11713 + return lambda x: mtf.log(1 + x ** 2) + elif activation_fn == "snake": # https://arxiv.org/abs/2006.08195 + return lambda x: x + mtf.sin(x) ** 2 + + elif activation_fn == "roottanh": # made up + return lambda x: (x ** 2 + 1) ** (1 / 3) * mtf.tanh(x) + elif activation_fn == "softplusmone": # made up + return lambda x: mtf.softplus(x) - 1 + + else: + raise ValueError('unknown activation function "activation_fn" in config') diff --git a/models/gpt2/gpt2.py b/models/gpt2/gpt2.py index 3324a399..59779f46 100644 --- a/models/gpt2/gpt2.py +++ b/models/gpt2/gpt2.py @@ -1,411 +1,13 @@ """GPT-like model in Mesh-Tensorflow""" -import mesh_tensorflow as mtf import tensorflow.compat.v1 as tf -import math import mesh_tensorflow.transformer as mtf_transformer -import random -from models.utils import parse_inputs, entmax_cross_entropy_with_logits - -# -------------------------------------------------------------------------------- -# LAYERS: - -sentinel = object() - - -def exists(x): - return x is not None - - -def identity(x, *args, **kwargs): - return x - - -def is_incremental_inference(context): - return exists(context) and context.mode == "incremental" - - -def norm(x, axis, epsilon=1e-8): - x -= mtf.reduce_mean(x, reduced_dim=axis, name="norm_reduce_mean_u") - s = mtf.reduce_mean(mtf.square(x), reduced_dim=axis, name="norm_reduce_mean_s") - return x * mtf.rsqrt(s + epsilon) - - -def rezero(x, scope, dtype): - with tf.variable_scope(scope): - g = mtf.get_variable(x.mesh, "g", [], initializer=tf.constant_initializer(0), dtype=dtype) - return x * g - - -def scale_norm(x, scope, *, variable_dtype, axis=sentinel, epsilon=1e-5, params=None): - if axis is sentinel: - axis = x.shape[-1] - - with tf.variable_scope(scope): - g = mtf.get_variable(x.mesh, "g", [], initializer=tf.constant_initializer(1), - master_dtype=variable_dtype.master_dtype, - slice_dtype=variable_dtype.slice_dtype, - activation_dtype=variable_dtype.activation_dtype) - - x = norm(x, axis, epsilon) - x = x * g - return x - - -def layer_norm(x, scope, *, variable_dtype, axis=sentinel, epsilon=1e-5, params=None): - """Normalize to mean = 0, std = 1, then do a diagonal affine transform.""" - if axis is sentinel: - axis = x.shape[-1] - - with tf.variable_scope(scope): - n_state = x.shape[-1] - - g = mtf.get_variable(x.mesh, "g", [n_state], initializer=tf.constant_initializer(1), - master_dtype=variable_dtype.master_dtype, - slice_dtype=variable_dtype.slice_dtype, - activation_dtype=variable_dtype.activation_dtype) - b = mtf.get_variable(x.mesh, "b", [n_state], initializer=tf.constant_initializer(0), - master_dtype=variable_dtype.master_dtype, - slice_dtype=variable_dtype.slice_dtype, - activation_dtype=variable_dtype.activation_dtype) - - x = norm(x, axis, epsilon) - x = x * g + b - return x - - -def linear_attention(q, k, v): - batch_dim, seq_dim, head_dim, dim_out = (v.shape[0], v.shape[1], v.shape[2], v.shape[3]) - q = mtf.rename_dimension(q, "features_per_head", "features_per_head_in") - k = mtf.rename_dimension(k, "features_per_head", "features_per_head_in") - - dim_in = k.shape[-1] - - q = mtf.softmax(q, dim_in) - k = mtf.softmax(k, seq_dim) - context = mtf.einsum([k, v], output_shape=[batch_dim, head_dim, dim_in, dim_out]) - attn = mtf.einsum([q, context], output_shape=[batch_dim, seq_dim, head_dim, dim_out]) - return attn - - -def causal_linear_attention(q, k, v, epsilon=1e-6): - batch_dim, seq_dim, head_dim, dim_out = (v.shape[0], v.shape[1], v.shape[2], v.shape[3]) - q = mtf.rename_dimension(q, "features_per_head", "features_per_head_in") - k = mtf.rename_dimension(k, "features_per_head", "features_per_head_in") - - dim_in = k.shape[-1] - - q = mtf.softmax(q, dim_in) - k = mtf.exp(k) - - cumulative_k = mtf.cumsum(k, seq_dim) - context = mtf.einsum([k, v], output_shape=[batch_dim, seq_dim, head_dim, dim_in, dim_out]) - cumulative_context = mtf.cumsum(context, seq_dim) - - cumulative_context /= (cumulative_k + epsilon) - attn = mtf.einsum([q, cumulative_context], output_shape=[batch_dim, seq_dim, head_dim, dim_out]) - return attn - - -def linear(x, scope, nf, *, w_init_stdev=0.02, variable_dtype, params=None, scale=False): - # nf = number of features - if params["scale_by_depth"] and scale: - # Scale by sqrt(num_layers), only happens at the final projection before a res block output - w_init_stdev = w_init_stdev * (1. / math.sqrt(params["n_layer"])) - if params["scale_by_in"]: # Scale by sqrt(num_input_features) - w_init_stdev = w_init_stdev * (1. / math.sqrt(x.shape[-1].size)) # Dimension is a namedtuple of (name, size) - # Not in the variable_scope because mtf already has a variable_scope in it - with tf.variable_scope("conv1d_main"): - c = mtf.layers.dense(x, new_dims=[nf], reduced_dims=[x.shape[-1]], name=scope, use_bias=True, - kernel_initializer=tf.random_normal_initializer(stddev=w_init_stdev), - variable_dtype=variable_dtype, - ) - return c - - -def memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh): - """memory / key values from all attention paper""" - - dim_mem_kv = mtf.Dimension("mem_kv_sequence", num_mem_kv) - emb_dim = k.shape[-1] - mem_std = 1 / math.sqrt(emb_dim.size) - - mem_k = mtf.get_variable(mesh, "mem_k", mtf.Shape([dim_mem_kv, dim_heads, emb_dim]), - initializer=tf.random_normal_initializer(stddev=mem_std), - master_dtype=variable_dtype.master_dtype, - slice_dtype=variable_dtype.slice_dtype, - activation_dtype=variable_dtype.activation_dtype, - ) - mem_v = mtf.get_variable(mesh, "mem_v", mtf.Shape([dim_mem_kv, dim_heads, emb_dim]), - initializer=tf.random_normal_initializer(stddev=mem_std), - master_dtype=variable_dtype.master_dtype, - slice_dtype=variable_dtype.slice_dtype, - activation_dtype=variable_dtype.activation_dtype) - - mem_k, mem_v = map(lambda t: mtf.broadcast(t, [dim_batch, dim_mem_kv, dim_heads, emb_dim]), - (mem_k, mem_v)) - mem_k, mem_v = map(lambda t: mtf.rename_dimension(t, "mem_kv_sequence", "sequence"), - (mem_k, mem_v)) - - k = mtf.concat([mem_k, k], "sequence") - v = mtf.concat([mem_v, v], "sequence") - return k, v - - -def attn(x, scope, n_state, *, attention_type, params, bias, dim_seq, memory_length_dim, variable_dtype, context=None): - # x :: [batch, seq, n_embd] - x_shape, dim_batch, *_, dim_embd, mesh = x.shape, *x.shape, x.mesh - - # n_state is the same as config["n_embd"], which is also the same as dim_embd. - assert n_state.size % params["n_head"] == 0 - - dim_heads = mtf.Dimension("heads", params["n_head"]) - - num_mem_kv = params.get("num_mem_kv", 0) - use_num_mem_kv = num_mem_kv > 0 - - with tf.variable_scope(scope): - # Compute attention inputs - dim_kv = mtf.Dimension("features_per_head", params["n_embd"] // params["n_head"]) - mtfparams = mtf.transformer.attention.attention_params_simple( - x.mesh, - io_dim=dim_embd, - kv_dim=dim_kv, - heads_dim=dim_heads, - variable_dtype=variable_dtype - ) - q = mtfparams.compute_q(x) - k = mtfparams.compute_k(x) - v = mtfparams.compute_v(x) - - if is_incremental_inference(context): - one_hot = mtf.one_hot(context.position - 1, dim_seq, dtype=variable_dtype.master_dtype) - inv_one_hot = 1.0 - one_hot - old_k, old_v = context.get_states(2) - k = old_k * inv_one_hot + k * one_hot - v = old_v * inv_one_hot + v * one_hot - - if exists(context): - context.record_new_states([k, v]) - - with tf.variable_scope("attention"): - if attention_type == "local": - # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights. - radius = params.get("local_attention_radius", 256) - - if is_incremental_inference(context): - q *= one_hot - - a = mtf_transformer.attention.local_attention_1d( - q, k, v, - length_dim=k.shape[1], - key_dim=dim_kv, - value_dim=dim_kv, - radius=radius, - length_dim_num_splits=1, - fully_autoregressive=params["causal"], - attention_kwargs={}, - ) - - if is_incremental_inference(context): - a = mtf.gather(a, context.position - 1, dim_seq) - - elif attention_type == "global": - - # TODO: pass in fake context - # Broadcast mask bias across batch and heads - if exists(bias): - if not is_incremental_inference(context): - broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-2], bias.shape[-1]]) - else: - # In the incremental case, a custom mask needs to be built that masks out all key/values that are greater than the current position - bias = mtf.gather(bias, context.position - 1, dim_seq) - broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-1]]) - - # memory key / values, from all-attention paper - if use_num_mem_kv: - k, v = memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh) - - k = mtf.replace_dimensions(k, k.shape[1], memory_length_dim) - v = mtf.replace_dimensions(v, v.shape[1], memory_length_dim) - - attn_dropout_rate = params["attn_dropout"] if params["mode"] == "train" else 0 - - a = mtf_transformer.attention.attention( - q, k, v, - memory_length_dim=memory_length_dim, - key_dim=dim_kv, - value_dim=dim_kv, - bias=broadcasted_bias, - dropout_rate=attn_dropout_rate - ) - - elif attention_type == "linear": - linear_attn_fn = causal_linear_attention if params["causal"] else linear_attention - a = linear_attn_fn(q, k, v) - - else: - raise NotImplementedError("Unknown attention type {}!".format(attention_type)) - - with tf.variable_scope("compute_output"): - a = mtfparams.compute_output(a, x_shape) - - with tf.variable_scope("compute_output_bias"): - b = mtf.get_variable(x.mesh, "o_b", [dim_embd], initializer=tf.constant_initializer(0), - master_dtype=variable_dtype.master_dtype, - slice_dtype=variable_dtype.slice_dtype, - activation_dtype=variable_dtype.activation_dtype) - a += b - - if params["mode"] == "train" and params["res_dropout"] > 0: - a = mtf.dropout(a, rate=params["res_dropout"], name="res_dropout") - return a - - -def get_activation_fn(params): - activation_fn = params.get("activation_fn", "gelu") - - - def _arcsinh(x): - return mtf.log(x + mtf.sqrt(1 + x ** 2)) - def _var(x, init): - return mtf.get_variable(x.mesh, f"activation-{random.randint(0, 2**32):x}", [], initializer=tf.constant_initializer(init), dtype=x.dtype) - def _pos_var(x, val): - return mtf.softplus(_var(x, 0)) + val - - if activation_fn == "gelu": # https://arxiv.org/abs/1606.08415 - return mtf.gelu - elif activation_fn == "relu": - return mtf.relu - elif activation_fn == "sigmoid": - return mtf.sigmoid - elif activation_fn == "tanh": - return mtf.tanh - elif activation_fn == "selu": # https://arxiv.org/abs/1706.02515 - return mtf.selu - elif activation_fn == "elu": # https://arxiv.org/abs/1511.07289 - return mtf.elu - elif activation_fn == "lrelu001": - return lambda x: mtf.leaky_relu(x, alpha=0.01) - elif activation_fn == "lrelu020": - return lambda x: mtf.leaky_relu(x, alpha=0.20) - - elif activation_fn == "abs": - return mtf.abs - elif activation_fn == "id": - return lambda x: x - elif activation_fn == "sin": - return mtf.sin - elif activation_fn == "cos": - return mtf.cos - elif activation_fn == "sign": - return mtf.sign - elif activation_fn == "triangle_relax": - return lambda x: mtf.sin(x)-mtf.sin(3*x)/9+mtf.sin(5*x)/25-mtf.sin(7*x)/49 - elif activation_fn == "square_relax": - return lambda x: mtf.cos(x)-mtf.cos(3*x)/3+mtf.cos(5*x)/5-mtf.cos(7*x)/7 - elif activation_fn == "spike": - return lambda x: 1/(1+x**2) - elif activation_fn == "spike2": - return lambda x: mtf.exp(-x**2) - - elif activation_fn == "tanhshrink": - return lambda x: x - tanh(x) - elif activation_fn == "softsign": - return lambda x: x / (mtf.abs(x) + 1) - elif activation_fn == "softmax": - return lambda x: mtf.softmax(x, x.shape[-1]) - elif activation_fn == "logsoftmax": - return lambda x: mtf.log_softmax(x, x.shape[-1]) - elif activation_fn == "bipolarsigmoid": - return lambda x: mtf.sigmoid(x) * 2 - 1 - elif activation_fn == "rrelu": # https://arxiv.org/abs/1505.00853 - def _rrelu_fn(x): - negative_scale = random.random() - return (negative_scale * mtf.abs(x) + x) / (1 + negative_scale) - return _rrelu_fn - elif activation_fn == "elish": # https://arxiv.org/abs/1808.00783v1 - def _elish_fn(x): - cond = mtf.cast(mtf.greater(x, 0), x.dtype) - exp = mtf.exp(x) - return cond * x / (1 + exp) + (1 - cond) * (exp - 1) / (1 / exp + 1) - return _elish_fn - - elif activation_fn == "silu": # https://arxiv.org/abs/1710.05941 - return mtf.swish - - elif activation_fn == "arcsinh": - return _arcsinh - - - # parametric - elif activation_fn == "aria": # https://arxiv.org/abs/1805.08878 - return lambda x: x * (_var(x, 0) + _var(x, 1) / (_pos_var(x, 0) + _var(x, 1) * mtf.exp(_var(x, -1) * x) ** (1 / _pos_var(x, 1)))) - elif activation_fn == "prelu": # https://arxiv.org/abs/1502.01852 - return lambda x: mtf.leaky_relu(x, alpha=_var(x, 0.2)) - elif activation_fn == "parcsinh": - return lambda x: _var(x, 1) * _arcsinh(x * _pos_var(x, 1)) - elif activation_fn == "psoftplus": - return lambda x: _var(x, 1) * mtf.softplus(x * _var(x, 1)) + _var(x, 0) - elif activation_fn == "proottanh": - return lambda x: (x ** _pos_var(x, 2) + _pos_var(x, 1)) ** (1 / _pos_var(x, 3)) * mtf.tanh(x) - - # https://arxiv.org/abs/1710.05941, https://arxiv.org/abs/1901.02671 - elif activation_fn == "maxsig": - return lambda x: mtf.maximum(x, mtf.sigmoid(x)) - elif activation_fn == "cosid": - return lambda x: mtf.cos(x) - x - elif activation_fn == "minsin": - return lambda x: mtf.minimum(x, mtf.sin(x)) - elif activation_fn == "maxtanh": - return lambda x: mtf.maximum(x, mtf.tanh(x)) - - elif activation_fn == "softplus": - return mtf.softplus - elif activation_fn == "mish": # https://arxiv.org/abs/1908.08681 - return lambda x: x * mtf.tanh(mtf.softplus(x)) - elif activation_fn == "tanhexp": # https://arxiv.org/abs/2003.09855 - return lambda x: x * mtf.tanh(mtf.exp(x)) - elif activation_fn == "lisht": # https://arxiv.org/abs/1901.05894 - return lambda x: x * mtf.tanh(x) - elif activation_fn == "seagull": # https://arxiv.org/abs/2011.11713 - return lambda x: mtf.log(1 + x ** 2) - elif activation_fn == "snake": # https://arxiv.org/abs/2006.08195 - return lambda x: x + mtf.sin(x) ** 2 - - elif activation_fn == "roottanh": # made up - return lambda x: (x ** 2 + 1) ** (1/3) * mtf.tanh(x) - elif activation_fn == "softplusmone": # made up - return lambda x: mtf.softplus(x) - 1 - - else: - raise ValueError('unknown activation function "activation_fn" in config') - -def mlp(x, scope, n_state, *, variable_dtype, params): - activation_fn = get_activation_fn(params) - with tf.variable_scope(scope): - nx = x.shape[-1] - h = activation_fn(linear(x, "c_fc", n_state, variable_dtype=variable_dtype, params=params)) - h2 = linear(h, "c_proj", nx, variable_dtype=variable_dtype, params=params, scale=True) - if params["mode"] == "train" and params["res_dropout"] > 0: - h2 = mtf.dropout(h2, rate=params["res_dropout"], name="mlp_dropout") - return h2 - - -def mlp_glu(x, scope, n_state, *, variable_dtype, params): - activation_fn = get_activation_fn(params) - with tf.variable_scope(scope): - nx = x.shape[-1] - h = linear(x, "c_fc", n_state, params=params) - - h, gate = mtf.split(h, h.shape[-1], 2) - h *= activation_fn(gate) +from models.utils import parse_inputs, entmax_cross_entropy_with_logits +from models.layers import * - h2 = linear(h, "c_proj", nx, variable_dtype=variable_dtype, params=params, scale=True) - if params["mode"] == "train" and params["res_dropout"] > 0: - h2 = mtf.dropout(h2, rate=params["res_dropout"], name="mlp_dropout") - return h2 +# -------------------------------------------------------------------------------- +# TRANSFORMER BLOCK: def block(params, scope, layer_num, bias, sequence_dim, memory_length_dim, variable_dtype, context=None): use_mlp_glu = params["mlp_glu"] == True @@ -428,7 +30,7 @@ def fn(x): pre_residual_fn = rezero if use_rezero else identity attention_type = params["attention_types"][layer_num] - + if macaron_attention: mult = 0.5 mlp_fn = mlp_glu if use_mlp_glu else mlp @@ -436,7 +38,7 @@ def fn(x): # Define intermediate layer of mlp - to split dim_intermediate_expanded = mtf.Dimension("intermediate_expanded", intermediate_size) m = mlp_fn(x, "mlp_macaron", dim_intermediate_expanded, variable_dtype=variable_dtype, params=params) - + x = x + (m * mult) else: mult = 1 @@ -469,7 +71,8 @@ def fn(x): train=moe_train, mesh_shape=params["mesh_shape"], layout=params["layout"], - activation=params.get("moe_activation", "relu"), + activation=params.get("moe_activation", + "relu"), variable_dtype=variable_dtype, num_microbatches=params["num_microbatches"]) m = mtf.dropout(m, rate=params["res_dropout"], name="moe_dropout") @@ -484,41 +87,14 @@ def fn(x): m = mlp_fn(res_x, "mlp", dim_intermediate_expanded, variable_dtype=variable_dtype, params=params) aux_loss = mtf.zeros(x.mesh, mtf.Shape([]), dtype=variable_dtype.slice_dtype) - x = x + pre_residual_fn((m*mult), "norm_rezero_2", variable_dtype) + x = x + pre_residual_fn((m * mult), "norm_rezero_2", variable_dtype) return x, aux_loss return fn -def axial_positional_emb(embd_dim, mesh, params, variable_dtype): - # Use axial position encoding - axial_dim_1, axial_dim_2 = params["axial_pos_emb"] - - axial_dim = mtf.Dimension("axial_dim", axial_dim_1 * axial_dim_2) - dim_axials = [mtf.Dimension(f"axial_dim_{i}", t) for i, t in enumerate((axial_dim_1, axial_dim_2))] - - axial_wpe_1 = mtf.get_variable(mesh, "axial_wpe_1", mtf.Shape([dim_axials[0], embd_dim]), - initializer=tf.random_normal_initializer(stddev=0.01), - master_dtype=variable_dtype.master_dtype, - slice_dtype=variable_dtype.slice_dtype, - activation_dtype=variable_dtype.activation_dtype) - - axial_wpe_2 = mtf.get_variable(mesh, "axial_wpe_2", mtf.Shape([dim_axials[1], embd_dim]), - initializer=tf.random_normal_initializer(stddev=0.01), - master_dtype=variable_dtype.master_dtype, - slice_dtype=variable_dtype.slice_dtype, - activation_dtype=variable_dtype.activation_dtype) - - axial_wpe_1, axial_wpe_2 = map(lambda t: mtf.broadcast(t, [dim_axials[0], dim_axials[1], embd_dim]), - (axial_wpe_1, axial_wpe_2)) - wpe = (axial_wpe_1 + axial_wpe_2) / 2 - - wpe = mtf.reshape(wpe, [axial_dim, embd_dim]) - - return wpe - # -------------------------------------------------------------------------------- -# MODEL: +# GPT2 MODEL: def model(mtf_features, other_features, params, mesh, variable_dtype, context=None): """A GPT style model implemented in mesh tensorflow.""" @@ -597,7 +173,7 @@ def model(mtf_features, other_features, params, mesh, variable_dtype, context=No if params["mode"] in ["train", "eval"]: labels = mtf_features["labels"] - z_loss = params.get("z_loss", 1e-4) # an auxiliary loss used to stabilize mtf xentropy + z_loss = params.get("z_loss", 1e-4) # an auxiliary loss used to stabilize mtf xentropy # Go to full precision for the logits logits = mtf.cast(logits, tf.float32) diff --git a/models/layers.py b/models/layers.py new file mode 100644 index 00000000..916b025c --- /dev/null +++ b/models/layers.py @@ -0,0 +1,317 @@ +import mesh_tensorflow as mtf +import tensorflow.compat.v1 as tf +import math +import mesh_tensorflow.transformer as mtf_transformer + +from models.activations import get_activation_fn + + +# -------------------------------------------------------------------------------- +# LAYERS: + +sentinel = object() + + +def exists(x): + return x is not None + + +def identity(x, *args, **kwargs): + return x + + +def is_incremental_inference(context): + return exists(context) and context.mode == "incremental" + + +def norm(x, axis, epsilon=1e-8): + x -= mtf.reduce_mean(x, reduced_dim=axis, name="norm_reduce_mean_u") + s = mtf.reduce_mean(mtf.square(x), reduced_dim=axis, name="norm_reduce_mean_s") + return x * mtf.rsqrt(s + epsilon) + + +def rezero(x, scope, dtype): + with tf.variable_scope(scope): + g = mtf.get_variable(x.mesh, "g", [], initializer=tf.constant_initializer(0), dtype=dtype) + return x * g + + +def scale_norm(x, scope, *, variable_dtype, axis=sentinel, epsilon=1e-5, params=None): + if axis is sentinel: + axis = x.shape[-1] + + with tf.variable_scope(scope): + g = mtf.get_variable(x.mesh, "g", [], initializer=tf.constant_initializer(1), + master_dtype=variable_dtype.master_dtype, + slice_dtype=variable_dtype.slice_dtype, + activation_dtype=variable_dtype.activation_dtype) + + x = norm(x, axis, epsilon) + x = x * g + return x + + +def layer_norm(x, scope, *, variable_dtype, axis=sentinel, epsilon=1e-5, params=None): + """Normalize to mean = 0, std = 1, then do a diagonal affine transform.""" + if axis is sentinel: + axis = x.shape[-1] + + with tf.variable_scope(scope): + n_state = x.shape[-1] + + g = mtf.get_variable(x.mesh, "g", [n_state], initializer=tf.constant_initializer(1), + master_dtype=variable_dtype.master_dtype, + slice_dtype=variable_dtype.slice_dtype, + activation_dtype=variable_dtype.activation_dtype) + b = mtf.get_variable(x.mesh, "b", [n_state], initializer=tf.constant_initializer(0), + master_dtype=variable_dtype.master_dtype, + slice_dtype=variable_dtype.slice_dtype, + activation_dtype=variable_dtype.activation_dtype) + + x = norm(x, axis, epsilon) + x = x * g + b + return x + + +def linear_attention(q, k, v): + batch_dim, seq_dim, head_dim, dim_out = (v.shape[0], v.shape[1], v.shape[2], v.shape[3]) + q = mtf.rename_dimension(q, "features_per_head", "features_per_head_in") + k = mtf.rename_dimension(k, "features_per_head", "features_per_head_in") + + dim_in = k.shape[-1] + + q = mtf.softmax(q, dim_in) + k = mtf.softmax(k, seq_dim) + + context = mtf.einsum([k, v], output_shape=[batch_dim, head_dim, dim_in, dim_out]) + attn = mtf.einsum([q, context], output_shape=[batch_dim, seq_dim, head_dim, dim_out]) + return attn + + +def causal_linear_attention(q, k, v, epsilon=1e-6): + batch_dim, seq_dim, head_dim, dim_out = (v.shape[0], v.shape[1], v.shape[2], v.shape[3]) + q = mtf.rename_dimension(q, "features_per_head", "features_per_head_in") + k = mtf.rename_dimension(k, "features_per_head", "features_per_head_in") + + dim_in = k.shape[-1] + + q = mtf.softmax(q, dim_in) + k = mtf.exp(k) + + cumulative_k = mtf.cumsum(k, seq_dim) + context = mtf.einsum([k, v], output_shape=[batch_dim, seq_dim, head_dim, dim_in, dim_out]) + cumulative_context = mtf.cumsum(context, seq_dim) + + cumulative_context /= (cumulative_k + epsilon) + attn = mtf.einsum([q, cumulative_context], output_shape=[batch_dim, seq_dim, head_dim, dim_out]) + return attn + + +def linear(x, scope, nf, *, w_init_stdev=0.02, variable_dtype, params=None, scale=False): + # nf = number of features + if params["scale_by_depth"] and scale: + # Scale by sqrt(num_layers), only happens at the final projection before a res block output + w_init_stdev = w_init_stdev * (1. / math.sqrt(params["n_layer"])) + if params["scale_by_in"]: # Scale by sqrt(num_input_features) + w_init_stdev = w_init_stdev * (1. / math.sqrt(x.shape[-1].size)) # Dimension is a namedtuple of (name, size) + # Not in the variable_scope because mtf already has a variable_scope in it + with tf.variable_scope("conv1d_main"): + c = mtf.layers.dense(x, new_dims=[nf], reduced_dims=[x.shape[-1]], name=scope, use_bias=True, + kernel_initializer=tf.random_normal_initializer(stddev=w_init_stdev), + variable_dtype=variable_dtype, + ) + return c + + +def memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh): + """memory / key values from all attention paper""" + + dim_mem_kv = mtf.Dimension("mem_kv_sequence", num_mem_kv) + emb_dim = k.shape[-1] + mem_std = 1 / math.sqrt(emb_dim.size) + + mem_k = mtf.get_variable(mesh, "mem_k", mtf.Shape([dim_mem_kv, dim_heads, emb_dim]), + initializer=tf.random_normal_initializer(stddev=mem_std), + master_dtype=variable_dtype.master_dtype, + slice_dtype=variable_dtype.slice_dtype, + activation_dtype=variable_dtype.activation_dtype, + ) + mem_v = mtf.get_variable(mesh, "mem_v", mtf.Shape([dim_mem_kv, dim_heads, emb_dim]), + initializer=tf.random_normal_initializer(stddev=mem_std), + master_dtype=variable_dtype.master_dtype, + slice_dtype=variable_dtype.slice_dtype, + activation_dtype=variable_dtype.activation_dtype) + + mem_k, mem_v = map(lambda t: mtf.broadcast(t, [dim_batch, dim_mem_kv, dim_heads, emb_dim]), + (mem_k, mem_v)) + mem_k, mem_v = map(lambda t: mtf.rename_dimension(t, "mem_kv_sequence", "sequence"), + (mem_k, mem_v)) + + k = mtf.concat([mem_k, k], "sequence") + v = mtf.concat([mem_v, v], "sequence") + return k, v + + +def attn(x, scope, n_state, *, attention_type, params, bias, dim_seq, memory_length_dim, variable_dtype, context=None): + # x :: [batch, seq, n_embd] + x_shape, dim_batch, *_, dim_embd, mesh = x.shape, *x.shape, x.mesh + + # n_state is the same as config["n_embd"], which is also the same as dim_embd. + assert n_state.size % params["n_head"] == 0 + + dim_heads = mtf.Dimension("heads", params["n_head"]) + + num_mem_kv = params.get("num_mem_kv", 0) + use_num_mem_kv = num_mem_kv > 0 + + with tf.variable_scope(scope): + # Compute attention inputs + dim_kv = mtf.Dimension("features_per_head", params["n_embd"] // params["n_head"]) + mtfparams = mtf.transformer.attention.attention_params_simple( + x.mesh, + io_dim=dim_embd, + kv_dim=dim_kv, + heads_dim=dim_heads, + variable_dtype=variable_dtype + ) + q = mtfparams.compute_q(x) + k = mtfparams.compute_k(x) + v = mtfparams.compute_v(x) + + if is_incremental_inference(context): + one_hot = mtf.one_hot(context.position - 1, dim_seq, dtype=variable_dtype.master_dtype) + inv_one_hot = 1.0 - one_hot + old_k, old_v = context.get_states(2) + k = old_k * inv_one_hot + k * one_hot + v = old_v * inv_one_hot + v * one_hot + + if exists(context): + context.record_new_states([k, v]) + + with tf.variable_scope("attention"): + if attention_type == "local": + # `local_attention_1d` has built in autoregressive masking, so we don't need mask_attn_weights. + radius = params.get("local_attention_radius", 256) + + if is_incremental_inference(context): + q *= one_hot + + a = mtf_transformer.attention.local_attention_1d( + q, k, v, + length_dim=k.shape[1], + key_dim=dim_kv, + value_dim=dim_kv, + radius=radius, + length_dim_num_splits=1, + fully_autoregressive=params["causal"], + attention_kwargs={}, + ) + + if is_incremental_inference(context): + a = mtf.gather(a, context.position - 1, dim_seq) + + elif attention_type == "global": + + # TODO: pass in fake context + # Broadcast mask bias across batch and heads + if exists(bias): + if not is_incremental_inference(context): + broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-2], bias.shape[-1]]) + else: + # In the incremental case, a custom mask needs to be built that masks out all key/values that are greater than the current position + bias = mtf.gather(bias, context.position - 1, dim_seq) + broadcasted_bias = mtf.broadcast(bias, [dim_batch, dim_heads, bias.shape[-1]]) + + # memory key / values, from all-attention paper + if use_num_mem_kv: + k, v = memory_key_values(k, v, num_mem_kv, dim_batch, dim_heads, variable_dtype, mesh) + + k = mtf.replace_dimensions(k, k.shape[1], memory_length_dim) + v = mtf.replace_dimensions(v, v.shape[1], memory_length_dim) + + attn_dropout_rate = params["attn_dropout"] if params["mode"] == "train" else 0 + + a = mtf_transformer.attention.attention( + q, k, v, + memory_length_dim=memory_length_dim, + key_dim=dim_kv, + value_dim=dim_kv, + bias=broadcasted_bias, + dropout_rate=attn_dropout_rate + ) + + elif attention_type == "linear": + linear_attn_fn = causal_linear_attention if params["causal"] else linear_attention + a = linear_attn_fn(q, k, v) + + else: + raise NotImplementedError("Unknown attention type {}!".format(attention_type)) + + with tf.variable_scope("compute_output"): + a = mtfparams.compute_output(a, x_shape) + + with tf.variable_scope("compute_output_bias"): + b = mtf.get_variable(x.mesh, "o_b", [dim_embd], initializer=tf.constant_initializer(0), + master_dtype=variable_dtype.master_dtype, + slice_dtype=variable_dtype.slice_dtype, + activation_dtype=variable_dtype.activation_dtype) + a += b + + if params["mode"] == "train" and params["res_dropout"] > 0: + a = mtf.dropout(a, rate=params["res_dropout"], name="res_dropout") + return a + + +def mlp(x, scope, n_state, *, variable_dtype, params): + activation_fn = get_activation_fn(params) + with tf.variable_scope(scope): + nx = x.shape[-1] + h = activation_fn(linear(x, "c_fc", n_state, variable_dtype=variable_dtype, params=params)) + h2 = linear(h, "c_proj", nx, variable_dtype=variable_dtype, params=params, scale=True) + if params["mode"] == "train" and params["res_dropout"] > 0: + h2 = mtf.dropout(h2, rate=params["res_dropout"], name="mlp_dropout") + return h2 + + +def mlp_glu(x, scope, n_state, *, variable_dtype, params): + activation_fn = get_activation_fn(params) + with tf.variable_scope(scope): + nx = x.shape[-1] + h = linear(x, "c_fc", n_state, params=params) + + h, gate = mtf.split(h, h.shape[-1], 2) + h *= activation_fn(gate) + + h2 = linear(h, "c_proj", nx, variable_dtype=variable_dtype, params=params, scale=True) + if params["mode"] == "train" and params["res_dropout"] > 0: + h2 = mtf.dropout(h2, rate=params["res_dropout"], name="mlp_dropout") + return h2 + + +def axial_positional_emb(embd_dim, mesh, params, variable_dtype): + # Use axial position encoding + axial_dim_1, axial_dim_2 = params["axial_pos_emb"] + + axial_dim = mtf.Dimension("axial_dim", axial_dim_1 * axial_dim_2) + dim_axials = [mtf.Dimension(f"axial_dim_{i}", t) for i, t in enumerate((axial_dim_1, axial_dim_2))] + + axial_wpe_1 = mtf.get_variable(mesh, "axial_wpe_1", mtf.Shape([dim_axials[0], embd_dim]), + initializer=tf.random_normal_initializer(stddev=0.01), + master_dtype=variable_dtype.master_dtype, + slice_dtype=variable_dtype.slice_dtype, + activation_dtype=variable_dtype.activation_dtype) + + axial_wpe_2 = mtf.get_variable(mesh, "axial_wpe_2", mtf.Shape([dim_axials[1], embd_dim]), + initializer=tf.random_normal_initializer(stddev=0.01), + master_dtype=variable_dtype.master_dtype, + slice_dtype=variable_dtype.slice_dtype, + activation_dtype=variable_dtype.activation_dtype) + + axial_wpe_1, axial_wpe_2 = map(lambda t: mtf.broadcast(t, [dim_axials[0], dim_axials[1], embd_dim]), + (axial_wpe_1, axial_wpe_2)) + wpe = (axial_wpe_1 + axial_wpe_2) / 2 + + wpe = mtf.reshape(wpe, [axial_dim, embd_dim]) + + return wpe + diff --git a/models/utils.py b/models/utils.py index 03733de1..2832ea85 100644 --- a/models/utils.py +++ b/models/utils.py @@ -2,7 +2,9 @@ import mesh_tensorflow as mtf from functools import partial -def entmax_backward(explicit_inputs, all_inputs, forward_operations, outputs, output_grads, alpha = 1.3, dim = None, n_iter = 50): + +def entmax_backward(explicit_inputs, all_inputs, forward_operations, outputs, output_grads, alpha=1.3, dim=None, + n_iter=50): x, = explicit_inputs y, = outputs dY, = output_grads @@ -10,12 +12,13 @@ def entmax_backward(explicit_inputs, all_inputs, forward_operations, outputs, ou gppr = mtf.where(mtf.greater(y, 0), mtf.pow(y, (2 - alpha)), mtf.zeros_like(y)) dX = dY * gppr - q = mtf.reduce_sum(dX, reduced_dim = dim) / mtf.reduce_sum(gppr, reduced_dim = dim) + q = mtf.reduce_sum(dX, reduced_dim=dim) / mtf.reduce_sum(gppr, reduced_dim=dim) dX = dX - q * gppr return dX, -def entmax_forward(x, alpha = 1.3, dim = None, n_iter = 50): + +def entmax_forward(x, alpha=1.3, dim=None, n_iter=50): assert alpha > 1 and alpha < 2, 'alpha must be between 1 and 2' _gp = lambda x, alpha: x ** (alpha - 1) @@ -27,12 +30,12 @@ def entmax_forward(x, alpha = 1.3, dim = None, n_iter = 50): x = x * (alpha - 1) - max_val = mtf.reduce_max(x, reduced_dim = dim) + max_val = mtf.reduce_max(x, reduced_dim=dim) tau_lo = max_val - _gp(1, alpha) tau_hi = max_val - _gp(1 / d, alpha) - f_lo = mtf.reduce_sum(_p(x - tau_lo, alpha), reduced_dim = dim) - 1 + f_lo = mtf.reduce_sum(_p(x - tau_lo, alpha), reduced_dim=dim) - 1 dm = tau_hi - tau_lo @@ -40,16 +43,17 @@ def entmax_forward(x, alpha = 1.3, dim = None, n_iter = 50): dm = dm / 2 tau_m = tau_lo + dm p_m = _p(x - tau_m, alpha) - f_m = mtf.reduce_sum(p_m, reduced_dim = dim) - 1 + f_m = mtf.reduce_sum(p_m, reduced_dim=dim) - 1 mask = mtf.greater_equal((f_m * f_lo), 0) tau_lo = mtf.where(mask, tau_m, tau_lo) - p_m = p_m / mtf.reduce_sum(p_m, reduced_dim = dim) + p_m = p_m / mtf.reduce_sum(p_m, reduced_dim=dim) return p_m -def entmax(x, alpha = 1.3, dim = None, n_iter = 50): - kwargs = dict(alpha = alpha, dim = dim, n_iter = n_iter) + +def entmax(x, alpha=1.3, dim=None, n_iter=50): + kwargs = dict(alpha=alpha, dim=dim, n_iter=n_iter) return mtf.custom_gradient( partial(entmax_forward, **kwargs), @@ -57,6 +61,7 @@ def entmax(x, alpha = 1.3, dim = None, n_iter = 50): [x] ) + def entmax_cross_entropy_with_logits(logits, targets, vocab_dim, z_loss=0.0): if targets.dtype.is_integer: # hard targets @@ -69,26 +74,28 @@ def entmax_cross_entropy_with_logits(logits, targets, vocab_dim, z_loss=0.0): elif set(targets.shape.dims) != set(logits.shape.dims): raise ValueError( "softmax_cross_entropy_with_logits with soft targets " - "dims in targets=%s should be dims in logits=%s"% (targets, logits)) + "dims in targets=%s should be dims in logits=%s" % (targets, logits)) if vocab_dim not in logits.shape.dims: raise ValueError("vocab_dim must be in logits.shape.dims") - log_entmax = mtf.log(entmax(logits, dim = vocab_dim)) + log_entmax = mtf.log(entmax(logits, dim=vocab_dim)) loss = mtf.negative( - mtf.reduce_sum(log_entmax * targets, reduced_dim = vocab_dim)) + mtf.reduce_sum(log_entmax * targets, reduced_dim=vocab_dim)) return loss -def sample_categorical(x, dim = None): + +def sample_categorical(x, dim=None): dim = x.shape[-1] if dim is None else dim cdf = mtf.cumsum(x, dim) - rand_uniform = mtf.random_uniform(x.mesh, x.shape - dim, minval = 0, maxval = 1) + rand_uniform = mtf.random_uniform(x.mesh, x.shape - dim, minval=0, maxval=1) mask = mtf.cast(mtf.greater(cdf, rand_uniform), tf.int32) return mtf.argmax(mask, dim) + def biasmask_attn_weights(mesh, nd, ns, variable_dtype): # The old mask_attn_weights applied directly to the QK; # this returns a bias that the attention code from mtf adds to the attention matrix. @@ -102,6 +109,7 @@ def biasmask_attn_weights(mesh, nd, ns, variable_dtype): dtype = variable_dtype.activation_dtype return mtf.cast(mtf.less(i, j), dtype) * -1e10 + def parse_inputs(mtf_features, other_features): # Parse inputs and labels from the mtf_features / other_features input dicts # All dimensions are defined inside model_fn for efficiency @@ -113,4 +121,4 @@ def parse_inputs(mtf_features, other_features): vocab_dim = other_features["vocab_dim"] embed_sequence_dim = other_features["embed_sequence_dim"] - return x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim \ No newline at end of file + return x, batch_dim, sequence_dim, embd_dim, vocab_dim, embed_sequence_dim