From 9f335665c8970e3d29deb18317742fdc4c41ae9c Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Thu, 16 May 2019 21:14:28 -0700 Subject: [PATCH 01/10] Merge nshepperd's nucleus sampling implementation --- gpt_2_simple/gpt_2.py | 14 +++++------- gpt_2_simple/src/sample.py | 45 +++++++++++++++++++++++++++++--------- 2 files changed, 40 insertions(+), 19 deletions(-) diff --git a/gpt_2_simple/gpt_2.py b/gpt_2_simple/gpt_2.py index 456c6d4..e607737 100644 --- a/gpt_2_simple/gpt_2.py +++ b/gpt_2_simple/gpt_2.py @@ -311,14 +311,13 @@ def generate(sess, destination_path=None, sample_delim='=' * 20 + '\n', prefix=None, - model_name='117M', seed=None, nsamples=1, batch_size=1, length=1023, temperature=0.7, top_k=0, - run_name='run1', + top_p=0.0, include_prefix=True): """Generates text from a model loaded into memory. @@ -353,7 +352,7 @@ def generate(sess, start_token=enc.encoder['<|endoftext|>'] if not prefix else None, context=context if prefix else None, batch_size=batch_size, - temperature=temperature, top_k=top_k + temperature=temperature, top_k=top_k, top_p=top_p )[:, 1:] if destination_path: @@ -404,14 +403,13 @@ def generate_to_file(sess, destination_path='gpt_2_gen_texts.txt', sample_delim='=' * 20 + '\n', prefix=None, - model_name='117M', seed=None, nsamples=1, batch_size=1, length=1023, temperature=0.7, top_k=0, - run_name='run1', + top_p=0.0, include_prefix=True): """Generates the texts to a file. @@ -426,14 +424,13 @@ def generate_to_file(sess, destination_path, sample_delim, prefix, - model_name, seed, nsamples, batch_size, length, temperature, top_k, - run_name, + top_p, include_prefix) @@ -661,6 +658,5 @@ def cmd_generate(nfiles, nsamples, folder, prefix=prefix, truncate=truncate, include_prefix=include_prefix, - sample_delim=sample_delim, - run_name=run_name + sample_delim=sample_delim ) diff --git a/gpt_2_simple/src/sample.py b/gpt_2_simple/src/sample.py index 0a2f835..1f48dd4 100755 --- a/gpt_2_simple/src/sample.py +++ b/gpt_2_simple/src/sample.py @@ -2,6 +2,7 @@ from gpt_2_simple.src import model + def top_k_logits(logits, k): if k == 0: # no truncation @@ -16,13 +17,30 @@ def _top_k(): logits, ) return tf.cond( - tf.equal(k, 0), - lambda: logits, - lambda: _top_k(), + tf.equal(k, 0), + lambda: logits, + lambda: _top_k(), ) -def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0): +def top_p_logits(logits, p): + with tf.variable_scope('top_p_logits'): + logits_sort = tf.sort(logits, direction='DESCENDING') + probs_sort = tf.nn.softmax(logits_sort) + probs_sums = tf.cumsum(probs_sort, axis=1, exclusive=True) + logits_masked = tf.where(probs_sums < p, logits_sort, tf.ones_like( + logits_sort)*1000) # [batchsize, vocab] + min_logits = tf.reduce_min(logits_masked, axis=1, keepdims=True) # [batchsize, 1] + return tf.where( + logits < min_logits, + tf.ones_like(logits, dtype=logits.dtype) * -1e10, + logits, + ) + + +def sample_sequence(*, hparams, length, start_token=None, + batch_size=None, context=None, temperature=1, + top_k=0, top_p=0.0): if start_token is None: assert context is not None, 'Specify exactly one of start_token and context!' else: @@ -30,11 +48,13 @@ def sample_sequence(*, hparams, length, start_token=None, batch_size=None, conte context = tf.fill([batch_size, 1], start_token) def step(hparams, tokens, past=None): - lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE) + lm_output = model.model(hparams=hparams, X=tokens, + past=past, reuse=tf.AUTO_REUSE) logits = lm_output['logits'][:, :, :hparams.n_vocab] presents = lm_output['present'] - presents.set_shape(model.past_shape(hparams=hparams, batch_size=batch_size)) + presents.set_shape(model.past_shape( + hparams=hparams, batch_size=batch_size)) return { 'logits': logits, 'presents': presents, @@ -48,9 +68,13 @@ def step(hparams, tokens, past=None): def body(past, prev, output): next_outputs = step(hparams, prev[:, tf.newaxis], past=past) - logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature) - logits = top_k_logits(logits, k=top_k) - samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32) + logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature) + if top_p > 0.0: + logits = top_p_logits(logits, p=top_p) + else: + logits = top_k_logits(logits, k=top_k) + samples = tf.multinomial( + logits, num_samples=1, output_dtype=tf.int32) return [ tf.concat([past, next_outputs['presents']], axis=-2), tf.squeeze(samples, axis=[1]), @@ -69,7 +93,8 @@ def cond(*args): context, ], shape_invariants=[ - tf.TensorShape(model.past_shape(hparams=hparams, batch_size=batch_size)), + tf.TensorShape(model.past_shape( + hparams=hparams, batch_size=batch_size)), tf.TensorShape([batch_size]), tf.TensorShape([batch_size, None]), ], From faf9ce4b8ceecef4fda2879acabd708671a9b0cb Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Thu, 16 May 2019 21:18:32 -0700 Subject: [PATCH 02/10] Add top_k and top_p CLI options --- gpt_2_simple/gpt_2.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/gpt_2_simple/gpt_2.py b/gpt_2_simple/gpt_2.py index e607737..295e26d 100644 --- a/gpt_2_simple/gpt_2.py +++ b/gpt_2_simple/gpt_2.py @@ -569,6 +569,12 @@ def cmd(): parser.add_argument( '--temperature', help="[generate] Temperature of the generated texts", nargs='?', default=0.7, type=float) + parser.add_argument( + '--top_k', help="[generate] Sample only from top k tokens", + nargs='?', default=0, type=int) + parser.add_argument( + '--top_p', help="[generate] Sample from top p prob (overrides top_k if nonzero)", + nargs='?', default=0.0, type=float) parser.add_argument( '--batch_size', help="[generate] Batch size for generation (increase for GPUs)", nargs='?', default=1, type=int) @@ -608,7 +614,8 @@ def cmd(): temperature=args.temperature, batch_size=args.batch_size, prefix=args.prefix, truncate=args.truncate, include_prefix=args.include_prefix, - sample_delim=args.sample_delim, run_name=args.run_name) + sample_delim=args.sample_delim, run_name=args.run_name, + top_k=args.top_k, top_p=args.top_p) def cmd_finetune(dataset, run_name, model_name, steps, @@ -630,7 +637,8 @@ def cmd_finetune(dataset, run_name, model_name, steps, def cmd_generate(nfiles, nsamples, folder, length, temperature, batch_size, prefix, truncate, include_prefix, - sample_delim, run_name): + sample_delim, run_name, + top_k, top_p): """Wrapper script for generating text via the CLI. The files are generated into a folder, which can be downloaded recursively by downloading the entire folder. @@ -658,5 +666,7 @@ def cmd_generate(nfiles, nsamples, folder, prefix=prefix, truncate=truncate, include_prefix=include_prefix, - sample_delim=sample_delim + sample_delim=sample_delim, + top_k=top_k, + top_p=top_p ) From 05625f6ba3095e04a13198b5f950a76a749bc0d4 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sat, 18 May 2019 12:31:00 -0700 Subject: [PATCH 03/10] Overwrite (#20), standalone FT, remove model_load --- README.md | 1 + gpt_2_simple/gpt_2.py | 24 ++++++++++++++++-------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 99d0bec..65486e2 100644 --- a/README.md +++ b/README.md @@ -100,6 +100,7 @@ The method GPT-2 uses to generate text is slightly different than those like oth * If you pass a single-column `.csv` file to `finetune()`, it will automatically parse the CSV into a format ideal for training with GPT-2 (including prepending `<|startoftext|>` and suffixing `<|endoftext|>` to every text document, so the `truncate` tricks above are helpful when generating output). This is necessary to handle both quotes and newlines in each text document correctly. * GPT-2 allows you to generate texts in parallel by setting a `batch_size` that is divisible into `nsamples`, resulting in much faster generation. Works very well with a GPU (can set `batch_size` up to 20 on Colaboratory's K80)! * Due to GPT-2's architecture, it scales up nicely with more powerful GPUs. For the 117M model, if you want to train for longer periods of time, GCP's P100 GPU is about 3x faster than a K80/T4 for only 3x the price, making it price-comparable (the V100 is about 1.5x faster than the P100 but about 2x the price). The P100 uses 100% of the GPU even with `batch_size=1`, and about 88% of the V100 GPU. +* If you have a partially-trained GPT-2 model and want to continue finetuning it, you can set `overwrite=True` to finetune, which will continue training and remove the previous iteration of the model without creating a duplicate copy. This can be especially useful for transfer learning (e.g. heavily finetune GPT-2 on one dataset, then finetune on other dataset to get a "merging" of both datasets). ## Planned Work diff --git a/gpt_2_simple/gpt_2.py b/gpt_2_simple/gpt_2.py index 295e26d..18f0586 100644 --- a/gpt_2_simple/gpt_2.py +++ b/gpt_2_simple/gpt_2.py @@ -86,7 +86,7 @@ def finetune(sess, max_checkpoints=1, use_memory_saving_gradients=False, only_train_transformer_layers=False, - model_load=False): + overwrite=False): """Finetunes the model on the given dataset. Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/train.py. @@ -105,10 +105,15 @@ def maketree(path): pass maketree(checkpoint_path) - if not model_load: - for file in ['hparams.json', 'encoder.json', 'vocab.bpe']: - shutil.copyfile(os.path.join('models', model_name, file), - os.path.join(checkpoint_path, file)) + files = [f for f in os.listdir(checkpoint_path) if os.path.isfile(f)] + for file in ['hparams.json', 'encoder.json', 'vocab.bpe']: + if file not in files: + try: + shutil.copyfile(os.path.join('models', model_name, file), + os.path.join(checkpoint_path, file)) + except FileNotFoundError as fnf_error: + print("You need to download the GPT-2 model first via download_gpt2()") + raise(fnf_error) enc = encoder.get_encoder(checkpoint_path) hparams = model.default_hparams() @@ -181,9 +186,6 @@ def maketree(path): print('Loading checkpoint', ckpt) saver.restore(sess, ckpt) - if model_load: - return - print('Loading dataset...') chunks = load_dataset(enc, dataset, combine) data_sampler = Sampler(chunks) @@ -236,6 +238,12 @@ def generate_samples(): def sample_batch(): return [data_sampler.sample(1024) for _ in range(batch_size)] + if overwrite: + save() + for file in files: + if file.startswith('model'): + os.remove(os.path.join(checkpoint_path, file)) + avg_loss = (0.0, 0.0) start_time = time.time() From e5b3c60820d694bfa24a2b9157c40e6e120a1517 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sat, 18 May 2019 12:54:41 -0700 Subject: [PATCH 04/10] Switch checkpoint FNs to use run_name --- gpt_2_simple/gpt_2.py | 6 ++++-- setup.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/gpt_2_simple/gpt_2.py b/gpt_2_simple/gpt_2.py index 18f0586..d21d3bc 100644 --- a/gpt_2_simple/gpt_2.py +++ b/gpt_2_simple/gpt_2.py @@ -461,10 +461,11 @@ def get_tarfile_name(checkpoint_folder): return tarfile_name -def copy_checkpoint_to_gdrive(checkpoint_folder=os.path.join('checkpoint', 'run1')): +def copy_checkpoint_to_gdrive(run_name='run1'): """Copies the checkpoint folder to a mounted Google Drive.""" is_mounted() + checkpoint_folder = os.path.join('checkpoint', run_name) file_path = get_tarfile_name(checkpoint_folder) # Reference: https://stackoverflow.com/a/17081026 @@ -474,10 +475,11 @@ def copy_checkpoint_to_gdrive(checkpoint_folder=os.path.join('checkpoint', 'run1 shutil.copyfile(file_path, "/content/drive/My Drive/" + file_path) -def copy_checkpoint_from_gdrive(checkpoint_folder=os.path.join('checkpoint', 'run1')): +def copy_checkpoint_from_gdrive(run_name='run1'): """Copies the checkpoint folder from a mounted Google Drive.""" is_mounted() + checkpoint_folder = os.path.join('checkpoint', run_name) file_path = get_tarfile_name(checkpoint_folder) shutil.copyfile("/content/drive/My Drive/" + file_path, file_path) diff --git a/setup.py b/setup.py index a497f95..ccdcde4 100644 --- a/setup.py +++ b/setup.py @@ -47,7 +47,7 @@ setup( name='gpt_2_simple', packages=['gpt_2_simple'], # this must be the same as the name above - version='0.4.2', + version='0.5', description="Python package to easily retrain OpenAI's GPT-2 " \ "text-generating model on new texts.", long_description=long_description, From 6796946c924e38ec5b85e9a9319be51f34b29fb2 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sun, 19 May 2019 15:54:26 -0700 Subject: [PATCH 05/10] encode_dataset() --- README.md | 1 + gpt_2_simple/gpt_2.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+) diff --git a/README.md b/README.md index 65486e2..985d81d 100644 --- a/README.md +++ b/README.md @@ -101,6 +101,7 @@ The method GPT-2 uses to generate text is slightly different than those like oth * GPT-2 allows you to generate texts in parallel by setting a `batch_size` that is divisible into `nsamples`, resulting in much faster generation. Works very well with a GPU (can set `batch_size` up to 20 on Colaboratory's K80)! * Due to GPT-2's architecture, it scales up nicely with more powerful GPUs. For the 117M model, if you want to train for longer periods of time, GCP's P100 GPU is about 3x faster than a K80/T4 for only 3x the price, making it price-comparable (the V100 is about 1.5x faster than the P100 but about 2x the price). The P100 uses 100% of the GPU even with `batch_size=1`, and about 88% of the V100 GPU. * If you have a partially-trained GPT-2 model and want to continue finetuning it, you can set `overwrite=True` to finetune, which will continue training and remove the previous iteration of the model without creating a duplicate copy. This can be especially useful for transfer learning (e.g. heavily finetune GPT-2 on one dataset, then finetune on other dataset to get a "merging" of both datasets). +* If your input text dataset is massive (>100 MB), you may want to preencode and compress the dataset using `gpt2.encode_dataset(file_path)`. THe output is a compressed `.npz` file which will load much faster into the GPU for finetuning. ## Planned Work diff --git a/gpt_2_simple/gpt_2.py b/gpt_2_simple/gpt_2.py index d21d3bc..c771d8b 100644 --- a/gpt_2_simple/gpt_2.py +++ b/gpt_2_simple/gpt_2.py @@ -529,6 +529,23 @@ def encode_csv(csv_path, out_path='csv_encoded.txt', header=True, w.write(start_token + row[0] + end_token + "\n") +def encode_dataset(file_path, out_path='text_encoded.npz', + model_name="117M", + combine=50000): + """Preencodes a text document into chunks and compresses it, + saving time when generated. + + Adapted from https://github.com/nshepperd/gpt-2/blob/finetuning/encode.py + """ + + model_path = os.path.join('models', model_name) + enc = encoder.get_encoder(model_path) + print('Reading files') + chunks = load_dataset(enc, file_path, combine) + print('Writing', out_path) + np.savez_compressed(out_path, *chunks) + + def cmd(): """Function called when invoking from the terminal.""" From 1da0cb941ddf9085aa03f0e6742382c586af0bd5 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sun, 19 May 2019 17:32:54 -0700 Subject: [PATCH 06/10] Fix miscellaneous bugs during testing --- gpt_2_simple/gpt_2.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/gpt_2_simple/gpt_2.py b/gpt_2_simple/gpt_2.py index c771d8b..452904a 100644 --- a/gpt_2_simple/gpt_2.py +++ b/gpt_2_simple/gpt_2.py @@ -105,7 +105,7 @@ def maketree(path): pass maketree(checkpoint_path) - files = [f for f in os.listdir(checkpoint_path) if os.path.isfile(f)] + files = [f for f in os.listdir(checkpoint_path)] for file in ['hparams.json', 'encoder.json', 'vocab.bpe']: if file not in files: try: @@ -238,11 +238,11 @@ def generate_samples(): def sample_batch(): return [data_sampler.sample(1024) for _ in range(batch_size)] - if overwrite: - save() + if overwrite and restore_from == 'latest': for file in files: - if file.startswith('model'): + if file.startswith('model') or file.startswith('events'): os.remove(os.path.join(checkpoint_path, file)) + save() avg_loss = (0.0, 0.0) start_time = time.time() @@ -314,6 +314,7 @@ def load_gpt2(sess, def generate(sess, + run_name='run1', return_as_list=False, truncate=None, destination_path=None, From bf319b64b647953564b6801e412f61fc54e5cf81 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sun, 19 May 2019 17:45:17 -0700 Subject: [PATCH 07/10] overwrite CLI argument --- gpt_2_simple/gpt_2.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/gpt_2_simple/gpt_2.py b/gpt_2_simple/gpt_2.py index 452904a..6cc275e 100644 --- a/gpt_2_simple/gpt_2.py +++ b/gpt_2_simple/gpt_2.py @@ -582,6 +582,9 @@ def cmd(): parser.add_argument( '--print_every', help="[finetune] After how many steps to print progress", nargs='?', default=10, type=int) + parser.add_argument( + '--overwrite', help="[finetune] Overwrite existing model when continuing training", + nargs='?', default=False, type=lambda x: (str(x).lower() == 'true')) parser.add_argument( '--nfiles', help="[generate] How many files to generate.", nargs='?', default=1, type=int) @@ -635,7 +638,8 @@ def cmd(): steps=args.steps, restore_from=args.restore_from, sample_every=args.sample_every, save_every=args.save_every, - print_every=args.print_every) + print_every=args.print_every, + overwrite=args.overwrite) if args.mode == "generate": cmd_generate(nfiles=args.nfiles, nsamples=args.nsamples, folder=args.folder, length=args.length, @@ -648,7 +652,7 @@ def cmd(): def cmd_finetune(dataset, run_name, model_name, steps, restore_from, sample_every, - save_every, print_every): + save_every, print_every, overwrite): """Wrapper script for finetuning the model via the CLI.""" if not is_gpt2_downloaded(model_name=model_name): @@ -659,7 +663,8 @@ def cmd_finetune(dataset, run_name, model_name, steps, model_name=model_name, steps=steps, restore_from=restore_from, sample_every=sample_every, save_every=save_every, - print_every=print_every) + print_every=print_every, + overwrite=overwrite) def cmd_generate(nfiles, nsamples, folder, From e4bb5a00633831b452315e9bf5663842fbadc6b7 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sun, 19 May 2019 17:57:44 -0700 Subject: [PATCH 08/10] Add copy_folder arg fallback --- gpt_2_simple/gpt_2.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/gpt_2_simple/gpt_2.py b/gpt_2_simple/gpt_2.py index 6cc275e..ec13544 100644 --- a/gpt_2_simple/gpt_2.py +++ b/gpt_2_simple/gpt_2.py @@ -462,31 +462,39 @@ def get_tarfile_name(checkpoint_folder): return tarfile_name -def copy_checkpoint_to_gdrive(run_name='run1'): +def copy_checkpoint_to_gdrive(run_name='run1', copy_folder=False): """Copies the checkpoint folder to a mounted Google Drive.""" is_mounted() checkpoint_folder = os.path.join('checkpoint', run_name) - file_path = get_tarfile_name(checkpoint_folder) - # Reference: https://stackoverflow.com/a/17081026 - with tarfile.open(file_path, 'w') as tar: - tar.add(checkpoint_folder) + if copy_folder: + shutil.copytree(checkpoint_folder, "/content/drive/My Drive/" + checkpoint_folder) + else: + file_path = get_tarfile_name(checkpoint_folder) - shutil.copyfile(file_path, "/content/drive/My Drive/" + file_path) + # Reference: https://stackoverflow.com/a/17081026 + with tarfile.open(file_path, 'w') as tar: + tar.add(checkpoint_folder) + + shutil.copyfile(file_path, "/content/drive/My Drive/" + file_path) -def copy_checkpoint_from_gdrive(run_name='run1'): +def copy_checkpoint_from_gdrive(run_name='run1', copy_folder=False): """Copies the checkpoint folder from a mounted Google Drive.""" is_mounted() checkpoint_folder = os.path.join('checkpoint', run_name) - file_path = get_tarfile_name(checkpoint_folder) - shutil.copyfile("/content/drive/My Drive/" + file_path, file_path) + if copy_folder: + shutil.copytree("/content/drive/My Drive/" + checkpoint_folder, checkpoint_folder) + else: + file_path = get_tarfile_name(checkpoint_folder) + + shutil.copyfile("/content/drive/My Drive/" + file_path, file_path) - with tarfile.open(file_path, 'r') as tar: - tar.extractall() + with tarfile.open(file_path, 'r') as tar: + tar.extractall() def copy_file_to_gdrive(file_path): From 085fc93412799c88cdfd091142be8dba46e10293 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sun, 19 May 2019 20:06:44 -0700 Subject: [PATCH 09/10] fix generate_to_file --- gpt_2_simple/gpt_2.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gpt_2_simple/gpt_2.py b/gpt_2_simple/gpt_2.py index ec13544..6de2f11 100644 --- a/gpt_2_simple/gpt_2.py +++ b/gpt_2_simple/gpt_2.py @@ -408,6 +408,7 @@ def generate(sess, def generate_to_file(sess, + run_name='run1', truncate=None, destination_path='gpt_2_gen_texts.txt', sample_delim='=' * 20 + '\n', @@ -428,6 +429,7 @@ def generate_to_file(sess, """ generate(sess, + run_name, False, truncate, destination_path, From de2b362fd680d6eb9d546de5cd1d3a19e78d2df2 Mon Sep 17 00:00:00 2001 From: Max Woolf Date: Sun, 19 May 2019 20:42:20 -0700 Subject: [PATCH 10/10] Rename installed package --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 985d81d..cae583b 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ You can use gpt-2-simple to retrain a model using a GPU **for free** in [this Co gpt-2-simple can be installed [via PyPI](https://pypi.org/project/gpt_2_simple/): ```shell -pip3 install gpt_2_simple +pip3 install gpt-2-simple ``` You will also need to install the corresponding TensorFlow for your system (e.g. `tensorflow` or `tensorflow-gpu`)