From 7dd43190621292e1bde0a33b89c29bb7167cdbc7 Mon Sep 17 00:00:00 2001 From: Daisy Deng Date: Tue, 28 Apr 2020 15:03:23 +0000 Subject: [PATCH 01/14] add working prediction scripts --- .../text_summarization_bartt5.py | 59 +++++++++++++++++++ 1 file changed, 59 insertions(+) create mode 100644 examples/text_summarization/text_summarization_bartt5.py diff --git a/examples/text_summarization/text_summarization_bartt5.py b/examples/text_summarization/text_summarization_bartt5.py new file mode 100644 index 000000000..112f211b5 --- /dev/null +++ b/examples/text_summarization/text_summarization_bartt5.py @@ -0,0 +1,59 @@ +from pathlib import Path +from transformers import BartForConditionalGeneration, BartTokenizer +from tqdm import tqdm +from transformers import T5ForConditionalGeneration, T5Tokenizer + + +model_class = { + "bart-large-cnn": BartForConditionalGeneration, + "t5-large":T5ForConditionalGeneration +} +tokenizer_class = { + "bart-large-cnn": BartTokenizer, + "t5-large": T5Tokenizer +} +def chunks(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i : i + n] + + +def generate_summaries( + examples: list, out_file: str, model_name: str, batch_size: int = 8, device: str = "cuda" +): + fout = Path(out_file).open("w") + model = model_class[model_name].from_pretrained(model_name).to(device) + tokenizer = tokenizer_class[model_name].from_pretrained(model_name) # bart-large + + max_length = 140 + min_length = 55 + + if model_name.startswith("t5"): + # update config with summarization specific params + task_specific_params = model.config.task_specific_params + if task_specific_params is not None: + model.config.update(task_specific_params.get("summarization", {})) + + for batch in tqdm(list(chunks(examples, batch_size))): + if model_name.startswith("t5"): + batch = [model.config.prefix + text for text in batch] + dct = tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True) + summaries = model.generate( + input_ids=dct["input_ids"].to(device), + attention_mask=dct["attention_mask"].to(device), + #num_beams=4, + #length_penalty=2.0, + #max_length=max_length + 2, # +2 from original because we start at step=1 and stop before max_length + #min_length=min_length + 1, # +1 from original because we start at step=1 + #no_repeat_ngram_size=3, + #early_stopping=True, + #decoder_start_token_id=model.config.eos_token_id, + ) + dec = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries] + for hypothesis in dec: + fout.write(hypothesis + "\n") + fout.flush() + +examples = [" " + x.rstrip() for x in open("./cnn_test.txt").readlines()] +generate_summaries(examples, "./cnn_generated.txt", "bart-large-cnn", batch_size=4, device="cuda") +#generate_summaries(examples, "./cnn_generated-t5.txt", "t5-large", batch_size=4, device="cuda") \ No newline at end of file From c28ca5b7cf24b9d899948f357beb5e11cdca2283 Mon Sep 17 00:00:00 2001 From: Daisy Deng Date: Thu, 30 Apr 2020 04:30:16 +0000 Subject: [PATCH 02/14] data preprocessing --- ...ive_summarization_cnndm_transformers.ipynb | 1040 +++++++++++++++++ utils_nlp/dataset/cnndm.py | 16 +- .../abstractive_summarization_bartt5.py | 604 ++++++++++ utils_nlp/models/transformers/datasets.py | 2 +- .../transformers/extractive_summarization.py | 2 +- 5 files changed, 1661 insertions(+), 3 deletions(-) create mode 100644 examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb create mode 100644 utils_nlp/models/transformers/abstractive_summarization_bartt5.py diff --git a/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb b/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb new file mode 100644 index 000000000..dbdbdc551 --- /dev/null +++ b/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb @@ -0,0 +1,1040 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Copyright (c) Microsoft Corporation. All rights reserved.\n", + "\n", + "Licensed under the MIT License." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Abstractive Summarization on CNN/DM Dataset using Transformers\n", + "\n", + "\n", + "### Summary\n", + "\n", + "This notebook demonstrates how to fine tune Transformers for extractive text summarization. Utility functions and classes in the NLP Best Practices repo are used to facilitate data preprocessing, model training, model scoring, result postprocessing, and model evaluation.\n", + "\n", + "\n", + "\n", + "\n", + "### Before You Start\n", + "\n", + "The running time shown in this notebook is on a Standard_NC24s_v3 Azure Ubuntu Virtual Machine with 4 NVIDIA Tesla V100 GPUs. \n", + "> **Tip**: If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`** to run the notebook on a small subset of the data and a smaller number of epochs. \n", + "\n", + "Using only 1 NVIDIA Tesla V100 GPUs, 16GB GPU memory configuration,\n", + "- for data preprocessing, it takes around 1 minutes to preprocess the data for quick run. Otherwise it takes ~20 minutes to finish the data preprocessing. This time estimation assumes that the chosen transformer model is \"distilbert-base-uncased\" and the sentence selection method is \"greedy\", which is the default. The preprocessing time can be significantly longer if the sentence selection method is \"combination\", which can achieve better model performance.\n", + "\n", + "- for model fine tuning, it takes around 2 minutes for quick run. Otherwise, it takes around ~3 hours to finish. This estimation assumes the chosen encoder method is \"transformer\". The model fine tuning time can be shorter if other encoder method is chosen, which may result in worse model performance. \n", + "\n", + "### Additional Notes\n", + "\n", + "* **ROUGE Evalation**: To run rouge evaluation, please refer to the section of compute_rouge_perl in [summarization_evaluation.ipynb](./summarization_evaluation.ipynb) for setup.\n", + "\n", + "* **Distributed Training**:\n", + "Please note that the jupyter notebook only allows to use pytorch [DataParallel](https://pytorch.org/docs/master/nn.html#dataparallel). Faster speed and larger batch size can be achieved with pytorch [DistributedDataParallel](https://pytorch.org/docs/master/notes/ddp.html)(DDP). Script [extractive_summarization_cnndm_distributed_train.py](./extractive_summarization_cnndm_distributed_train.py) shows an example of how to use DDP.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%load_ext autoreload\n", + "\n", + "%autoreload 2" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "tags": [ + "parameters" + ] + }, + "outputs": [], + "source": [ + "## Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.\n", + "QUICK_RUN = True\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Configuration\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "import os\n", + "import shutil\n", + "import sys\n", + "from tempfile import TemporaryDirectory\n", + "import torch\n", + "\n", + "nlp_path = os.path.abspath(\"../../\")\n", + "if nlp_path not in sys.path:\n", + " sys.path.insert(0, nlp_path)\n", + "\n", + "from utils_nlp.dataset.cnndm import CNNDMBertSumProcessedData, CNNDMSummarizationDataset\n", + "from utils_nlp.eval import compute_rouge_python, compute_rouge_perl\n", + "from utils_nlp.models.transformers.abstractive_summarization_bartt5 import SummarizationProcessor\n", + "\n", + "from utils_nlp.models.transformers.datasets import SummarizationDataset\n", + "import nltk\n", + "from nltk import tokenize\n", + "\n", + "import pandas as pd\n", + "import scrapbook as sb\n", + "import pprint" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "### Configuration: choose the transformer model to be used" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Several pretrained models have been made available by [Hugging Face](https://github.com/huggingface/transformers). For extractive summarization, the following pretrained models are supported. " + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'ExtractiveSummarizer' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataFrame\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m\"model_name\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mExtractiveSummarizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlist_supported_models\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name 'ExtractiveSummarizer' is not defined" + ] + } + ], + "source": [ + "#pd.DataFrame({\"model_name\": ExtractiveSummarizer.list_supported_models()})" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "tags": [ + "parameters" + ] + }, + "outputs": [], + "source": [ + "# Transformer model being used\n", + "MODEL_NAME = \"bart-large-cnn\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# notebook parameters\n", + "# the cache data path during find tuning\n", + "from transformers import BartForConditionalGeneration, BartTokenizer\n", + "from tqdm import tqdm\n", + "from transformers import T5ForConditionalGeneration, T5Tokenizer\n", + "\n", + "model_class = {\n", + " \"bart-large-cnn\": BartForConditionalGeneration,\n", + " \"t5-large\":T5ForConditionalGeneration\n", + "}\n", + "tokenizer_class = {\n", + " \"bart-large-cnn\": BartTokenizer,\n", + " \"t5-large\": T5Tokenizer\n", + "}\n", + "CACHE_DIR = TemporaryDirectory().name\n", + "tokenizer = tokenizer_class[MODEL_NAME].from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR) # b\n", + "\n", + "processor = SummarizationProcessor(tokenizer)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Data Preprocessing\n", + "\n", + "The dataset we used for this notebook is CNN/DM dataset which contains the documents and accompanying questions from the news articles of CNN and Daily mail. The highlights in each article are used as summary. The dataset consits of ~289K training examples, ~11K valiation examples and ~11K test examples. \n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### [Option 1] Preprocess data (Please skil this part if you choose to use preprocessed data)\n", + "The code in following cell will download the CNN/DM dataset listed at https://github.com/harvardnlp/sent-summary/." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "tags": [ + "parameters" + ] + }, + "outputs": [], + "source": [ + "# the data path used to save the downloaded data file\n", + "DATA_PATH = TemporaryDirectory().name\n", + "# The number of lines at the head of data file used for preprocessing. -1 means all the lines.\n", + "TOP_N = 1000\n", + "if not QUICK_RUN:\n", + " TOP_N = -1" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 489k/489k [00:06<00:00, 70.8kKB/s] \n" + ] + } + ], + "source": [ + "train_dataset, test_dataset = CNNDMSummarizationDataset(top_n=TOP_N, local_cache_path=DATA_PATH, raw=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['src', 'src_txt', 'tgt', 'tgt_txt'])" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_dataset[0].keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1000" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(test_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'src': \"marseille , france -lrb- cnn -rrb- the french prosecutor leading an investigation into the crash of germanwings flight 9525 insisted wednesday that he was not aware of any video footage from on board the plane . marseille prosecutor brice robin told cnn that `` so far no videos were used in the crash investigation . '' he added , `` a person who has such a video needs to immediately give it to the investigators . '' robin 's comments follow claims by two magazines , german daily bild and french paris match , of a cell phone video showing the harrowing final seconds from on board germanwings flight 9525 as it crashed into the french alps . all 150 on board were killed . paris match and bild reported that the video was recovered from a phone at the wreckage site . the two publications described the supposed video , but did not post it on their websites . the publications said that they watched the video , which was found by a source close to the investigation . `` one can hear cries of ` my god ' in several languages , '' paris match reported . `` metallic banging can also be heard more than three times , perhaps of the pilot trying to open the cockpit door with a heavy object . towards the end , after a heavy shake , stronger than the others , the screaming intensifies . then nothing . '' `` it is a very disturbing scene , '' said julian reichelt , editor-in-chief of bild online . an official with france 's accident investigation agency , the bea , said the agency is not aware of any such video . lt. col. jean-marc menichini , a french gendarmerie spokesman in charge of communications on rescue efforts around the germanwings crash site , told cnn that the reports were `` completely wrong '' and `` unwarranted . '' cell phones have been collected at the site , he said , but that they `` had n't been exploited yet . '' menichini said he believed the cell phones would need to be sent to the criminal research institute in rosny sous-bois , near paris , in order to be analyzed by specialized technicians working hand-in-hand with investigators . but none of the cell phones found so far have been sent to the institute , menichini said . asked whether staff involved in the search could have leaked a memory card to the media , menichini answered with a categorical `` no . '' reichelt told `` erin burnett : outfront '' that he had watched the video and stood by the report , saying bild and paris match are `` very confident '' that the clip is real . he noted that investigators only revealed they 'd recovered cell phones from the crash site after bild and paris match published their reports . `` that is something we did not know before . ... overall we can say many things of the investigation were n't revealed by the investigation at the beginning , '' he said . what was mental state of germanwings co-pilot ? german airline lufthansa confirmed tuesday that co-pilot andreas lubitz had battled depression years before he took the controls of germanwings flight 9525 , which he 's accused of deliberately crashing last week in the french alps . lubitz told his lufthansa flight training school in 2009 that he had a `` previous episode of severe depression , '' the airline said tuesday . email correspondence between lubitz and the school discovered in an internal investigation , lufthansa said , included medical documents he submitted in connection with resuming his flight training . the announcement indicates that lufthansa , the parent company of germanwings , knew of lubitz 's battle with depression , allowed him to continue training and ultimately put him in the cockpit . lufthansa , whose ceo carsten spohr previously said lubitz was 100 % fit to fly , described its statement tuesday as a `` swift and seamless clarification '' and said it was sharing the information and documents -- including training and medical records -- with public prosecutors . spohr traveled to the crash site wednesday , where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside . he saw the crisis center set up in seyne-les-alpes , laid a wreath in the village of le vernet , closer to the crash site , where grieving families have left flowers at a simple stone memorial . menichini told cnn late tuesday that no visible human remains were left at the site but recovery teams would keep searching . french president francois hollande , speaking tuesday , said that it should be possible to identify all the victims using dna analysis by the end of the week , sooner than authorities had previously suggested . in the meantime , the recovery of the victims ' personal belongings will start wednesday , menichini said . among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board . check out the latest from our correspondents . the details about lubitz 's correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and lubitz 's possible motive for downing the jet . a lufthansa spokesperson told cnn on tuesday that lubitz had a valid medical certificate , had passed all his examinations and `` held all the licenses required . '' earlier , a spokesman for the prosecutor 's office in dusseldorf , christoph kumpa , said medical records reveal lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot 's license . kumpa emphasized there 's no evidence suggesting lubitz was suicidal or acting aggressively before the crash . investigators are looking into whether lubitz feared his medical condition would cause him to lose his pilot 's license , a european government official briefed on the investigation told cnn on tuesday . while flying was `` a big part of his life , '' the source said , it 's only one theory being considered . another source , a law enforcement official briefed on the investigation , also told cnn that authorities believe the primary motive for lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems . lubitz 's girlfriend told investigators he had seen an eye doctor and a neuropsychologist , both of whom deemed him unfit to work recently and concluded he had psychological issues , the european government official said . but no matter what details emerge about his previous mental health struggles , there 's more to the story , said brian russell , a forensic psychologist . `` psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they were n't going to keep doing their job and they 're upset about that and so they 're suicidal , '' he said . `` but there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person 's problems . '' germanwings crash compensation : what we know . who was the captain of germanwings flight 9525 ? cnn 's margot haddad reported from marseille and pamela brown from dusseldorf , while laura smith-spark wrote from london . cnn 's frederik pleitgen , pamela boykoff , antonia mortensen , sandrine amiel and anna-maja rappard contributed to this report .\\n\",\n", + " 'src_txt': \"marseille , france -lrb- cnn -rrb- the french prosecutor leading an investigation into the crash of germanwings flight 9525 insisted wednesday that he was not aware of any video footage from on board the plane . marseille prosecutor brice robin told cnn that `` so far no videos were used in the crash investigation . '' he added , `` a person who has such a video needs to immediately give it to the investigators . '' robin 's comments follow claims by two magazines , german daily bild and french paris match , of a cell phone video showing the harrowing final seconds from on board germanwings flight 9525 as it crashed into the french alps . all 150 on board were killed . paris match and bild reported that the video was recovered from a phone at the wreckage site . the two publications described the supposed video , but did not post it on their websites . the publications said that they watched the video , which was found by a source close to the investigation . `` one can hear cries of ` my god ' in several languages , '' paris match reported . `` metallic banging can also be heard more than three times , perhaps of the pilot trying to open the cockpit door with a heavy object . towards the end , after a heavy shake , stronger than the others , the screaming intensifies . then nothing . '' `` it is a very disturbing scene , '' said julian reichelt , editor-in-chief of bild online . an official with france 's accident investigation agency , the bea , said the agency is not aware of any such video . lt. col. jean-marc menichini , a french gendarmerie spokesman in charge of communications on rescue efforts around the germanwings crash site , told cnn that the reports were `` completely wrong '' and `` unwarranted . '' cell phones have been collected at the site , he said , but that they `` had n't been exploited yet . '' menichini said he believed the cell phones would need to be sent to the criminal research institute in rosny sous-bois , near paris , in order to be analyzed by specialized technicians working hand-in-hand with investigators . but none of the cell phones found so far have been sent to the institute , menichini said . asked whether staff involved in the search could have leaked a memory card to the media , menichini answered with a categorical `` no . '' reichelt told `` erin burnett : outfront '' that he had watched the video and stood by the report , saying bild and paris match are `` very confident '' that the clip is real . he noted that investigators only revealed they 'd recovered cell phones from the crash site after bild and paris match published their reports . `` that is something we did not know before . ... overall we can say many things of the investigation were n't revealed by the investigation at the beginning , '' he said . what was mental state of germanwings co-pilot ? german airline lufthansa confirmed tuesday that co-pilot andreas lubitz had battled depression years before he took the controls of germanwings flight 9525 , which he 's accused of deliberately crashing last week in the french alps . lubitz told his lufthansa flight training school in 2009 that he had a `` previous episode of severe depression , '' the airline said tuesday . email correspondence between lubitz and the school discovered in an internal investigation , lufthansa said , included medical documents he submitted in connection with resuming his flight training . the announcement indicates that lufthansa , the parent company of germanwings , knew of lubitz 's battle with depression , allowed him to continue training and ultimately put him in the cockpit . lufthansa , whose ceo carsten spohr previously said lubitz was 100 % fit to fly , described its statement tuesday as a `` swift and seamless clarification '' and said it was sharing the information and documents -- including training and medical records -- with public prosecutors . spohr traveled to the crash site wednesday , where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside . he saw the crisis center set up in seyne-les-alpes , laid a wreath in the village of le vernet , closer to the crash site , where grieving families have left flowers at a simple stone memorial . menichini told cnn late tuesday that no visible human remains were left at the site but recovery teams would keep searching . french president francois hollande , speaking tuesday , said that it should be possible to identify all the victims using dna analysis by the end of the week , sooner than authorities had previously suggested . in the meantime , the recovery of the victims ' personal belongings will start wednesday , menichini said . among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board . check out the latest from our correspondents . the details about lubitz 's correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and lubitz 's possible motive for downing the jet . a lufthansa spokesperson told cnn on tuesday that lubitz had a valid medical certificate , had passed all his examinations and `` held all the licenses required . '' earlier , a spokesman for the prosecutor 's office in dusseldorf , christoph kumpa , said medical records reveal lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot 's license . kumpa emphasized there 's no evidence suggesting lubitz was suicidal or acting aggressively before the crash . investigators are looking into whether lubitz feared his medical condition would cause him to lose his pilot 's license , a european government official briefed on the investigation told cnn on tuesday . while flying was `` a big part of his life , '' the source said , it 's only one theory being considered . another source , a law enforcement official briefed on the investigation , also told cnn that authorities believe the primary motive for lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems . lubitz 's girlfriend told investigators he had seen an eye doctor and a neuropsychologist , both of whom deemed him unfit to work recently and concluded he had psychological issues , the european government official said . but no matter what details emerge about his previous mental health struggles , there 's more to the story , said brian russell , a forensic psychologist . `` psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they were n't going to keep doing their job and they 're upset about that and so they 're suicidal , '' he said . `` but there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person 's problems . '' germanwings crash compensation : what we know . who was the captain of germanwings flight 9525 ? cnn 's margot haddad reported from marseille and pamela brown from dusseldorf , while laura smith-spark wrote from london . cnn 's frederik pleitgen , pamela boykoff , antonia mortensen , sandrine amiel and anna-maja rappard contributed to this report .\\n\",\n", + " 'tgt': \" marseille prosecutor says `` so far no videos were used in the crash investigation '' despite media reports . journalists at bild and paris match are `` very confident '' the video clip is real , an editor says . andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says . \\n\",\n", + " 'tgt_txt': \" marseille prosecutor says `` so far no videos were used in the crash investigation '' despite media reports . journalists at bild and paris match are `` very confident '' the video clip is real , an editor says . andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says . \\n\"}" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_dataset[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Preprocess the data." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "\n", + "abs_sum_train = processor.preprocess(train_dataset)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'src': \"editor 's note : in our behind the scenes series , cnn correspondents share their experiences in covering news and analyze the stories behind the events . here , soledad o'brien takes users inside a jail where many of the inmates are mentally ill . an inmate housed on the `` forgotten floor , '' where many mentally ill inmates are housed in miami before trial . miami , florida -lrb- cnn -rrb- -- the ninth floor of the miami-dade pretrial detention facility is dubbed the `` forgotten floor . '' here , inmates with the most severe mental illnesses are incarcerated until they 're ready to appear in court . most often , they face drug charges or charges of assaulting an officer -- charges that judge steven leifman says are usually `` avoidable felonies . '' he says the arrests often result from confrontations with police . mentally ill people often wo n't do what they 're told when police arrive on the scene -- confrontation seems to exacerbate their illness and they become more paranoid , delusional , and less likely to follow directions , according to leifman . so , they end up on the ninth floor severely mentally disturbed , but not getting any real help because they 're in jail . we toured the jail with leifman . he is well known in miami as an advocate for justice and the mentally ill . even though we were not exactly welcomed with open arms by the guards , we were given permission to shoot videotape and tour the floor . go inside the ` forgotten floor ' '' at first , it 's hard to determine where the people are . the prisoners are wearing sleeveless robes . imagine cutting holes for arms and feet in a heavy wool sleeping bag -- that 's kind of what they look like . they 're designed to keep the mentally ill patients from injuring themselves . that 's also why they have no shoes , laces or mattresses . leifman says about one-third of all people in miami-dade county jails are mentally ill . so , he says , the sheer volume is overwhelming the system , and the result is what we see on the ninth floor . of course , it is a jail , so it 's not supposed to be warm and comforting , but the lights glare , the cells are tiny and it 's loud . we see two , sometimes three men -- sometimes in the robes , sometimes naked , lying or sitting in their cells . `` i am the son of the president . you need to get me out of here ! '' one man shouts at me . he is absolutely serious , convinced that help is on the way -- if only he could reach the white house . leifman tells me that these prisoner-patients will often circulate through the system , occasionally stabilizing in a mental hospital , only to return to jail to face their charges . it 's brutally unjust , in his mind , and he has become a strong advocate for changing things in miami . over a meal later , we talk about how things got this way for mental patients . leifman says 200 years ago people were considered `` lunatics '' and they were locked up in jails even if they had no charges against them . they were just considered unfit to be in society . over the years , he says , there was some public outcry , and the mentally ill were moved out of jails and into hospitals . but leifman says many of these mental hospitals were so horrible they were shut down . where did the patients go ? nowhere . the streets . they became , in many cases , the homeless , he says . they never got treatment . leifman says in 1955 there were more than half a million people in state mental hospitals , and today that number has been reduced 90 percent , and 40,000 to 50,000 people are in mental hospitals . the judge says he 's working to change this . starting in 2008 , many inmates who would otherwise have been brought to the `` forgotten floor '' will instead be sent to a new mental health facility -- the first step on a journey toward long-term treatment , not just punishment . leifman says it 's not the complete answer , but it 's a start . leifman says the best part is that it 's a win-win solution . the patients win , the families are relieved , and the state saves money by simply not cycling these prisoners through again and again . and , for leifman , justice is served . e-mail to a friend .\\n\",\n", + " 'src_txt': \"editor 's note : in our behind the scenes series , cnn correspondents share their experiences in covering news and analyze the stories behind the events . here , soledad o'brien takes users inside a jail where many of the inmates are mentally ill . an inmate housed on the `` forgotten floor , '' where many mentally ill inmates are housed in miami before trial . miami , florida -lrb- cnn -rrb- -- the ninth floor of the miami-dade pretrial detention facility is dubbed the `` forgotten floor . '' here , inmates with the most severe mental illnesses are incarcerated until they 're ready to appear in court . most often , they face drug charges or charges of assaulting an officer -- charges that judge steven leifman says are usually `` avoidable felonies . '' he says the arrests often result from confrontations with police . mentally ill people often wo n't do what they 're told when police arrive on the scene -- confrontation seems to exacerbate their illness and they become more paranoid , delusional , and less likely to follow directions , according to leifman . so , they end up on the ninth floor severely mentally disturbed , but not getting any real help because they 're in jail . we toured the jail with leifman . he is well known in miami as an advocate for justice and the mentally ill . even though we were not exactly welcomed with open arms by the guards , we were given permission to shoot videotape and tour the floor . go inside the ` forgotten floor ' '' at first , it 's hard to determine where the people are . the prisoners are wearing sleeveless robes . imagine cutting holes for arms and feet in a heavy wool sleeping bag -- that 's kind of what they look like . they 're designed to keep the mentally ill patients from injuring themselves . that 's also why they have no shoes , laces or mattresses . leifman says about one-third of all people in miami-dade county jails are mentally ill . so , he says , the sheer volume is overwhelming the system , and the result is what we see on the ninth floor . of course , it is a jail , so it 's not supposed to be warm and comforting , but the lights glare , the cells are tiny and it 's loud . we see two , sometimes three men -- sometimes in the robes , sometimes naked , lying or sitting in their cells . `` i am the son of the president . you need to get me out of here ! '' one man shouts at me . he is absolutely serious , convinced that help is on the way -- if only he could reach the white house . leifman tells me that these prisoner-patients will often circulate through the system , occasionally stabilizing in a mental hospital , only to return to jail to face their charges . it 's brutally unjust , in his mind , and he has become a strong advocate for changing things in miami . over a meal later , we talk about how things got this way for mental patients . leifman says 200 years ago people were considered `` lunatics '' and they were locked up in jails even if they had no charges against them . they were just considered unfit to be in society . over the years , he says , there was some public outcry , and the mentally ill were moved out of jails and into hospitals . but leifman says many of these mental hospitals were so horrible they were shut down . where did the patients go ? nowhere . the streets . they became , in many cases , the homeless , he says . they never got treatment . leifman says in 1955 there were more than half a million people in state mental hospitals , and today that number has been reduced 90 percent , and 40,000 to 50,000 people are in mental hospitals . the judge says he 's working to change this . starting in 2008 , many inmates who would otherwise have been brought to the `` forgotten floor '' will instead be sent to a new mental health facility -- the first step on a journey toward long-term treatment , not just punishment . leifman says it 's not the complete answer , but it 's a start . leifman says the best part is that it 's a win-win solution . the patients win , the families are relieved , and the state saves money by simply not cycling these prisoners through again and again . and , for leifman , justice is served . e-mail to a friend .\\n\",\n", + " 'tgt': \" mentally ill inmates in miami are housed on the `` forgotten floor '' judge steven leifman says most are there as a result of `` avoidable felonies '' while cnn tours facility , patient shouts : `` i am the son of the president '' leifman says the system is unjust and he 's fighting for change . \\n\",\n", + " 'tgt_txt': \" mentally ill inmates in miami are housed on the `` forgotten floor '' judge steven leifman says most are there as a result of `` avoidable felonies '' while cnn tours facility , patient shouts : `` i am the son of the president '' leifman says the system is unjust and he 's fighting for change . \\n\",\n", + " 'source_ids': tensor([ 0, 4474, 128, ..., 1, 1, 1]),\n", + " 'src_mask': tensor([1, 1, 1, ..., 0, 0, 0]),\n", + " 'target_ids': tensor([ 0, 28696, 90, 15698, 10072, 4812, 8039, 11, 475, 40879,\n", + " 32, 15740, 15, 5, 45518, 9885, 1929, 12801, 49703, 90,\n", + " 15698, 28696, 90, 15698, 1679, 11235, 2987, 2084, 1594, 397,\n", + " 161, 144, 32, 89, 25, 10, 898, 9, 45518, 1877,\n", + " 868, 14383, 17130, 12801, 49703, 90, 15698, 28696, 90, 15698,\n", + " 150, 740, 15688, 10182, 2122, 2])}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "abs_sum_train[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Exception in thread Thread-24:\n", + "Traceback (most recent call last):\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/threading.py\", line 916, in _bootstrap_inner\n", + " self.run()\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/threading.py\", line 864, in run\n", + " self._target(*self._args, **self._kwargs)\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\", line 463, in _handle_results\n", + " task = get()\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/connection.py\", line 251, in recv\n", + " return _ForkingPickler.loads(buf.getbuffer())\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/site-packages/torch/multiprocessing/reductions.py\", line 294, in rebuild_storage_fd\n", + " fd = df.detach()\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/resource_sharer.py\", line 58, in detach\n", + " return reduction.recv_handle(conn)\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/reduction.py\", line 182, in recv_handle\n", + " return recvfds(s, 1)[0]\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/reduction.py\", line 161, in recvfds\n", + " len(ancdata))\n", + "RuntimeError: received 0 items of ancdata\n", + "\n", + "Process ForkPoolWorker-71:\n", + "Process ForkPoolWorker-72:\n", + "Process ForkPoolWorker-66:\n", + "Process ForkPoolWorker-62:\n" + ] + }, + { + "ename": "AssertionError", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/dadendev/nlp-recipes/utils_nlp/models/transformers/abstractive_summarization_bartt5.py\u001b[0m in \u001b[0;36mparallel_preprocess\u001b[0;34m(input_data, preprocess, num_pool)\u001b[0m\n\u001b[1;32m 115\u001b[0m results = p.map(\n\u001b[0;32m--> 116\u001b[0;31m \u001b[0mpreprocess\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchunksize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_data\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mnum_pool\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 117\u001b[0m )\n", + "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\u001b[0m in \u001b[0;36mmap\u001b[0;34m(self, func, iterable, chunksize)\u001b[0m\n\u001b[1;32m 265\u001b[0m '''\n\u001b[0;32m--> 266\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_map_async\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0miterable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmapstar\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchunksize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 267\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\u001b[0m in \u001b[0;36mget\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 637\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimeout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 638\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwait\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 639\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mready\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\u001b[0m in \u001b[0;36mwait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 634\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwait\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimeout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 635\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_event\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwait\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 636\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/threading.py\u001b[0m in \u001b[0;36mwait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 550\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0msignaled\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 551\u001b[0;31m \u001b[0msignaled\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cond\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwait\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 552\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0msignaled\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/threading.py\u001b[0m in \u001b[0;36mwait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 294\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtimeout\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 295\u001b[0;31m \u001b[0mwaiter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0macquire\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 296\u001b[0m \u001b[0mgotit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: ", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mabs_sum_test\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprocessor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpreprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_dataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m/dadendev/nlp-recipes/utils_nlp/models/transformers/abstractive_summarization_bartt5.py\u001b[0m in \u001b[0;36mpreprocess\u001b[0;34m(self, input_data_list)\u001b[0m\n\u001b[1;32m 146\u001b[0m )\n\u001b[1;32m 147\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 148\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mparallel_preprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_data_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpreprocess\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_pool\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 149\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mstaticmethod\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/dadendev/nlp-recipes/utils_nlp/models/transformers/abstractive_summarization_bartt5.py\u001b[0m in \u001b[0;36mparallel_preprocess\u001b[0;34m(input_data, preprocess, num_pool)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mPool\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_pool\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 115\u001b[0m results = p.map(\n\u001b[0;32m--> 116\u001b[0;31m \u001b[0mpreprocess\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchunksize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_data\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mnum_pool\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 117\u001b[0m )\n\u001b[1;32m 118\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\u001b[0m in \u001b[0;36m__exit__\u001b[0;34m(self, exc_type, exc_val, exc_tb)\u001b[0m\n\u001b[1;32m 609\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 610\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__exit__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexc_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexc_val\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexc_tb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 611\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mterminate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 612\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 613\u001b[0m \u001b[0;31m#\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\u001b[0m in \u001b[0;36mterminate\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 539\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTERMINATE\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 540\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_worker_handler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTERMINATE\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 541\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_terminate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 542\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 543\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/util.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wr, _finalizer_registry, sub_debug, getpid)\u001b[0m\n\u001b[1;32m 184\u001b[0m sub_debug('finalizer calling %s with args %s and kwargs %s',\n\u001b[1;32m 185\u001b[0m self._callback, self._args, self._kwargs)\n\u001b[0;32m--> 186\u001b[0;31m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 187\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_weakref\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callback\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_args\u001b[0m \u001b[0;34m=\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 188\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_kwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_key\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\u001b[0m in \u001b[0;36m_terminate_pool\u001b[0;34m(cls, taskqueue, inqueue, outqueue, pool, worker_handler, task_handler, result_handler, cache)\u001b[0m\n\u001b[1;32m 571\u001b[0m \u001b[0mcls\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_help_stuff_finish\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minqueue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtask_handler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpool\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 572\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 573\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0mresult_handler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_alive\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcache\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 574\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 575\u001b[0m \u001b[0mresult_handler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTERMINATE\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAssertionError\u001b[0m: " + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Process ForkPoolWorker-69:\n", + "Process ForkPoolWorker-65:\n", + "Process ForkPoolWorker-70:\n", + "Process ForkPoolWorker-64:\n", + "Process ForkPoolWorker-67:\n", + "Process ForkPoolWorker-68:\n", + "Process ForkPoolWorker-63:\n", + "Process ForkPoolWorker-61:\n", + "Traceback (most recent call last):\n", + "Traceback (most recent call last):\n", + "Traceback (most recent call last):\n", + "Traceback (most recent call last):\n", + "Traceback (most recent call last):\n", + "Traceback (most recent call last):\n", + "Traceback (most recent call last):\n", + "Traceback (most recent call last):\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n", + " self.run()\n", + "Traceback (most recent call last):\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n", + " self.run()\n", + "Traceback (most recent call last):\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n", + " self.run()\n", + "Traceback (most recent call last):\n", + "Traceback (most recent call last):\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n", + " self.run()\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n", + " self.run()\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n", + " self.run()\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n", + " self.run()\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n", + " self.run()\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 93, in run\n", + " self._target(*self._args, **self._kwargs)\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n", + " self.run()\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 93, in run\n", + " self._target(*self._args, **self._kwargs)\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 93, in run\n", + " self._target(*self._args, **self._kwargs)\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n", + " self.run()\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n", + " self.run()\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 93, in run\n", + " self._target(*self._args, **self._kwargs)\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n", + " self.run()\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 93, in run\n", + " self._target(*self._args, **self._kwargs)\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 93, in run\n", + " self._target(*self._args, **self._kwargs)\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 93, in run\n", + " self._target(*self._args, **self._kwargs)\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 93, in run\n", + " self._target(*self._args, **self._kwargs)\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\", line 125, in worker\n", + " put((job, i, result))\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 93, in run\n", + " self._target(*self._args, **self._kwargs)\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\", line 108, in worker\n", + " task = get()\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\", line 125, in worker\n", + " put((job, i, result))\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 93, in run\n", + " self._target(*self._args, **self._kwargs)\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 93, in run\n", + " self._target(*self._args, **self._kwargs)\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\", line 125, in worker\n", + " put((job, i, result))\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\", line 108, in worker\n", + " task = get()\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 93, in run\n", + " self._target(*self._args, **self._kwargs)\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\", line 125, in worker\n", + " put((job, i, result))\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\", line 125, in worker\n", + " put((job, i, result))\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\", line 125, in worker\n", + " put((job, i, result))\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/queues.py\", line 346, in put\n", + " with self._wlock:\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/queues.py\", line 334, in get\n", + " with self._rlock:\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\", line 125, in worker\n", + " put((job, i, result))\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/queues.py\", line 346, in put\n", + " with self._wlock:\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\", line 108, in worker\n", + " task = get()\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\", line 108, in worker\n", + " task = get()\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/queues.py\", line 346, in put\n", + " with self._wlock:\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/queues.py\", line 334, in get\n", + " with self._rlock:\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\", line 108, in worker\n", + " task = get()\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/queues.py\", line 346, in put\n", + " with self._wlock:\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/queues.py\", line 346, in put\n", + " with self._wlock:\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/queues.py\", line 346, in put\n", + " with self._wlock:\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/synchronize.py\", line 95, in __enter__\n", + " return self._semlock.__enter__()\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/queues.py\", line 347, in put\n", + " self._writer.send_bytes(obj)\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/synchronize.py\", line 95, in __enter__\n", + " return self._semlock.__enter__()\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/synchronize.py\", line 95, in __enter__\n", + " return self._semlock.__enter__()\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/queues.py\", line 334, in get\n", + " with self._rlock:\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/queues.py\", line 334, in get\n", + " with self._rlock:\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/synchronize.py\", line 95, in __enter__\n", + " return self._semlock.__enter__()\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/synchronize.py\", line 95, in __enter__\n", + " return self._semlock.__enter__()\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/queues.py\", line 335, in get\n", + " res = self._reader.recv_bytes()\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/synchronize.py\", line 95, in __enter__\n", + " return self._semlock.__enter__()\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/synchronize.py\", line 95, in __enter__\n", + " return self._semlock.__enter__()\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/synchronize.py\", line 95, in __enter__\n", + " return self._semlock.__enter__()\n", + "KeyboardInterrupt\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/connection.py\", line 200, in send_bytes\n", + " self._send_bytes(m[offset:offset + size])\n", + "KeyboardInterrupt\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/synchronize.py\", line 95, in __enter__\n", + " return self._semlock.__enter__()\n", + "KeyboardInterrupt\n", + "KeyboardInterrupt\n", + "KeyboardInterrupt\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/synchronize.py\", line 95, in __enter__\n", + " return self._semlock.__enter__()\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/connection.py\", line 216, in recv_bytes\n", + " buf = self._recv_bytes(maxlength)\n", + "KeyboardInterrupt\n", + "KeyboardInterrupt\n", + "KeyboardInterrupt\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/connection.py\", line 398, in _send_bytes\n", + " self._send(buf)\n", + "KeyboardInterrupt\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/connection.py\", line 407, in _recv_bytes\n", + " buf = self._recv(4)\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/connection.py\", line 368, in _send\n", + " n = write(self._handle, buf)\n", + " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/connection.py\", line 379, in _recv\n", + " chunk = read(handle, remaining)\n", + "KeyboardInterrupt\n", + "KeyboardInterrupt\n", + "KeyboardInterrupt\n" + ] + } + ], + "source": [ + "abs_sum_test = processor.preprocess(test_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "\"\"\"\n", + "# save and load preprocessed data\n", + "save_path = os.path.join(DATA_PATH, \"processed\")\n", + "torch.save(ext_sum_train, os.path.join(save_path, \"train_full.pt\"))\n", + "torch.save(ext_sum_test, os.path.join(save_path, \"test_full.pt\"))\n", + "\n", + "\"\"\"\n", + "# ext_sum_train = torch.load(os.path.join(save_path, \"train_full.pt\"))\n", + "# ext_sum_test = torch.load(os.path.join(save_path, \"test_full.pt\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "len(ext_sum_train)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "len(ext_sum_test)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Inspect Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ext_sum_train[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "ext_sum_train[0].keys()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### [Option 2] Reuse Preprocessed data from [BERTSUM Repo](https://github.com/nlpyang/BertSum)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "parameters", + ":w" + ] + }, + "outputs": [], + "source": [ + "# the data path used to downloaded the preprocessed data from BERTSUM Repo.\n", + "# if you have downloaded the dataset, change the code to use that path where the dataset is.\n", + "PROCESSED_DATA_PATH = TemporaryDirectory().name\n", + "os.makedirs(PROCESSED_DATA_PATH, exist_ok=True)\n", + "#data_path = \"./temp_data5/\"\n", + "#PROCESSED_DATA_PATH = data_path" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if USE_PREPROCSSED_DATA:\n", + " download_path = CNNDMBertSumProcessedData.download(local_path=PROCESSED_DATA_PATH)\n", + " ext_sum_train, ext_sum_test = ExtSumProcessedData().splits(root=download_path, train_iterable=True)\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Model training\n", + "To start model training, we need to create a instance of ExtractiveSummarizer.\n", + "#### Choose the transformer model.\n", + "Currently ExtractiveSummarizer support two models:\n", + "- distilbert-base-uncase, \n", + "- bert-base-uncase\n", + "\n", + "Potentionally, roberta-based model and xlnet can be supported but needs to be tested.\n", + "#### Choose the encoder algorithm.\n", + "There are four options:\n", + "- baseline: it used a smaller transformer model to replace the bert model and with transformer summarization layer\n", + "- classifier: it uses pretrained BERT and fine-tune BERT with **simple logistic classification** summarization layer\n", + "- transformer: it uses pretrained BERT and fine-tune BERT with **transformer** summarization layer\n", + "- RNN: it uses pretrained BERT and fine-tune BERT with **LSTM** summarization layer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "tags": [ + "parameters" + ] + }, + "outputs": [], + "source": [ + "BATCH_SIZE = 5 # batch size, unit is the number of samples\n", + "MAX_POS_LENGTH = 512\n", + "if USE_PREPROCSSED_DATA: #if bertsum published data is used\n", + " BATCH_SIZE = 3000 # batch size, unit is the number of tokens\n", + " MAX_POS_LENGTH = 512\n", + " \n", + "\n", + "\n", + "# GPU used for training\n", + "NUM_GPUS = torch.cuda.device_count()\n", + "\n", + "# Encoder name. Options are: 1. baseline, classifier, transformer, rnn.\n", + "ENCODER = \"transformer\"\n", + "\n", + "# Learning rate\n", + "LEARNING_RATE=2e-3\n", + "\n", + "# How often the statistics reports show up in training, unit is step.\n", + "REPORT_EVERY=100\n", + "\n", + "# total number of steps for training\n", + "MAX_STEPS=1e2\n", + "# number of steps for warm up\n", + "WARMUP_STEPS=5e2\n", + " \n", + "if not QUICK_RUN:\n", + " MAX_STEPS=5e4\n", + " WARMUP_STEPS=5e3\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "summarizer = ExtractiveSummarizer(processor, MODEL_NAME, ENCODER, MAX_POS_LENGTH, CACHE_DIR)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "#\"\"\"\n", + "\n", + "summarizer.fit(\n", + " ext_sum_train,\n", + " num_gpus=NUM_GPUS,\n", + " batch_size=BATCH_SIZE,\n", + " gradient_accumulation_steps=2,\n", + " max_steps=MAX_STEPS,\n", + " learning_rate=LEARNING_RATE,\n", + " warmup_steps=WARMUP_STEPS,\n", + " verbose=True,\n", + " report_every=REPORT_EVERY,\n", + " clip_grad_norm=False,\n", + " use_preprocessed_data=USE_PREPROCSSED_DATA\n", + " )\n", + "\n", + "#\"\"\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "summarizer.save_model(\n", + " os.path.join(\n", + " CACHE_DIR,\n", + " \"extsum_modelname_{0}_usepreprocess{1}_steps_{2}.pt\".format(\n", + " MODEL_NAME, USE_PREPROCSSED_DATA, MAX_STEPS\n", + " ),\n", + " )\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# for loading a previous saved model\n", + "\"\"\"\n", + "import torch\n", + "model_path = os.path.join(\n", + " CACHE_DIR,\n", + " \"extsum_modelname_{0}_usepreprocess{1}_steps_{2}.pt\".format(\n", + " MODEL_NAME, USE_PREPROCSSED_DATA, MAX_STEPS\n", + " ))\n", + "summarizer = ExtractiveSummarizer(processor, MODEL_NAME, ENCODER, MAX_POS_LENGTH, CACHE_DIR)\n", + "summarizer.model.load_state_dict(torch.load(model_path, map_location=\"cpu\"))\n", + "\"\"\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Model Evaluation\n", + "\n", + "[ROUGE](https://en.wikipedia.org/wiki/ROUGE_(metric)), or Recall-Oriented Understudy for Gisting Evaluation has been commonly used for evaluating text summarization." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ext_sum_test[0].keys()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if \"segs\" in ext_sum_test[0]: # preprocessed_data\n", + " source = [i['src_txt'] for i in ext_sum_test]\n", + " target = [\"\\n\".join(i['tgt_txt'].split(\"\")) for i in ext_sum_test]\n", + "else:\n", + " source = []\n", + " temp_target = []\n", + " for i in ext_sum_test:\n", + " source.append(i[\"src_txt\"]) \n", + " temp_target.append(\" \".join(j) for j in i['tgt']) \n", + " target = [''.join(i) for i in list(temp_target)]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "%%time\n", + "sentence_separator = \"\\n\"\n", + "prediction = summarizer.predict(ext_sum_test, num_gpus=NUM_GPUS, batch_size=256, sentence_separator=sentence_separator)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "len(prediction)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "rouge_scores = compute_rouge_python(cand=prediction, ref=target)\n", + "pprint.pprint(rouge_scores)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "target[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prediction[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "source[0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# for testing\n", + "sb.glue(\"rouge_2_f_score\", rouge_scores['rouge-2']['f'])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prediction on a single input sample" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "source = \"\"\"\n", + "But under the new rule, set to be announced in the next 48 hours, Border Patrol agents would immediately return anyone to Mexico — without any detainment and without any due process — who attempts to cross the southwestern border between the legal ports of entry. The person would not be held for any length of time in an American facility.\n", + "\n", + "Although they advised that details could change before the announcement, administration officials said the measure was needed to avert what they fear could be a systemwide outbreak of the coronavirus inside detention facilities along the border. Such an outbreak could spread quickly through the immigrant population and could infect large numbers of Border Patrol agents, leaving the southwestern border defenses weakened, the officials argued.\n", + "The Trump administration plans to immediately turn back all asylum seekers and other foreigners attempting to enter the United States from Mexico illegally, saying the nation cannot risk allowing the coronavirus to spread through detention facilities and Border Patrol agents, four administration officials said.\n", + "The administration officials said the ports of entry would remain open to American citizens, green-card holders and foreigners with proper documentation. Some foreigners would be blocked, including Europeans currently subject to earlier travel restrictions imposed by the administration. The points of entry will also be open to commercial traffic.\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "test_dataset = SummarizationDataset(\n", + " None,\n", + " source=[source],\n", + " source_preprocessing=[tokenize.sent_tokenize],\n", + " word_tokenize=nltk.word_tokenize,\n", + ")\n", + "processor = ExtSumProcessor(model_name=MODEL_NAME, cache_dir=CACHE_DIR)\n", + "preprocessed_dataset = processor.preprocess(test_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "preprocessed_dataset[0].keys()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prediction = summarizer.predict(preprocessed_dataset, num_gpus=0, batch_size=1, sentence_separator=\"\\n\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "prediction" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Clean up temporary folders" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if os.path.exists(DATA_PATH):\n", + " shutil.rmtree(DATA_PATH, ignore_errors=True)\n", + "if os.path.exists(CACHE_DIR):\n", + " shutil.rmtree(CACHE_DIR, ignore_errors=True)\n", + "if USE_PREPROCSSED_DATA:\n", + " if os.path.exists(PROCESSED_DATA_PATH):\n", + " shutil.rmtree(PROCESSED_DATA_PATH, ignore_errors=True)" + ] + } + ], + "metadata": { + "celltoolbar": "Tags", + "kernelspec": { + "display_name": "nlp_gpu", + "language": "python", + "name": "nlp_gpu" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/utils_nlp/dataset/cnndm.py b/utils_nlp/dataset/cnndm.py index ce20982ef..f1deb5062 100644 --- a/utils_nlp/dataset/cnndm.py +++ b/utils_nlp/dataset/cnndm.py @@ -71,7 +71,7 @@ def CNNDMSummarizationDataset(*args, **kwargs): URLS = ["https://s3.amazonaws.com/opennmt-models/Summary/cnndm.tar.gz"] def _setup_datasets( - url, top_n=-1, local_cache_path=".data", prepare_extractive=True + url, top_n=-1, local_cache_path=".data", raw=False, prepare_extractive=True ): FILE_NAME = "cnndm.tar.gz" maybe_download(url, FILE_NAME, local_cache_path) @@ -86,6 +86,20 @@ def _setup_datasets( test_source_file = fname if fname.endswith("test.txt.tgt.tagged"): test_target_file = fname + if raw: + return ( + SummarizationDataset( + train_source_file, + target_file=train_target_file, + top_n=top_n + ), + SummarizationDataset( + test_source_file, + target_file=test_target_file, + top_n=top_n + ), + + ) if prepare_extractive: diff --git a/utils_nlp/models/transformers/abstractive_summarization_bartt5.py b/utils_nlp/models/transformers/abstractive_summarization_bartt5.py new file mode 100644 index 000000000..bbdf3389c --- /dev/null +++ b/utils_nlp/models/transformers/abstractive_summarization_bartt5.py @@ -0,0 +1,604 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +# This script reuses some code from https://github.com/nlpyang/Presumm +# This script reuses some code from https://github.com/huggingface/transformers/ +# Add to noticefile + +from collections import namedtuple +import functools +import logging +from multiprocessing import Pool, cpu_count +import os +import pickle +from tqdm import tqdm + +import torch +from torch.utils.data import ( + DataLoader, + SequentialSampler, + RandomSampler, +) +import torch.multiprocessing +torch.multiprocessing.set_sharing_strategy('file_system') + +from torch.utils.data.distributed import DistributedSampler +from transformers import BertModel + +from utils_nlp.common.pytorch_utils import ( + compute_training_steps, + get_device, + get_amp, + move_model_to_device, + parallelize_model, +) +from utils_nlp.eval import compute_rouge_python +from utils_nlp.models.transformers.common import TOKENIZER_CLASS, Transformer + +#from transformers.modeling_bart import BART_PRETRAINED_MODEL_ARCHIVE_MAP + +from transformers import ( + AutoConfig, + AutoModel, + AutoModelWithLMHead, + AutoTokenizer, +) + +MODEL_MODES = { + "language-modeling": AutoModelWithLMHead, +} + +logger = logging.getLogger(__name__) + + +import os + +import torch +from torch.utils.data import Dataset + +from transformers.tokenization_utils import trim_batch + + +def encode_file(tokenizer, data_path, max_length, pad_to_max_length=True, return_tensors="pt"): + examples = [] + with open(data_path, "r") as f: + for text in f.readlines(): + tokenized = tokenizer.batch_encode_plus( + [text], max_length=max_length, pad_to_max_length=pad_to_max_length, return_tensors=return_tensors, + ) + examples.append(tokenized) + return examples + +def encode_example(example, tokenizer=None, max_source_length=None, max_target_length=None, pad_to_max_length=True, return_tensors="pt"): + #examples = [] + #with open(data_path, "r") as f: + # for text in f.readlines(): + #for text in text_lines: + ## add to the dataset + tokenized_source = tokenizer.batch_encode_plus( + [example['src']], max_length=max_source_length, pad_to_max_length=pad_to_max_length, return_tensors=return_tensors, + ) + + source_ids = tokenized_source["input_ids"].squeeze() + src_mask = tokenized_source["attention_mask"].squeeze() + example["source_ids"] = source_ids + example["src_mask"] = src_mask + if 'tgt' in example: + tokenized_target = tokenizer.batch_encode_plus( + [example['tgt']], max_length=max_target_length, pad_to_max_length=pad_to_max_length, return_tensors=return_tensors, + ) + target_ids = tokenized_target["input_ids"].squeeze() + example["target_ids"] = target_ids + return example + +def parallel_preprocess(input_data, preprocess, num_pool=-1): + """ + Process data in parallel using multiple GPUs. + + Args: + input_data (list): List if input strings to process. + preprocess (function): function to apply on the input data. + word_tokenize (func, optional): A tokenization function used to tokenize + the results from preprocess_pipeline. + num_pool (int, optional): Number of CPUs to use. Defaults to -1 and all + available CPUs are used. + + Returns: + list: list of processed text strings. + + """ + if num_pool == -1: + num_pool = cpu_count() + + num_pool = min(num_pool, len(input_data)) + + result = None + with Pool(num_pool) as p: + results = p.map( + preprocess, input_data, chunksize=max(1, int(len(input_data) / num_pool)), + ) + + p.close() + #p.join() + + return results + + +class SummarizationProcessor: + def __init__( + self, + tokenizer, + #with_target=False, + max_source_length=1024, + max_target_length=56, + ): + #super().__init__() + self.tokenizer = tokenizer + #self.source = source_examples #encode_file(tokenizer, os.path.join(data_dir, type_path + ".source"), max_source_length) + self.with_target = False + self.max_source_length = max_source_length + self.max_target_length = max_target_length + #if with_target: + # self.with_target = True + # self.target = source_examples #encode_file(tokenizer, os.path.join(data_dir, type_path + ".target"), max_target_length) + + def preprocess(self, input_data_list): + preprocess = functools.partial( + encode_example, tokenizer=self.tokenizer, max_source_length=self.max_source_length, max_target_length=self.max_target_length + ) + + return parallel_preprocess(input_data_list, preprocess, num_pool=-1) + + @staticmethod + def trim_seq2seq_batch(batch, pad_token_id): + y = trim_batch(batch["target_ids"], pad_token_id) + source_ids, source_mask = trim_batch(batch["source_ids"], pad_token_id, attention_mask=batch["source_mask"]) + return source_ids, source_mask, y + + def collate_fn(self, batch, with_target=False): + input_ids = torch.stack([x["source_ids"] for x in batch]) + masks = torch.stack([x["source_mask"] for x in batch]) + pad_token_id = self.tokenizer.pad_token_id + source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) + if with_target: + target_ids = torch.stack([x["target_ids"] for x in batch]) + y = trim_batch(target_ids, pad_token_id) + return {"source_ids": source_ids, "source_mask": source_mask, "target_ids": y} + else: + return {"source_ids": source_ids, "source_mask": source_mask} + + +class AbstractiveSummarizer(Transformer): + """class which performs abstractive summarization fine tuning and + prediction based on BertSumAbs model """ + + def __init__( + self, + model_name="bert-base-uncased", + cache_dir=".", + max_source_length=1024, + max_target_length=240 + ): + """Initialize an object of BertSumAbs. + + Args: + model_name (str, optional:) Name of the pretrained model which is used + to initialize the encoder of the BertSumAbs model. + check MODEL_CLASS for supported models. Defaults to "bert-base-uncased". + cache_dir (str, optional): Directory to cache the tokenizer. Defaults to ".". + max_pos_length (int, optional): maximum postional embedding length for the + input. Defaults to 768. + """ + + super().__init__( + model_class=AutoModelWithLMHead, + model_name=model_name, + num_labels=0, + cache_dir=cache_dir, + ) + """ + if model_name not in self.list_supported_models(): + raise ValueError( + "Model name {} is not supported by BertSumAbs. " + "Call 'BertSumAbs.list_supported_models()' to get all supported model " + "names.".format(value) + ) + """ + self.config = AutoConfig.from_pretrained( + model_name, + #self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path, + #**({"num_labels": num_labels} if num_labels is not None else {}), + cache_dir=cache_dir, + #**config_kwargs, + ) + + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + #self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path, + cache_dir=cache_dir, + ) + + self.model = MODEL_MODES[mode].from_pretrained( + self.model_name, + #from_tf=bool(".ckpt" in self.hparams.model_name_or_path), + config=self.config, + cache_dir=cache_dir, + ) + + self.model_class = AutoModelWithLMHead #MODEL_CLASS[model_name] + self.cache_dir = cache_dir + self.max_source_length = max_source_length + self.max_target_length = max_target_length + + @staticmethod + def list_supported_models(): + return list(MODEL_CLASS.keys()) + + def fit( + self, + train_dataset, + num_gpus=None, + gpu_ids=None, + batch_size=4, + local_rank=-1, + max_steps=5e4, + warmup_steps=20000, + learning_rate=0.002, + weight_decay=0.01, + adam_epsilon=1e-8, + max_grad_norm=1.0, + gradient_accumulation_steps=1, + report_every=10, + save_every=1000, + verbose=True, + seed=None, + fp16=False, + fp16_opt_level="O2", + world_size=1, + rank=0, + validation_function=None, + checkpoint=None, + **kwargs, + ): + """ + Fine-tune pre-trained transofmer models for extractive summarization. + + Args: + train_dataset (SummarizationDataset): Training dataset. + num_gpus (int, optional): The number of GPUs to use. If None, all + available GPUs will be used. If set to 0 or GPUs are not available, + CPU device will be used. Defaults to None. + gpu_ids (list): List of GPU IDs to be used. + If set to None, the first num_gpus GPUs will be used. + Defaults to None. + batch_size (int, optional): Maximum number of tokens in each batch. + local_rank (int, optional): Local_rank for distributed training on GPUs. + Local rank means the ranking of the current GPU device on the current + node. Defaults to -1, which means non-distributed training. + max_steps (int, optional): Maximum number of training steps. Defaults to 5e5. + warmup_steps_bert (int, optional): Number of steps taken to increase + learning rate from 0 to `learning_rate` for tuning the BERT encoder. + Defaults to 2e4. + warmup_steps_dec (int, optional): Number of steps taken to increase + learning rate from 0 to `learning_rate` for tuning the decoder. + Defaults to 1e4. + learning_rate_bert (float, optional): Learning rate of the optimizer + for the encoder. Defaults to 0.002. + learning_rate_dec (float, optional): Learning rate of the optimizer + for the decoder. Defaults to 0.2. + optimization_method (string, optional): Optimization method used in fine + tuning. Defaults to "adam". + max_grad_norm (float, optional): Maximum gradient norm for gradient clipping. + Defaults to 0. + beta1 (float, optional): The exponential decay rate for the first moment + estimates. Defaults to 0.9. + beta2 (float, optional): The exponential decay rate for the second-moment + estimates. This value should be set close to 1.0 on problems with + a sparse gradient. Defaults to 0.99. + decay_method (string, optional): learning rate decrease method. + Default to 'noam'. + gradient_accumulation_steps (int, optional): Number of batches to accumulate + gradients on between each model parameter update. Defaults to 1. + report_every (int, optional): The interval by steps to print out the + training log. Defaults to 10. + save_every (int, optional): The interval by steps to save the finetuned + model. Defaults to 100. + verbose (bool, optional): Whether to print out the training log. + Defaults to True. + seed (int, optional): Random seed used to improve reproducibility. + Defaults to None. + fp16 (bool, optional): Whether to use mixed precision training. + Defaults to False. + fp16_opt_level (str, optional): optimization level, refer to + https://nvidia.github.io/apex/amp.html#opt-levels for details. + Value choices are: "O0", "O1", "O2", "O3". Defaults to "O2". + world_size (int, optional): Total number of GPUs that will be used. + Defaults to 1. + rank (int, optional): Global rank of the current GPU in distributed + training. It's calculated with the rank of the current node in the + cluster/world and the `local_rank` of the device in the current node. + See an example in :file: `examples/text_summarization/ + abstractive_summarization_bertsum_cnndm_distributed_train.py`. + Defaults to 0. + validation_function (function, optional): function used in fitting to + validate the performance. Default to None. + checkpoint (str, optional): file path for a checkpoint based on which the + training continues. Default to None. + """ + + # move model to devices + print("device is {}".format(device)) + if checkpoint: + # checkpoint should have "model", "optimizer", "amp" + checkpoint_state_dict = torch.load(checkpoint, map_location="cpu") + + # init optimizer + device, num_gpus, amp = self.prepare_model_and_optimizer( + num_gpus=num_gpus, + gpu_ids=gpu_ids, + local_rank=local_rank, + fp16=fp16, + fp16_opt_level=fp16_opt_level, + weight_decay=weight_decay, + learning_rate=learning_rate, + adam_epsilon=adam_epsilon, + checkpoint_state_dict=checkpoint_state_dict, + ) + + global_step = 0 + if "global_step" in checkpoint_state_dict and checkpoint_state_dict["global_step"]: + global_step = checkpoint_state_dict["global_step"] / world_size + print("global_step is {}".format(global_step)) + + self.scheduler = Transformer.get_default_scheduler( + optimizer=self.optimizer, + warmup_steps=warmup_steps, + num_training_steps=max_steps, + ) + if global_step > 0: + self.scheduler.load_state_dict(checkpoint_state_dict["lr_scheduler"]) + + if local_rank == -1: + sampler = RandomSampler(train_dataset) + else: + sampler = DistributedSampler( + train_dataset, num_replicas=world_size, rank=rank + ) + + def collate_fn(data): + return self.processor.collate( + data, block_size=self.max_pos_length, device=device + ) + + train_dataloader = DataLoader( + train_dataset, + sampler=sampler, + batch_size=batch_size, + collate_fn=collate_fn, + ) + + # compute the max number of training steps + max_steps = compute_training_steps( + train_dataloader, + max_steps=max_steps, + gradient_accumulation_steps=gradient_accumulation_steps, + ) + + super().fine_tune( + train_dataloader=train_dataloader, + get_inputs=xxxx.get_inputs, + device=device, + num_gpus=num_gpus, + max_steps=max_steps, + global_step=global_step, + max_grad_norm=max_grad_norm, + gradient_accumulation_steps=gradient_accumulation_steps, + verbose=verbose, + seed=seed, + report_every=report_every, + save_every=save_every, + clip_grad_norm=False, + optimizer=optimizers, + scheduler=None, + fp16=fp16, + amp=self.amp, + validation_function=validation_function, + ) + + # release GPU memories + self.model.cpu() + torch.cuda.empty_cache() + + self.save_model(max_steps) + + def predict( + self, + test_dataset, + num_gpus=None, + gpu_ids=None, + local_rank=-1, + batch_size=16, + alpha=0.6, + beam_size=5, + min_length=15, + max_length=150, + fp16=False, + verbose=True, + ): + """ + Predict the summarization for the input data iterator. + + Args: + test_dataset (SummarizationDataset): Dataset for which the summary + to be predicted. + num_gpus (int, optional): The number of GPUs used in prediction. + Defaults to 1. + gpu_ids (list): List of GPU IDs to be used. + If set to None, the first num_gpus GPUs will be used. + Defaults to None. + local_rank (int, optional): Local rank of the device in distributed + inferencing. Defaults to -1, which means non-distributed inferencing. + batch_size (int, optional): The number of test examples in each batch. + Defaults to 16. + alpha (float, optional): Length penalty. Defaults to 0.6. + beam_size (int, optional): Beam size of beam search. Defaults to 5. + min_length (int, optional): Minimum number of tokens in the output sequence. + Defaults to 15. + max_length (int, optional): Maximum number of tokens in output + sequence. Defaults to 150. + fp16 (bool, optional): Whether to use half-precision model for prediction. + Defaults to False. + verbose (bool, optional): Whether to print out the training log. + Defaults to True. + + Returns: + List of strings which are the summaries + + """ + device, num_gpus = get_device( + num_gpus=num_gpus, gpu_ids=gpu_ids, local_rank=local_rank + ) + + # move model to devices + def this_model_move_callback(model, device): + model = move_model_to_device(model, device) + return parallelize_model( + model, device, num_gpus=num_gpus, gpu_ids=gpu_ids, local_rank=local_rank + ) + + if fp16: + self.model = self.model.half() + + self.model = move_model_to_device(self.model, device) + self.model.eval() + + predictor = build_predictor( + self.processor.tokenizer, + self.processor.symbols, + self.model, + alpha=alpha, + beam_size=beam_size, + min_length=min_length, + max_length=max_length, + ) + predictor = this_model_move_callback(predictor, device) + self.model = parallelize_model( + self.model, + device, + num_gpus=num_gpus, + gpu_ids=gpu_ids, + local_rank=local_rank, + ) + + test_sampler = SequentialSampler(test_dataset) + + def collate_fn(data): + return self.processor.collate( + data, self.max_pos_length, device, train_mode=False + ) + + test_dataloader = DataLoader( + test_dataset, + sampler=test_sampler, + batch_size=batch_size, + collate_fn=collate_fn, + ) + print("dataset length is {}".format(len(test_dataset))) + + def format_summary(translation): + """ Transforms the output of the `from_batch` function + into nicely formatted summaries. + """ + raw_summary = translation + summary = ( + raw_summary.replace("[unused0]", "") + .replace("[unused3]", "") + .replace("[CLS]", "") + .replace("[SEP]", "") + .replace("[PAD]", "") + .replace("[unused1]", "") + .replace(r" +", " ") + .replace(" [unused2] ", ".") + .replace("[unused2]", "") + .strip() + ) + + return summary + + def generate_summary_from_tokenid(preds, pred_score): + batch_size = preds.size()[0] # batch.batch_size + translations = [] + for b in range(batch_size): + if len(preds[b]) < 1: + pred_sents = "" + else: + pred_sents = self.processor.tokenizer.convert_ids_to_tokens( + [int(n) for n in preds[b] if int(n) != 0] + ) + pred_sents = " ".join(pred_sents).replace(" ##", "") + translations.append(pred_sents) + return translations + + generated_summaries = [] + + for batch in tqdm( + test_dataloader, desc="Generating summary", disable=not verbose + ): + input = self.processor.get_inputs(batch, device, "bert", train_mode=False) + translations, scores = predictor(**input) + + translations_text = generate_summary_from_tokenid(translations, scores) + summaries = [format_summary(t) for t in translations_text] + generated_summaries.extend(summaries) + + # release GPU memories + self.model.cpu() + torch.cuda.empty_cache() + + return generated_summaries + + def save_model(self, global_step=None, full_name=None): + """ + save the trained model. + + Args: + global_step (int, optional): The number of steps that the model has been + finetuned for. Defaults to None. + full_name (str, optional): File name to save the model's `state_dict()`. + If it's None, the model is going to be saved under "fine_tuned" folder + of the cached directory of the object. Defaults to None. + """ + model_to_save = ( + self.model.module if hasattr(self.model, "module") else self.model + ) # Take care of distributed/parallel training + + if full_name is None: + output_model_dir = os.path.join(self.cache_dir, "fine_tuned") + os.makedirs(self.cache_dir, exist_ok=True) + os.makedirs(output_model_dir, exist_ok=True) + full_name = os.path.join(output_model_dir, "bertsumabs.pt") + else: + path, filename = os.path.split(full_name) + print(path) + os.makedirs(path, exist_ok=True) + + checkpoint = { + "optimizers": [self.optim_bert.state_dict(), self.optim_dec.state_dict()], + "model": model_to_save.state_dict(), + "amp": self.amp.state_dict() if self.amp else None, + "global_step": global_step, + "max_pos_length": self.max_pos_length, + } + + logger.info("Saving model checkpoint to %s", full_name) + try: + print("saving through pytorch to {}".format(full_name)) + torch.save(checkpoint, full_name) + except OSError: + try: + print("saving as pickle") + pickle.dump(checkpoint, open(full_name, "wb")) + except Exception: + raise + except Exception: + raise diff --git a/utils_nlp/models/transformers/datasets.py b/utils_nlp/models/transformers/datasets.py index 0c659f190..e21e7d95b 100644 --- a/utils_nlp/models/transformers/datasets.py +++ b/utils_nlp/models/transformers/datasets.py @@ -519,7 +519,7 @@ def parallel_preprocess( word_tokenize=word_tokenize, ), input_data, - chunksize=min(1, int(len(input_data) / num_pool)), + chunksize=max(1, int(len(input_data) / num_pool)), ) p.close() p.join() diff --git a/utils_nlp/models/transformers/extractive_summarization.py b/utils_nlp/models/transformers/extractive_summarization.py index 2753685df..e7cb7d84f 100644 --- a/utils_nlp/models/transformers/extractive_summarization.py +++ b/utils_nlp/models/transformers/extractive_summarization.py @@ -302,7 +302,7 @@ def parallel_preprocess(input_data, preprocess, num_pool=-1): p = Pool(num_pool) results = p.map( - preprocess, input_data, chunksize=min(1, int(len(input_data) / num_pool)), + preprocess, input_data, chunksize=max(1, int(len(input_data) / num_pool)), ) p.close() p.join() From b9f3451911fb0ce9985c6189d57279eb1367ffd9 Mon Sep 17 00:00:00 2001 From: Daisy Deng Date: Thu, 30 Apr 2020 20:19:52 +0000 Subject: [PATCH 03/14] prediction function complete --- ...ive_summarization_cnndm_transformers.ipynb | 966 +++++++++--------- .../abstractive_summarization_bartt5.py | 157 ++- 2 files changed, 552 insertions(+), 571 deletions(-) diff --git a/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb b/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb index dbdbdc551..914dab416 100644 --- a/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb +++ b/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb @@ -45,17 +45,6 @@ { "cell_type": "code", "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext autoreload\n", - "\n", - "%autoreload 2" - ] - }, - { - "cell_type": "code", - "execution_count": 2, "metadata": { "tags": [ "parameters" @@ -63,6 +52,9 @@ }, "outputs": [], "source": [ + "%load_ext autoreload\n", + "\n", + "%autoreload 2\n", "## Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.\n", "QUICK_RUN = True\n" ] @@ -76,7 +68,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": { "scrolled": true }, @@ -94,7 +86,7 @@ "\n", "from utils_nlp.dataset.cnndm import CNNDMBertSumProcessedData, CNNDMSummarizationDataset\n", "from utils_nlp.eval import compute_rouge_python, compute_rouge_perl\n", - "from utils_nlp.models.transformers.abstractive_summarization_bartt5 import SummarizationProcessor\n", + "from utils_nlp.models.transformers.abstractive_summarization_bartt5 import AbstractiveSummarizer, SummarizationProcessor\n", "\n", "from utils_nlp.models.transformers.datasets import SummarizationDataset\n", "import nltk\n", @@ -122,28 +114,16 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'ExtractiveSummarizer' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mpd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataFrame\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m\"model_name\"\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mExtractiveSummarizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlist_supported_models\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mNameError\u001b[0m: name 'ExtractiveSummarizer' is not defined" - ] - } - ], + "outputs": [], "source": [ "#pd.DataFrame({\"model_name\": ExtractiveSummarizer.list_supported_models()})" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 8, "metadata": { "tags": [ "parameters" @@ -152,33 +132,63 @@ "outputs": [], "source": [ "# Transformer model being used\n", - "MODEL_NAME = \"bart-large-cnn\"" + "MODEL_NAME = \"t5-large\"" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "56813bb1ba0e43d598bd19918bef2a80", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=791656.0, style=ProgressStyle(descripti…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b487ea9a9776423faa51b7d3990f8cb4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1200.0, style=ProgressStyle(description…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "# notebook parameters\n", "# the cache data path during find tuning\n", - "from transformers import BartForConditionalGeneration, BartTokenizer\n", - "from tqdm import tqdm\n", - "from transformers import T5ForConditionalGeneration, T5Tokenizer\n", - "\n", - "model_class = {\n", - " \"bart-large-cnn\": BartForConditionalGeneration,\n", - " \"t5-large\":T5ForConditionalGeneration\n", - "}\n", - "tokenizer_class = {\n", - " \"bart-large-cnn\": BartTokenizer,\n", - " \"t5-large\": T5Tokenizer\n", - "}\n", "CACHE_DIR = TemporaryDirectory().name\n", - "tokenizer = tokenizer_class[MODEL_NAME].from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR) # b\n", - "\n", - "processor = SummarizationProcessor(tokenizer)" + " \n", + "processor = SummarizationProcessor(MODEL_NAME,cache_dir=CACHE_DIR ) #tokenizer, config.prefix)" ] }, { @@ -187,20 +197,12 @@ "source": [ "### Data Preprocessing\n", "\n", - "The dataset we used for this notebook is CNN/DM dataset which contains the documents and accompanying questions from the news articles of CNN and Daily mail. The highlights in each article are used as summary. The dataset consits of ~289K training examples, ~11K valiation examples and ~11K test examples. \n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "##### [Option 1] Preprocess data (Please skil this part if you choose to use preprocessed data)\n", - "The code in following cell will download the CNN/DM dataset listed at https://github.com/harvardnlp/sent-summary/." + "The dataset we used for this notebook is CNN/DM dataset which contains the documents and accompanying questions from the news articles of CNN and Daily mail. The highlights in each article are used as summary. The dataset consits of ~289K training examples, ~11K valiation examples and ~11K test examples. The code in following cell will download the CNN/DM dataset listed at https://github.com/harvardnlp/sent-summary/.\n" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 12, "metadata": { "tags": [ "parameters" @@ -218,7 +220,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 13, "metadata": { "scrolled": false }, @@ -227,7 +229,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 489k/489k [00:06<00:00, 70.8kKB/s] \n" + "100%|██████████| 489k/489k [00:08<00:00, 56.3kKB/s] \n" ] } ], @@ -235,69 +237,6 @@ "train_dataset, test_dataset = CNNDMSummarizationDataset(top_n=TOP_N, local_cache_path=DATA_PATH, raw=True)" ] }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "dict_keys(['src', 'src_txt', 'tgt', 'tgt_txt'])" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "test_dataset[0].keys()" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "1000" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "len(test_dataset)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'src': \"marseille , france -lrb- cnn -rrb- the french prosecutor leading an investigation into the crash of germanwings flight 9525 insisted wednesday that he was not aware of any video footage from on board the plane . marseille prosecutor brice robin told cnn that `` so far no videos were used in the crash investigation . '' he added , `` a person who has such a video needs to immediately give it to the investigators . '' robin 's comments follow claims by two magazines , german daily bild and french paris match , of a cell phone video showing the harrowing final seconds from on board germanwings flight 9525 as it crashed into the french alps . all 150 on board were killed . paris match and bild reported that the video was recovered from a phone at the wreckage site . the two publications described the supposed video , but did not post it on their websites . the publications said that they watched the video , which was found by a source close to the investigation . `` one can hear cries of ` my god ' in several languages , '' paris match reported . `` metallic banging can also be heard more than three times , perhaps of the pilot trying to open the cockpit door with a heavy object . towards the end , after a heavy shake , stronger than the others , the screaming intensifies . then nothing . '' `` it is a very disturbing scene , '' said julian reichelt , editor-in-chief of bild online . an official with france 's accident investigation agency , the bea , said the agency is not aware of any such video . lt. col. jean-marc menichini , a french gendarmerie spokesman in charge of communications on rescue efforts around the germanwings crash site , told cnn that the reports were `` completely wrong '' and `` unwarranted . '' cell phones have been collected at the site , he said , but that they `` had n't been exploited yet . '' menichini said he believed the cell phones would need to be sent to the criminal research institute in rosny sous-bois , near paris , in order to be analyzed by specialized technicians working hand-in-hand with investigators . but none of the cell phones found so far have been sent to the institute , menichini said . asked whether staff involved in the search could have leaked a memory card to the media , menichini answered with a categorical `` no . '' reichelt told `` erin burnett : outfront '' that he had watched the video and stood by the report , saying bild and paris match are `` very confident '' that the clip is real . he noted that investigators only revealed they 'd recovered cell phones from the crash site after bild and paris match published their reports . `` that is something we did not know before . ... overall we can say many things of the investigation were n't revealed by the investigation at the beginning , '' he said . what was mental state of germanwings co-pilot ? german airline lufthansa confirmed tuesday that co-pilot andreas lubitz had battled depression years before he took the controls of germanwings flight 9525 , which he 's accused of deliberately crashing last week in the french alps . lubitz told his lufthansa flight training school in 2009 that he had a `` previous episode of severe depression , '' the airline said tuesday . email correspondence between lubitz and the school discovered in an internal investigation , lufthansa said , included medical documents he submitted in connection with resuming his flight training . the announcement indicates that lufthansa , the parent company of germanwings , knew of lubitz 's battle with depression , allowed him to continue training and ultimately put him in the cockpit . lufthansa , whose ceo carsten spohr previously said lubitz was 100 % fit to fly , described its statement tuesday as a `` swift and seamless clarification '' and said it was sharing the information and documents -- including training and medical records -- with public prosecutors . spohr traveled to the crash site wednesday , where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside . he saw the crisis center set up in seyne-les-alpes , laid a wreath in the village of le vernet , closer to the crash site , where grieving families have left flowers at a simple stone memorial . menichini told cnn late tuesday that no visible human remains were left at the site but recovery teams would keep searching . french president francois hollande , speaking tuesday , said that it should be possible to identify all the victims using dna analysis by the end of the week , sooner than authorities had previously suggested . in the meantime , the recovery of the victims ' personal belongings will start wednesday , menichini said . among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board . check out the latest from our correspondents . the details about lubitz 's correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and lubitz 's possible motive for downing the jet . a lufthansa spokesperson told cnn on tuesday that lubitz had a valid medical certificate , had passed all his examinations and `` held all the licenses required . '' earlier , a spokesman for the prosecutor 's office in dusseldorf , christoph kumpa , said medical records reveal lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot 's license . kumpa emphasized there 's no evidence suggesting lubitz was suicidal or acting aggressively before the crash . investigators are looking into whether lubitz feared his medical condition would cause him to lose his pilot 's license , a european government official briefed on the investigation told cnn on tuesday . while flying was `` a big part of his life , '' the source said , it 's only one theory being considered . another source , a law enforcement official briefed on the investigation , also told cnn that authorities believe the primary motive for lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems . lubitz 's girlfriend told investigators he had seen an eye doctor and a neuropsychologist , both of whom deemed him unfit to work recently and concluded he had psychological issues , the european government official said . but no matter what details emerge about his previous mental health struggles , there 's more to the story , said brian russell , a forensic psychologist . `` psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they were n't going to keep doing their job and they 're upset about that and so they 're suicidal , '' he said . `` but there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person 's problems . '' germanwings crash compensation : what we know . who was the captain of germanwings flight 9525 ? cnn 's margot haddad reported from marseille and pamela brown from dusseldorf , while laura smith-spark wrote from london . cnn 's frederik pleitgen , pamela boykoff , antonia mortensen , sandrine amiel and anna-maja rappard contributed to this report .\\n\",\n", - " 'src_txt': \"marseille , france -lrb- cnn -rrb- the french prosecutor leading an investigation into the crash of germanwings flight 9525 insisted wednesday that he was not aware of any video footage from on board the plane . marseille prosecutor brice robin told cnn that `` so far no videos were used in the crash investigation . '' he added , `` a person who has such a video needs to immediately give it to the investigators . '' robin 's comments follow claims by two magazines , german daily bild and french paris match , of a cell phone video showing the harrowing final seconds from on board germanwings flight 9525 as it crashed into the french alps . all 150 on board were killed . paris match and bild reported that the video was recovered from a phone at the wreckage site . the two publications described the supposed video , but did not post it on their websites . the publications said that they watched the video , which was found by a source close to the investigation . `` one can hear cries of ` my god ' in several languages , '' paris match reported . `` metallic banging can also be heard more than three times , perhaps of the pilot trying to open the cockpit door with a heavy object . towards the end , after a heavy shake , stronger than the others , the screaming intensifies . then nothing . '' `` it is a very disturbing scene , '' said julian reichelt , editor-in-chief of bild online . an official with france 's accident investigation agency , the bea , said the agency is not aware of any such video . lt. col. jean-marc menichini , a french gendarmerie spokesman in charge of communications on rescue efforts around the germanwings crash site , told cnn that the reports were `` completely wrong '' and `` unwarranted . '' cell phones have been collected at the site , he said , but that they `` had n't been exploited yet . '' menichini said he believed the cell phones would need to be sent to the criminal research institute in rosny sous-bois , near paris , in order to be analyzed by specialized technicians working hand-in-hand with investigators . but none of the cell phones found so far have been sent to the institute , menichini said . asked whether staff involved in the search could have leaked a memory card to the media , menichini answered with a categorical `` no . '' reichelt told `` erin burnett : outfront '' that he had watched the video and stood by the report , saying bild and paris match are `` very confident '' that the clip is real . he noted that investigators only revealed they 'd recovered cell phones from the crash site after bild and paris match published their reports . `` that is something we did not know before . ... overall we can say many things of the investigation were n't revealed by the investigation at the beginning , '' he said . what was mental state of germanwings co-pilot ? german airline lufthansa confirmed tuesday that co-pilot andreas lubitz had battled depression years before he took the controls of germanwings flight 9525 , which he 's accused of deliberately crashing last week in the french alps . lubitz told his lufthansa flight training school in 2009 that he had a `` previous episode of severe depression , '' the airline said tuesday . email correspondence between lubitz and the school discovered in an internal investigation , lufthansa said , included medical documents he submitted in connection with resuming his flight training . the announcement indicates that lufthansa , the parent company of germanwings , knew of lubitz 's battle with depression , allowed him to continue training and ultimately put him in the cockpit . lufthansa , whose ceo carsten spohr previously said lubitz was 100 % fit to fly , described its statement tuesday as a `` swift and seamless clarification '' and said it was sharing the information and documents -- including training and medical records -- with public prosecutors . spohr traveled to the crash site wednesday , where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside . he saw the crisis center set up in seyne-les-alpes , laid a wreath in the village of le vernet , closer to the crash site , where grieving families have left flowers at a simple stone memorial . menichini told cnn late tuesday that no visible human remains were left at the site but recovery teams would keep searching . french president francois hollande , speaking tuesday , said that it should be possible to identify all the victims using dna analysis by the end of the week , sooner than authorities had previously suggested . in the meantime , the recovery of the victims ' personal belongings will start wednesday , menichini said . among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board . check out the latest from our correspondents . the details about lubitz 's correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and lubitz 's possible motive for downing the jet . a lufthansa spokesperson told cnn on tuesday that lubitz had a valid medical certificate , had passed all his examinations and `` held all the licenses required . '' earlier , a spokesman for the prosecutor 's office in dusseldorf , christoph kumpa , said medical records reveal lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot 's license . kumpa emphasized there 's no evidence suggesting lubitz was suicidal or acting aggressively before the crash . investigators are looking into whether lubitz feared his medical condition would cause him to lose his pilot 's license , a european government official briefed on the investigation told cnn on tuesday . while flying was `` a big part of his life , '' the source said , it 's only one theory being considered . another source , a law enforcement official briefed on the investigation , also told cnn that authorities believe the primary motive for lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems . lubitz 's girlfriend told investigators he had seen an eye doctor and a neuropsychologist , both of whom deemed him unfit to work recently and concluded he had psychological issues , the european government official said . but no matter what details emerge about his previous mental health struggles , there 's more to the story , said brian russell , a forensic psychologist . `` psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they were n't going to keep doing their job and they 're upset about that and so they 're suicidal , '' he said . `` but there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person 's problems . '' germanwings crash compensation : what we know . who was the captain of germanwings flight 9525 ? cnn 's margot haddad reported from marseille and pamela brown from dusseldorf , while laura smith-spark wrote from london . cnn 's frederik pleitgen , pamela boykoff , antonia mortensen , sandrine amiel and anna-maja rappard contributed to this report .\\n\",\n", - " 'tgt': \" marseille prosecutor says `` so far no videos were used in the crash investigation '' despite media reports . journalists at bild and paris match are `` very confident '' the video clip is real , an editor says . andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says . \\n\",\n", - " 'tgt_txt': \" marseille prosecutor says `` so far no videos were used in the crash investigation '' despite media reports . journalists at bild and paris match are `` very confident '' the video clip is real , an editor says . andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says . \\n\"}" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "test_dataset[0]" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -307,7 +246,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 14, "metadata": { "scrolled": false }, @@ -315,285 +254,27 @@ "source": [ "\n", "abs_sum_train = processor.preprocess(train_dataset)\n", - "\n" + "abs_sum_test = processor.preprocess(test_dataset)\n" ] }, { "cell_type": "code", - "execution_count": 11, - "metadata": {}, + "execution_count": 9, + "metadata": { + "scrolled": true + }, "outputs": [ { "data": { "text/plain": [ - "{'src': \"editor 's note : in our behind the scenes series , cnn correspondents share their experiences in covering news and analyze the stories behind the events . here , soledad o'brien takes users inside a jail where many of the inmates are mentally ill . an inmate housed on the `` forgotten floor , '' where many mentally ill inmates are housed in miami before trial . miami , florida -lrb- cnn -rrb- -- the ninth floor of the miami-dade pretrial detention facility is dubbed the `` forgotten floor . '' here , inmates with the most severe mental illnesses are incarcerated until they 're ready to appear in court . most often , they face drug charges or charges of assaulting an officer -- charges that judge steven leifman says are usually `` avoidable felonies . '' he says the arrests often result from confrontations with police . mentally ill people often wo n't do what they 're told when police arrive on the scene -- confrontation seems to exacerbate their illness and they become more paranoid , delusional , and less likely to follow directions , according to leifman . so , they end up on the ninth floor severely mentally disturbed , but not getting any real help because they 're in jail . we toured the jail with leifman . he is well known in miami as an advocate for justice and the mentally ill . even though we were not exactly welcomed with open arms by the guards , we were given permission to shoot videotape and tour the floor . go inside the ` forgotten floor ' '' at first , it 's hard to determine where the people are . the prisoners are wearing sleeveless robes . imagine cutting holes for arms and feet in a heavy wool sleeping bag -- that 's kind of what they look like . they 're designed to keep the mentally ill patients from injuring themselves . that 's also why they have no shoes , laces or mattresses . leifman says about one-third of all people in miami-dade county jails are mentally ill . so , he says , the sheer volume is overwhelming the system , and the result is what we see on the ninth floor . of course , it is a jail , so it 's not supposed to be warm and comforting , but the lights glare , the cells are tiny and it 's loud . we see two , sometimes three men -- sometimes in the robes , sometimes naked , lying or sitting in their cells . `` i am the son of the president . you need to get me out of here ! '' one man shouts at me . he is absolutely serious , convinced that help is on the way -- if only he could reach the white house . leifman tells me that these prisoner-patients will often circulate through the system , occasionally stabilizing in a mental hospital , only to return to jail to face their charges . it 's brutally unjust , in his mind , and he has become a strong advocate for changing things in miami . over a meal later , we talk about how things got this way for mental patients . leifman says 200 years ago people were considered `` lunatics '' and they were locked up in jails even if they had no charges against them . they were just considered unfit to be in society . over the years , he says , there was some public outcry , and the mentally ill were moved out of jails and into hospitals . but leifman says many of these mental hospitals were so horrible they were shut down . where did the patients go ? nowhere . the streets . they became , in many cases , the homeless , he says . they never got treatment . leifman says in 1955 there were more than half a million people in state mental hospitals , and today that number has been reduced 90 percent , and 40,000 to 50,000 people are in mental hospitals . the judge says he 's working to change this . starting in 2008 , many inmates who would otherwise have been brought to the `` forgotten floor '' will instead be sent to a new mental health facility -- the first step on a journey toward long-term treatment , not just punishment . leifman says it 's not the complete answer , but it 's a start . leifman says the best part is that it 's a win-win solution . the patients win , the families are relieved , and the state saves money by simply not cycling these prisoners through again and again . and , for leifman , justice is served . e-mail to a friend .\\n\",\n", - " 'src_txt': \"editor 's note : in our behind the scenes series , cnn correspondents share their experiences in covering news and analyze the stories behind the events . here , soledad o'brien takes users inside a jail where many of the inmates are mentally ill . an inmate housed on the `` forgotten floor , '' where many mentally ill inmates are housed in miami before trial . miami , florida -lrb- cnn -rrb- -- the ninth floor of the miami-dade pretrial detention facility is dubbed the `` forgotten floor . '' here , inmates with the most severe mental illnesses are incarcerated until they 're ready to appear in court . most often , they face drug charges or charges of assaulting an officer -- charges that judge steven leifman says are usually `` avoidable felonies . '' he says the arrests often result from confrontations with police . mentally ill people often wo n't do what they 're told when police arrive on the scene -- confrontation seems to exacerbate their illness and they become more paranoid , delusional , and less likely to follow directions , according to leifman . so , they end up on the ninth floor severely mentally disturbed , but not getting any real help because they 're in jail . we toured the jail with leifman . he is well known in miami as an advocate for justice and the mentally ill . even though we were not exactly welcomed with open arms by the guards , we were given permission to shoot videotape and tour the floor . go inside the ` forgotten floor ' '' at first , it 's hard to determine where the people are . the prisoners are wearing sleeveless robes . imagine cutting holes for arms and feet in a heavy wool sleeping bag -- that 's kind of what they look like . they 're designed to keep the mentally ill patients from injuring themselves . that 's also why they have no shoes , laces or mattresses . leifman says about one-third of all people in miami-dade county jails are mentally ill . so , he says , the sheer volume is overwhelming the system , and the result is what we see on the ninth floor . of course , it is a jail , so it 's not supposed to be warm and comforting , but the lights glare , the cells are tiny and it 's loud . we see two , sometimes three men -- sometimes in the robes , sometimes naked , lying or sitting in their cells . `` i am the son of the president . you need to get me out of here ! '' one man shouts at me . he is absolutely serious , convinced that help is on the way -- if only he could reach the white house . leifman tells me that these prisoner-patients will often circulate through the system , occasionally stabilizing in a mental hospital , only to return to jail to face their charges . it 's brutally unjust , in his mind , and he has become a strong advocate for changing things in miami . over a meal later , we talk about how things got this way for mental patients . leifman says 200 years ago people were considered `` lunatics '' and they were locked up in jails even if they had no charges against them . they were just considered unfit to be in society . over the years , he says , there was some public outcry , and the mentally ill were moved out of jails and into hospitals . but leifman says many of these mental hospitals were so horrible they were shut down . where did the patients go ? nowhere . the streets . they became , in many cases , the homeless , he says . they never got treatment . leifman says in 1955 there were more than half a million people in state mental hospitals , and today that number has been reduced 90 percent , and 40,000 to 50,000 people are in mental hospitals . the judge says he 's working to change this . starting in 2008 , many inmates who would otherwise have been brought to the `` forgotten floor '' will instead be sent to a new mental health facility -- the first step on a journey toward long-term treatment , not just punishment . leifman says it 's not the complete answer , but it 's a start . leifman says the best part is that it 's a win-win solution . the patients win , the families are relieved , and the state saves money by simply not cycling these prisoners through again and again . and , for leifman , justice is served . e-mail to a friend .\\n\",\n", - " 'tgt': \" mentally ill inmates in miami are housed on the `` forgotten floor '' judge steven leifman says most are there as a result of `` avoidable felonies '' while cnn tours facility , patient shouts : `` i am the son of the president '' leifman says the system is unjust and he 's fighting for change . \\n\",\n", - " 'tgt_txt': \" mentally ill inmates in miami are housed on the `` forgotten floor '' judge steven leifman says most are there as a result of `` avoidable felonies '' while cnn tours facility , patient shouts : `` i am the son of the president '' leifman says the system is unjust and he 's fighting for change . \\n\",\n", - " 'source_ids': tensor([ 0, 4474, 128, ..., 1, 1, 1]),\n", - " 'src_mask': tensor([1, 1, 1, ..., 0, 0, 0]),\n", - " 'target_ids': tensor([ 0, 28696, 90, 15698, 10072, 4812, 8039, 11, 475, 40879,\n", - " 32, 15740, 15, 5, 45518, 9885, 1929, 12801, 49703, 90,\n", - " 15698, 28696, 90, 15698, 1679, 11235, 2987, 2084, 1594, 397,\n", - " 161, 144, 32, 89, 25, 10, 898, 9, 45518, 1877,\n", - " 868, 14383, 17130, 12801, 49703, 90, 15698, 28696, 90, 15698,\n", - " 150, 740, 15688, 10182, 2122, 2])}" + "'\\n# save and load preprocessed data\\nsave_path = os.path.join(DATA_PATH, \"processed\")\\ntorch.save(ext_sum_train, os.path.join(save_path, \"train_full.pt\"))\\ntorch.save(ext_sum_test, os.path.join(save_path, \"test_full.pt\"))\\n\\n'" ] }, - "execution_count": 11, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], - "source": [ - "abs_sum_train[0]" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Exception in thread Thread-24:\n", - "Traceback (most recent call last):\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/threading.py\", line 916, in _bootstrap_inner\n", - " self.run()\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/threading.py\", line 864, in run\n", - " self._target(*self._args, **self._kwargs)\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\", line 463, in _handle_results\n", - " task = get()\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/connection.py\", line 251, in recv\n", - " return _ForkingPickler.loads(buf.getbuffer())\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/site-packages/torch/multiprocessing/reductions.py\", line 294, in rebuild_storage_fd\n", - " fd = df.detach()\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/resource_sharer.py\", line 58, in detach\n", - " return reduction.recv_handle(conn)\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/reduction.py\", line 182, in recv_handle\n", - " return recvfds(s, 1)[0]\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/reduction.py\", line 161, in recvfds\n", - " len(ancdata))\n", - "RuntimeError: received 0 items of ancdata\n", - "\n", - "Process ForkPoolWorker-71:\n", - "Process ForkPoolWorker-72:\n", - "Process ForkPoolWorker-66:\n", - "Process ForkPoolWorker-62:\n" - ] - }, - { - "ename": "AssertionError", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m/dadendev/nlp-recipes/utils_nlp/models/transformers/abstractive_summarization_bartt5.py\u001b[0m in \u001b[0;36mparallel_preprocess\u001b[0;34m(input_data, preprocess, num_pool)\u001b[0m\n\u001b[1;32m 115\u001b[0m results = p.map(\n\u001b[0;32m--> 116\u001b[0;31m \u001b[0mpreprocess\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchunksize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_data\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mnum_pool\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 117\u001b[0m )\n", - "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\u001b[0m in \u001b[0;36mmap\u001b[0;34m(self, func, iterable, chunksize)\u001b[0m\n\u001b[1;32m 265\u001b[0m '''\n\u001b[0;32m--> 266\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_map_async\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfunc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0miterable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmapstar\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchunksize\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 267\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\u001b[0m in \u001b[0;36mget\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 637\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mget\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimeout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 638\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwait\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 639\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mready\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\u001b[0m in \u001b[0;36mwait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 634\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mwait\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtimeout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 635\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_event\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwait\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 636\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/threading.py\u001b[0m in \u001b[0;36mwait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 550\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0msignaled\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 551\u001b[0;31m \u001b[0msignaled\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cond\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwait\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 552\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0msignaled\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/threading.py\u001b[0m in \u001b[0;36mwait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 294\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtimeout\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 295\u001b[0;31m \u001b[0mwaiter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0macquire\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 296\u001b[0m \u001b[0mgotit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: ", - "\nDuring handling of the above exception, another exception occurred:\n", - "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mabs_sum_test\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprocessor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpreprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtest_dataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;32m/dadendev/nlp-recipes/utils_nlp/models/transformers/abstractive_summarization_bartt5.py\u001b[0m in \u001b[0;36mpreprocess\u001b[0;34m(self, input_data_list)\u001b[0m\n\u001b[1;32m 146\u001b[0m )\n\u001b[1;32m 147\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 148\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mparallel_preprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_data_list\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpreprocess\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_pool\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 149\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0mstaticmethod\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/dadendev/nlp-recipes/utils_nlp/models/transformers/abstractive_summarization_bartt5.py\u001b[0m in \u001b[0;36mparallel_preprocess\u001b[0;34m(input_data, preprocess, num_pool)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mPool\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnum_pool\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 115\u001b[0m results = p.map(\n\u001b[0;32m--> 116\u001b[0;31m \u001b[0mpreprocess\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput_data\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mchunksize\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_data\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mnum_pool\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 117\u001b[0m )\n\u001b[1;32m 118\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\u001b[0m in \u001b[0;36m__exit__\u001b[0;34m(self, exc_type, exc_val, exc_tb)\u001b[0m\n\u001b[1;32m 609\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 610\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__exit__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexc_type\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexc_val\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexc_tb\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 611\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mterminate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 612\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 613\u001b[0m \u001b[0;31m#\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\u001b[0m in \u001b[0;36mterminate\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 539\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTERMINATE\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 540\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_worker_handler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTERMINATE\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 541\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_terminate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 542\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 543\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/util.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, wr, _finalizer_registry, sub_debug, getpid)\u001b[0m\n\u001b[1;32m 184\u001b[0m sub_debug('finalizer calling %s with args %s and kwargs %s',\n\u001b[1;32m 185\u001b[0m self._callback, self._args, self._kwargs)\n\u001b[0;32m--> 186\u001b[0;31m \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_args\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_kwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 187\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_weakref\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_callback\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_args\u001b[0m \u001b[0;34m=\u001b[0m\u001b[0;31m \u001b[0m\u001b[0;31m\\\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 188\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_kwargs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_key\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\u001b[0m in \u001b[0;36m_terminate_pool\u001b[0;34m(cls, taskqueue, inqueue, outqueue, pool, worker_handler, task_handler, result_handler, cache)\u001b[0m\n\u001b[1;32m 571\u001b[0m \u001b[0mcls\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_help_stuff_finish\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minqueue\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtask_handler\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpool\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 572\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 573\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0mresult_handler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mis_alive\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mor\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcache\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 574\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 575\u001b[0m \u001b[0mresult_handler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_state\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mTERMINATE\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mAssertionError\u001b[0m: " - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Process ForkPoolWorker-69:\n", - "Process ForkPoolWorker-65:\n", - "Process ForkPoolWorker-70:\n", - "Process ForkPoolWorker-64:\n", - "Process ForkPoolWorker-67:\n", - "Process ForkPoolWorker-68:\n", - "Process ForkPoolWorker-63:\n", - "Process ForkPoolWorker-61:\n", - "Traceback (most recent call last):\n", - "Traceback (most recent call last):\n", - "Traceback (most recent call last):\n", - "Traceback (most recent call last):\n", - "Traceback (most recent call last):\n", - "Traceback (most recent call last):\n", - "Traceback (most recent call last):\n", - "Traceback (most recent call last):\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n", - " self.run()\n", - "Traceback (most recent call last):\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n", - " self.run()\n", - "Traceback (most recent call last):\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n", - " self.run()\n", - "Traceback (most recent call last):\n", - "Traceback (most recent call last):\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n", - " self.run()\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n", - " self.run()\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n", - " self.run()\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n", - " self.run()\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n", - " self.run()\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 93, in run\n", - " self._target(*self._args, **self._kwargs)\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n", - " self.run()\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 93, in run\n", - " self._target(*self._args, **self._kwargs)\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 93, in run\n", - " self._target(*self._args, **self._kwargs)\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n", - " self.run()\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n", - " self.run()\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 93, in run\n", - " self._target(*self._args, **self._kwargs)\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 258, in _bootstrap\n", - " self.run()\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 93, in run\n", - " self._target(*self._args, **self._kwargs)\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 93, in run\n", - " self._target(*self._args, **self._kwargs)\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 93, in run\n", - " self._target(*self._args, **self._kwargs)\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 93, in run\n", - " self._target(*self._args, **self._kwargs)\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\", line 125, in worker\n", - " put((job, i, result))\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 93, in run\n", - " self._target(*self._args, **self._kwargs)\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\", line 108, in worker\n", - " task = get()\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\", line 125, in worker\n", - " put((job, i, result))\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 93, in run\n", - " self._target(*self._args, **self._kwargs)\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 93, in run\n", - " self._target(*self._args, **self._kwargs)\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\", line 125, in worker\n", - " put((job, i, result))\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\", line 108, in worker\n", - " task = get()\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/process.py\", line 93, in run\n", - " self._target(*self._args, **self._kwargs)\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\", line 125, in worker\n", - " put((job, i, result))\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\", line 125, in worker\n", - " put((job, i, result))\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\", line 125, in worker\n", - " put((job, i, result))\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/queues.py\", line 346, in put\n", - " with self._wlock:\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/queues.py\", line 334, in get\n", - " with self._rlock:\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\", line 125, in worker\n", - " put((job, i, result))\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/queues.py\", line 346, in put\n", - " with self._wlock:\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\", line 108, in worker\n", - " task = get()\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\", line 108, in worker\n", - " task = get()\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/queues.py\", line 346, in put\n", - " with self._wlock:\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/queues.py\", line 334, in get\n", - " with self._rlock:\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/pool.py\", line 108, in worker\n", - " task = get()\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/queues.py\", line 346, in put\n", - " with self._wlock:\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/queues.py\", line 346, in put\n", - " with self._wlock:\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/queues.py\", line 346, in put\n", - " with self._wlock:\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/synchronize.py\", line 95, in __enter__\n", - " return self._semlock.__enter__()\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/queues.py\", line 347, in put\n", - " self._writer.send_bytes(obj)\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/synchronize.py\", line 95, in __enter__\n", - " return self._semlock.__enter__()\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/synchronize.py\", line 95, in __enter__\n", - " return self._semlock.__enter__()\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/queues.py\", line 334, in get\n", - " with self._rlock:\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/queues.py\", line 334, in get\n", - " with self._rlock:\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/synchronize.py\", line 95, in __enter__\n", - " return self._semlock.__enter__()\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/synchronize.py\", line 95, in __enter__\n", - " return self._semlock.__enter__()\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/queues.py\", line 335, in get\n", - " res = self._reader.recv_bytes()\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/synchronize.py\", line 95, in __enter__\n", - " return self._semlock.__enter__()\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/synchronize.py\", line 95, in __enter__\n", - " return self._semlock.__enter__()\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/synchronize.py\", line 95, in __enter__\n", - " return self._semlock.__enter__()\n", - "KeyboardInterrupt\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/connection.py\", line 200, in send_bytes\n", - " self._send_bytes(m[offset:offset + size])\n", - "KeyboardInterrupt\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/synchronize.py\", line 95, in __enter__\n", - " return self._semlock.__enter__()\n", - "KeyboardInterrupt\n", - "KeyboardInterrupt\n", - "KeyboardInterrupt\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/synchronize.py\", line 95, in __enter__\n", - " return self._semlock.__enter__()\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/connection.py\", line 216, in recv_bytes\n", - " buf = self._recv_bytes(maxlength)\n", - "KeyboardInterrupt\n", - "KeyboardInterrupt\n", - "KeyboardInterrupt\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/connection.py\", line 398, in _send_bytes\n", - " self._send(buf)\n", - "KeyboardInterrupt\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/connection.py\", line 407, in _recv_bytes\n", - " buf = self._recv(4)\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/connection.py\", line 368, in _send\n", - " n = write(self._handle, buf)\n", - " File \"/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/multiprocessing/connection.py\", line 379, in _recv\n", - " chunk = read(handle, remaining)\n", - "KeyboardInterrupt\n", - "KeyboardInterrupt\n", - "KeyboardInterrupt\n" - ] - } - ], - "source": [ - "abs_sum_test = processor.preprocess(test_dataset)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": true - }, - "outputs": [], "source": [ "\"\"\"\n", "# save and load preprocessed data\n", @@ -608,20 +289,21 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "len(ext_sum_train)" - ] - }, - { - "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1000\n", + "1000\n" + ] + } + ], "source": [ - "len(ext_sum_test)" + "print(len(abs_sum_train))\n", + "print(len(abs_sum_test))" ] }, { @@ -633,60 +315,55 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['src', 'src_txt', 'tgt', 'tgt_txt', 'source_ids', 'source_mask', 'target_ids'])" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "ext_sum_train[0]" + "abs_sum_train[0].keys()" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 12, "metadata": { "scrolled": false }, - "outputs": [], - "source": [ - "ext_sum_train[0].keys()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "##### [Option 2] Reuse Preprocessed data from [BERTSUM Repo](https://github.com/nlpyang/BertSum)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [ - "parameters", - ":w" - ] - }, - "outputs": [], - "source": [ - "# the data path used to downloaded the preprocessed data from BERTSUM Repo.\n", - "# if you have downloaded the dataset, change the code to use that path where the dataset is.\n", - "PROCESSED_DATA_PATH = TemporaryDirectory().name\n", - "os.makedirs(PROCESSED_DATA_PATH, exist_ok=True)\n", - "#data_path = \"./temp_data5/\"\n", - "#PROCESSED_DATA_PATH = data_path" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "{'src': \"editor 's note : in our behind the scenes series , cnn correspondents share their experiences in covering news and analyze the stories behind the events . here , soledad o'brien takes users inside a jail where many of the inmates are mentally ill . an inmate housed on the `` forgotten floor , '' where many mentally ill inmates are housed in miami before trial . miami , florida -lrb- cnn -rrb- -- the ninth floor of the miami-dade pretrial detention facility is dubbed the `` forgotten floor . '' here , inmates with the most severe mental illnesses are incarcerated until they 're ready to appear in court . most often , they face drug charges or charges of assaulting an officer -- charges that judge steven leifman says are usually `` avoidable felonies . '' he says the arrests often result from confrontations with police . mentally ill people often wo n't do what they 're told when police arrive on the scene -- confrontation seems to exacerbate their illness and they become more paranoid , delusional , and less likely to follow directions , according to leifman . so , they end up on the ninth floor severely mentally disturbed , but not getting any real help because they 're in jail . we toured the jail with leifman . he is well known in miami as an advocate for justice and the mentally ill . even though we were not exactly welcomed with open arms by the guards , we were given permission to shoot videotape and tour the floor . go inside the ` forgotten floor ' '' at first , it 's hard to determine where the people are . the prisoners are wearing sleeveless robes . imagine cutting holes for arms and feet in a heavy wool sleeping bag -- that 's kind of what they look like . they 're designed to keep the mentally ill patients from injuring themselves . that 's also why they have no shoes , laces or mattresses . leifman says about one-third of all people in miami-dade county jails are mentally ill . so , he says , the sheer volume is overwhelming the system , and the result is what we see on the ninth floor . of course , it is a jail , so it 's not supposed to be warm and comforting , but the lights glare , the cells are tiny and it 's loud . we see two , sometimes three men -- sometimes in the robes , sometimes naked , lying or sitting in their cells . `` i am the son of the president . you need to get me out of here ! '' one man shouts at me . he is absolutely serious , convinced that help is on the way -- if only he could reach the white house . leifman tells me that these prisoner-patients will often circulate through the system , occasionally stabilizing in a mental hospital , only to return to jail to face their charges . it 's brutally unjust , in his mind , and he has become a strong advocate for changing things in miami . over a meal later , we talk about how things got this way for mental patients . leifman says 200 years ago people were considered `` lunatics '' and they were locked up in jails even if they had no charges against them . they were just considered unfit to be in society . over the years , he says , there was some public outcry , and the mentally ill were moved out of jails and into hospitals . but leifman says many of these mental hospitals were so horrible they were shut down . where did the patients go ? nowhere . the streets . they became , in many cases , the homeless , he says . they never got treatment . leifman says in 1955 there were more than half a million people in state mental hospitals , and today that number has been reduced 90 percent , and 40,000 to 50,000 people are in mental hospitals . the judge says he 's working to change this . starting in 2008 , many inmates who would otherwise have been brought to the `` forgotten floor '' will instead be sent to a new mental health facility -- the first step on a journey toward long-term treatment , not just punishment . leifman says it 's not the complete answer , but it 's a start . leifman says the best part is that it 's a win-win solution . the patients win , the families are relieved , and the state saves money by simply not cycling these prisoners through again and again . and , for leifman , justice is served . e-mail to a friend .\\n\",\n", + " 'src_txt': \"editor 's note : in our behind the scenes series , cnn correspondents share their experiences in covering news and analyze the stories behind the events . here , soledad o'brien takes users inside a jail where many of the inmates are mentally ill . an inmate housed on the `` forgotten floor , '' where many mentally ill inmates are housed in miami before trial . miami , florida -lrb- cnn -rrb- -- the ninth floor of the miami-dade pretrial detention facility is dubbed the `` forgotten floor . '' here , inmates with the most severe mental illnesses are incarcerated until they 're ready to appear in court . most often , they face drug charges or charges of assaulting an officer -- charges that judge steven leifman says are usually `` avoidable felonies . '' he says the arrests often result from confrontations with police . mentally ill people often wo n't do what they 're told when police arrive on the scene -- confrontation seems to exacerbate their illness and they become more paranoid , delusional , and less likely to follow directions , according to leifman . so , they end up on the ninth floor severely mentally disturbed , but not getting any real help because they 're in jail . we toured the jail with leifman . he is well known in miami as an advocate for justice and the mentally ill . even though we were not exactly welcomed with open arms by the guards , we were given permission to shoot videotape and tour the floor . go inside the ` forgotten floor ' '' at first , it 's hard to determine where the people are . the prisoners are wearing sleeveless robes . imagine cutting holes for arms and feet in a heavy wool sleeping bag -- that 's kind of what they look like . they 're designed to keep the mentally ill patients from injuring themselves . that 's also why they have no shoes , laces or mattresses . leifman says about one-third of all people in miami-dade county jails are mentally ill . so , he says , the sheer volume is overwhelming the system , and the result is what we see on the ninth floor . of course , it is a jail , so it 's not supposed to be warm and comforting , but the lights glare , the cells are tiny and it 's loud . we see two , sometimes three men -- sometimes in the robes , sometimes naked , lying or sitting in their cells . `` i am the son of the president . you need to get me out of here ! '' one man shouts at me . he is absolutely serious , convinced that help is on the way -- if only he could reach the white house . leifman tells me that these prisoner-patients will often circulate through the system , occasionally stabilizing in a mental hospital , only to return to jail to face their charges . it 's brutally unjust , in his mind , and he has become a strong advocate for changing things in miami . over a meal later , we talk about how things got this way for mental patients . leifman says 200 years ago people were considered `` lunatics '' and they were locked up in jails even if they had no charges against them . they were just considered unfit to be in society . over the years , he says , there was some public outcry , and the mentally ill were moved out of jails and into hospitals . but leifman says many of these mental hospitals were so horrible they were shut down . where did the patients go ? nowhere . the streets . they became , in many cases , the homeless , he says . they never got treatment . leifman says in 1955 there were more than half a million people in state mental hospitals , and today that number has been reduced 90 percent , and 40,000 to 50,000 people are in mental hospitals . the judge says he 's working to change this . starting in 2008 , many inmates who would otherwise have been brought to the `` forgotten floor '' will instead be sent to a new mental health facility -- the first step on a journey toward long-term treatment , not just punishment . leifman says it 's not the complete answer , but it 's a start . leifman says the best part is that it 's a win-win solution . the patients win , the families are relieved , and the state saves money by simply not cycling these prisoners through again and again . and , for leifman , justice is served . e-mail to a friend .\\n\",\n", + " 'tgt': \" mentally ill inmates in miami are housed on the `` forgotten floor '' judge steven leifman says most are there as a result of `` avoidable felonies '' while cnn tours facility , patient shouts : `` i am the son of the president '' leifman says the system is unjust and he 's fighting for change . \\n\",\n", + " 'tgt_txt': \" mentally ill inmates in miami are housed on the `` forgotten floor '' judge steven leifman says most are there as a result of `` avoidable felonies '' while cnn tours facility , patient shouts : `` i am the son of the president '' leifman says the system is unjust and he 's fighting for change . \\n\",\n", + " 'source_ids': tensor([6005, 3, 31, ..., 307, 18, 1987]),\n", + " 'source_mask': tensor([1, 1, 1, ..., 1, 1, 1]),\n", + " 'target_ids': tensor([ 3, 2, 17, 3155, 19367, 3, 1092, 16, 11171, 16,\n", + " 1337, 3690, 33, 629, 26, 30, 8, 3, 2, 11821,\n", + " 1501, 3, 31, 31, 3, 2, 87, 17, 3155, 3,\n", + " 2, 17, 3155, 5191, 3, 849, 1926, 90, 99, 348,\n", + " 845, 167, 33, 132, 38, 3, 9, 741, 13, 3,\n", + " 2, 1792, 179, 3110, 106, 725])}" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "if USE_PREPROCSSED_DATA:\n", - " download_path = CNNDMBertSumProcessedData.download(local_path=PROCESSED_DATA_PATH)\n", - " ext_sum_train, ext_sum_test = ExtSumProcessedData().splits(root=download_path, train_iterable=True)\n", - " " + "abs_sum_train[0]" ] }, { @@ -711,7 +388,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 32, "metadata": { "tags": [ "parameters" @@ -721,17 +398,13 @@ "source": [ "BATCH_SIZE = 5 # batch size, unit is the number of samples\n", "MAX_POS_LENGTH = 512\n", - "if USE_PREPROCSSED_DATA: #if bertsum published data is used\n", - " BATCH_SIZE = 3000 # batch size, unit is the number of tokens\n", - " MAX_POS_LENGTH = 512\n", + "\n", " \n", "\n", "\n", "# GPU used for training\n", "NUM_GPUS = torch.cuda.device_count()\n", "\n", - "# Encoder name. Options are: 1. baseline, classifier, transformer, rnn.\n", - "ENCODER = \"transformer\"\n", "\n", "# Learning rate\n", "LEARNING_RATE=2e-3\n", @@ -752,24 +425,42 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": { "scrolled": true }, "outputs": [], "source": [ - "summarizer = ExtractiveSummarizer(processor, MODEL_NAME, ENCODER, MAX_POS_LENGTH, CACHE_DIR)" + "summarizer = AbstractiveSummarizer(processor, MODEL_NAME)" ] }, { "cell_type": "code", "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 15, "metadata": { "scrolled": true }, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "'\\n\\nsummarizer.fit(\\n ext_sum_train,\\n num_gpus=NUM_GPUS,\\n batch_size=BATCH_SIZE,\\n gradient_accumulation_steps=2,\\n max_steps=MAX_STEPS,\\n learning_rate=LEARNING_RATE,\\n warmup_steps=WARMUP_STEPS,\\n verbose=True,\\n report_every=REPORT_EVERY,\\n clip_grad_norm=False,\\n use_preprocessed_data=USE_PREPROCSSED_DATA\\n )\\n\\n'" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "#\"\"\"\n", + "\"\"\"\n", "\n", "summarizer.fit(\n", " ext_sum_train,\n", @@ -785,15 +476,27 @@ " use_preprocessed_data=USE_PREPROCSSED_DATA\n", " )\n", "\n", - "#\"\"\"\n" + "\"\"\"\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 16, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "'\\nsummarizer.save_model(\\n os.path.join(\\n CACHE_DIR,\\n \"extsum_modelname_{0}_usepreprocess{1}_steps_{2}.pt\".format(\\n MODEL_NAME, USE_PREPROCSSED_DATA, MAX_STEPS\\n ),\\n )\\n)\\n'" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ + "\"\"\"\n", "summarizer.save_model(\n", " os.path.join(\n", " CACHE_DIR,\n", @@ -801,14 +504,26 @@ " MODEL_NAME, USE_PREPROCSSED_DATA, MAX_STEPS\n", " ),\n", " )\n", - ")" + ")\n", + "\"\"\"\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "'\\nimport torch\\nmodel_path = os.path.join(\\n CACHE_DIR,\\n \"extsum_modelname_{0}_usepreprocess{1}_steps_{2}.pt\".format(\\n MODEL_NAME, USE_PREPROCSSED_DATA, MAX_STEPS\\n ))\\nsummarizer = ExtractiveSummarizer(processor, MODEL_NAME, ENCODER, MAX_POS_LENGTH, CACHE_DIR)\\nsummarizer.model.load_state_dict(torch.load(model_path, map_location=\"cpu\"))\\n'" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# for loading a previous saved model\n", "\"\"\"\n", @@ -834,58 +549,351 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['src', 'src_txt', 'tgt', 'tgt_txt', 'source_ids', 'source_mask', 'target_ids'])" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "ext_sum_test[0].keys()" + "abs_sum_test[0].keys()" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ - "if \"segs\" in ext_sum_test[0]: # preprocessed_data\n", - " source = [i['src_txt'] for i in ext_sum_test]\n", - " target = [\"\\n\".join(i['tgt_txt'].split(\"\")) for i in ext_sum_test]\n", - "else:\n", - " source = []\n", - " temp_target = []\n", - " for i in ext_sum_test:\n", - " source.append(i[\"src_txt\"]) \n", - " temp_target.append(\" \".join(j) for j in i['tgt']) \n", - " target = [''.join(i) for i in list(temp_target)]" + "source = []\n", + "target = []\n", + "for i in abs_sum_test:\n", + " source.append(i[\"src_txt\"]) \n", + " target.append(i['tgt']) " ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\" marseille prosecutor says `` so far no videos were used in the crash investigation '' despite media reports . journalists at bild and paris match are `` very confident '' the video clip is real , an editor says . andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says . \\n\"" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "target[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 19, "metadata": { "scrolled": false }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r", + "Generating summary: 0%| | 0/2 [00:00\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mrouge_scores\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_rouge_python\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcand\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mprediction\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mref\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mpprint\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrouge_scores\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/dadendev/nlp-recipes/utils_nlp/eval/rouge/compute_rouge.py\u001b[0m in \u001b[0;36mcompute_rouge_python\u001b[0;34m(cand, ref, is_input_files, language)\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Number of candidates: {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcandidates\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Number of references: {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreferences\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 118\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcandidates\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreferences\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 119\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlanguage\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"en\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mAssertionError\u001b[0m: " + ] + } + ], "source": [ "rouge_scores = compute_rouge_python(cand=prediction, ref=target)\n", "pprint.pprint(rouge_scores)" diff --git a/utils_nlp/models/transformers/abstractive_summarization_bartt5.py b/utils_nlp/models/transformers/abstractive_summarization_bartt5.py index bbdf3389c..97c5ff8a7 100644 --- a/utils_nlp/models/transformers/abstractive_summarization_bartt5.py +++ b/utils_nlp/models/transformers/abstractive_summarization_bartt5.py @@ -33,7 +33,11 @@ parallelize_model, ) from utils_nlp.eval import compute_rouge_python -from utils_nlp.models.transformers.common import TOKENIZER_CLASS, Transformer +from utils_nlp.models.transformers.common import Transformer +# from utils_nlp.models.transformers.common import TOKENIZER_CLASS + + + #from transformers.modeling_bart import BART_PRETRAINED_MODEL_ARCHIVE_MAP @@ -48,41 +52,37 @@ "language-modeling": AutoModelWithLMHead, } +from transformers import BartForConditionalGeneration, BartTokenizer +from transformers import T5ForConditionalGeneration, T5Tokenizer + +MODEL_CLASS = { + "bart-large-cnn": BartForConditionalGeneration, + "t5-large":T5ForConditionalGeneration +} +TOKENIZER_CLASS = { + "bart-large-cnn": BartTokenizer, + "t5-large": T5Tokenizer +} + logger = logging.getLogger(__name__) import os - import torch from torch.utils.data import Dataset from transformers.tokenization_utils import trim_batch - -def encode_file(tokenizer, data_path, max_length, pad_to_max_length=True, return_tensors="pt"): - examples = [] - with open(data_path, "r") as f: - for text in f.readlines(): - tokenized = tokenizer.batch_encode_plus( - [text], max_length=max_length, pad_to_max_length=pad_to_max_length, return_tensors=return_tensors, - ) - examples.append(tokenized) - return examples - -def encode_example(example, tokenizer=None, max_source_length=None, max_target_length=None, pad_to_max_length=True, return_tensors="pt"): - #examples = [] - #with open(data_path, "r") as f: - # for text in f.readlines(): - #for text in text_lines: +def encode_example(example, tokenizer=None, prefix="", max_source_length=None, max_target_length=None, pad_to_max_length=True, return_tensors="pt"): ## add to the dataset tokenized_source = tokenizer.batch_encode_plus( - [example['src']], max_length=max_source_length, pad_to_max_length=pad_to_max_length, return_tensors=return_tensors, + [prefix + example['src']], max_length=max_source_length, pad_to_max_length=pad_to_max_length, return_tensors=return_tensors, ) source_ids = tokenized_source["input_ids"].squeeze() src_mask = tokenized_source["attention_mask"].squeeze() example["source_ids"] = source_ids - example["src_mask"] = src_mask + example["source_mask"] = src_mask if 'tgt' in example: tokenized_target = tokenizer.batch_encode_plus( [example['tgt']], max_length=max_target_length, pad_to_max_length=pad_to_max_length, return_tensors=return_tensors, @@ -127,13 +127,27 @@ def parallel_preprocess(input_data, preprocess, num_pool=-1): class SummarizationProcessor: def __init__( self, - tokenizer, - #with_target=False, + model_name, + cache_dir="./", max_source_length=1024, max_target_length=56, ): #super().__init__() - self.tokenizer = tokenizer + self.tokenizer = TOKENIZER_CLASS[model_name].from_pretrained(model_name, cache_dir=cache_dir) # b + config = AutoConfig.from_pretrained( + model_name, + #self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path, + #**({"num_labels": num_labels} if num_labels is not None else {}), + cache_dir=cache_dir, + #**config_kwargs, + ) + if model_name.startswith("t5"): + # update config with summarization specific params + task_specific_params = config.task_specific_params + if task_specific_params is not None: + config.update(task_specific_params.get("summarization", {})) + + self.prefix = config.prefix #self.source = source_examples #encode_file(tokenizer, os.path.join(data_dir, type_path + ".source"), max_source_length) self.with_target = False self.max_source_length = max_source_length @@ -144,7 +158,7 @@ def __init__( def preprocess(self, input_data_list): preprocess = functools.partial( - encode_example, tokenizer=self.tokenizer, max_source_length=self.max_source_length, max_target_length=self.max_target_length + encode_example, tokenizer=self.tokenizer, prefix=self.prefix, max_source_length=self.max_source_length, max_target_length=self.max_target_length ) return parallel_preprocess(input_data_list, preprocess, num_pool=-1) @@ -155,17 +169,17 @@ def trim_seq2seq_batch(batch, pad_token_id): source_ids, source_mask = trim_batch(batch["source_ids"], pad_token_id, attention_mask=batch["source_mask"]) return source_ids, source_mask, y - def collate_fn(self, batch, with_target=False): + def collate_fn(self, batch, device, train_mode=False): input_ids = torch.stack([x["source_ids"] for x in batch]) masks = torch.stack([x["source_mask"] for x in batch]) pad_token_id = self.tokenizer.pad_token_id source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) - if with_target: + if train_mode: target_ids = torch.stack([x["target_ids"] for x in batch]) y = trim_batch(target_ids, pad_token_id) - return {"source_ids": source_ids, "source_mask": source_mask, "target_ids": y} + return {"source_ids": source_ids.to(device), "source_mask": source_mask.to(device), "target_ids": y.to(device)} else: - return {"source_ids": source_ids, "source_mask": source_mask} + return {"source_ids": source_ids.to(device), "source_mask": source_mask.to(device)} class AbstractiveSummarizer(Transformer): @@ -174,7 +188,8 @@ class AbstractiveSummarizer(Transformer): def __init__( self, - model_name="bert-base-uncased", + processor, + model_name="bart-large-cnn", cache_dir=".", max_source_length=1024, max_target_length=240 @@ -190,13 +205,14 @@ def __init__( input. Defaults to 768. """ - super().__init__( + """super().__init__( model_class=AutoModelWithLMHead, model_name=model_name, num_labels=0, cache_dir=cache_dir, ) """ + """ if model_name not in self.list_supported_models(): raise ValueError( "Model name {} is not supported by BertSumAbs. " @@ -204,6 +220,7 @@ def __init__( "names.".format(value) ) """ + self.processor = processor self.config = AutoConfig.from_pretrained( model_name, #self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path, @@ -218,7 +235,8 @@ def __init__( cache_dir=cache_dir, ) - self.model = MODEL_MODES[mode].from_pretrained( + self._model_name = model_name + self.model = MODEL_MODES["language-modeling"].from_pretrained( self.model_name, #from_tf=bool(".ckpt" in self.hparams.model_name_or_path), config=self.config, @@ -230,9 +248,10 @@ def __init__( self.max_source_length = max_source_length self.max_target_length = max_target_length + @staticmethod def list_supported_models(): - return list(MODEL_CLASS.keys()) + return list(MODEL_CLASS) def fit( self, @@ -459,29 +478,12 @@ def predict( num_gpus=num_gpus, gpu_ids=gpu_ids, local_rank=local_rank ) - # move model to devices - def this_model_move_callback(model, device): - model = move_model_to_device(model, device) - return parallelize_model( - model, device, num_gpus=num_gpus, gpu_ids=gpu_ids, local_rank=local_rank - ) - if fp16: self.model = self.model.half() self.model = move_model_to_device(self.model, device) self.model.eval() - predictor = build_predictor( - self.processor.tokenizer, - self.processor.symbols, - self.model, - alpha=alpha, - beam_size=beam_size, - min_length=min_length, - max_length=max_length, - ) - predictor = this_model_move_callback(predictor, device) self.model = parallelize_model( self.model, device, @@ -493,8 +495,8 @@ def this_model_move_callback(model, device): test_sampler = SequentialSampler(test_dataset) def collate_fn(data): - return self.processor.collate( - data, self.max_pos_length, device, train_mode=False + return self.processor.collate_fn( + data, device, train_mode=False ) test_dataloader = DataLoader( @@ -504,52 +506,23 @@ def collate_fn(data): collate_fn=collate_fn, ) print("dataset length is {}".format(len(test_dataset))) - - def format_summary(translation): - """ Transforms the output of the `from_batch` function - into nicely formatted summaries. - """ - raw_summary = translation - summary = ( - raw_summary.replace("[unused0]", "") - .replace("[unused3]", "") - .replace("[CLS]", "") - .replace("[SEP]", "") - .replace("[PAD]", "") - .replace("[unused1]", "") - .replace(r" +", " ") - .replace(" [unused2] ", ".") - .replace("[unused2]", "") - .strip() - ) - - return summary - - def generate_summary_from_tokenid(preds, pred_score): - batch_size = preds.size()[0] # batch.batch_size - translations = [] - for b in range(batch_size): - if len(preds[b]) < 1: - pred_sents = "" - else: - pred_sents = self.processor.tokenizer.convert_ids_to_tokens( - [int(n) for n in preds[b] if int(n) != 0] - ) - pred_sents = " ".join(pred_sents).replace(" ##", "") - translations.append(pred_sents) - return translations - generated_summaries = [] for batch in tqdm( test_dataloader, desc="Generating summary", disable=not verbose ): - input = self.processor.get_inputs(batch, device, "bert", train_mode=False) - translations, scores = predictor(**input) - - translations_text = generate_summary_from_tokenid(translations, scores) - summaries = [format_summary(t) for t in translations_text] - generated_summaries.extend(summaries) + #if self.model_name.startswith("t5"): + # batch = [self.model.config.prefix + text for text in batch] + #dct = self.tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True) + print(batch) + summaries = self.model.module.generate( + input_ids=batch["source_ids"], + attention_mask=batch["source_mask"], + min_length=min_length, + max_length=max_length + ) + decoded_summaries = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries] + generated_summaries.extend(decoded_summaries) # release GPU memories self.model.cpu() From ee334d9fdc88bfefb56955480bc19fab581155a1 Mon Sep 17 00:00:00 2001 From: Daisy Deng Date: Fri, 1 May 2020 18:32:22 +0000 Subject: [PATCH 04/14] save work --- ...ive_summarization_cnndm_transformers.ipynb | 32 ++++++++++------ .../abstractive_summarization_bartt5.py | 38 +++++++++++++++++++ 2 files changed, 59 insertions(+), 11 deletions(-) diff --git a/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb b/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb index 914dab416..2937be122 100644 --- a/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb +++ b/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb @@ -68,7 +68,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 2, "metadata": { "scrolled": true }, @@ -123,7 +123,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 3, "metadata": { "tags": [ "parameters" @@ -137,13 +137,13 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "56813bb1ba0e43d598bd19918bef2a80", + "model_id": "c8a2c62de1354321a25c5c1d5fce2b36", "version_major": 2, "version_minor": 0 }, @@ -164,7 +164,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "b487ea9a9776423faa51b7d3990f8cb4", + "model_id": "de2a32d4c70844cd913cb5f8cbf833e7", "version_major": 2, "version_minor": 0 }, @@ -202,7 +202,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 5, "metadata": { "tags": [ "parameters" @@ -220,7 +220,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 6, "metadata": { "scrolled": false }, @@ -229,7 +229,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 489k/489k [00:08<00:00, 56.3kKB/s] \n" + "100%|██████████| 489k/489k [00:08<00:00, 59.9kKB/s] \n" ] } ], @@ -246,7 +246,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 7, "metadata": { "scrolled": false }, @@ -388,7 +388,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 8, "metadata": { "tags": [ "parameters" @@ -425,7 +425,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 11, "metadata": { "scrolled": true }, @@ -600,6 +600,16 @@ "target[0]" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + " " + ] + }, { "cell_type": "code", "execution_count": 19, diff --git a/utils_nlp/models/transformers/abstractive_summarization_bartt5.py b/utils_nlp/models/transformers/abstractive_summarization_bartt5.py index 97c5ff8a7..5965f7f33 100644 --- a/utils_nlp/models/transformers/abstractive_summarization_bartt5.py +++ b/utils_nlp/models/transformers/abstractive_summarization_bartt5.py @@ -21,6 +21,7 @@ ) import torch.multiprocessing torch.multiprocessing.set_sharing_strategy('file_system') +from torch import nn from torch.utils.data.distributed import DistributedSampler from transformers import BertModel @@ -123,6 +124,30 @@ def parallel_preprocess(input_data, preprocess, num_pool=-1): return results +class Predictor(nn.Module): + def __init__( + self, + model, + tokenizer, + min_length, + max_length): + super(Translator, self).__init__() + self.model = model.module if hasattr(model, "module") else model + self.tokenizer = tokenizer + self.min_length = min_length + self.max_length = max_length + + def forward(self, src, src_mask): + device = src.device + with torch.no_grad(): + summaries = self.model.generate( + input_ids=src, + attention_mask=src_mask, + min_length=min_length, + max_length=max_length + ) + decoded_summaries = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries] + return decoded_summaries class SummarizationProcessor: def __init__( @@ -506,6 +531,16 @@ def collate_fn(data): collate_fn=collate_fn, ) print("dataset length is {}".format(len(test_dataset))) + + predictor = Predictor(self.model, self.tokenizer, min_length, max_length) + # move model to devices + def this_model_move_callback(model, device): + model = move_model_to_device(model, device) + return parallelize_model( + model, device, num_gpus=num_gpus, gpu_ids=gpu_ids, local_rank=local_rank + ) + predictor = this_model_move_callback(predictor, device) + generated_summaries = [] for batch in tqdm( @@ -515,6 +550,8 @@ def collate_fn(data): # batch = [self.model.config.prefix + text for text in batch] #dct = self.tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True) print(batch) + decoded_summaries = predictor(batch["source_ids"], batch["source_mask"]) + """ summaries = self.model.module.generate( input_ids=batch["source_ids"], attention_mask=batch["source_mask"], @@ -522,6 +559,7 @@ def collate_fn(data): max_length=max_length ) decoded_summaries = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries] + """ generated_summaries.extend(decoded_summaries) # release GPU memories From cef3da356ff08485ea56f29c642e46cac4f4f318 Mon Sep 17 00:00:00 2001 From: Daisy Deng Date: Fri, 1 May 2020 20:48:41 +0000 Subject: [PATCH 05/14] enable multi-GPU inference --- ...ive_summarization_cnndm_transformers.ipynb | 482 +++++++++++++++--- .../abstractive_summarization_bartt5.py | 29 +- 2 files changed, 436 insertions(+), 75 deletions(-) diff --git a/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb b/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb index 2937be122..c75cc4da7 100644 --- a/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb +++ b/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb @@ -72,7 +72,38 @@ "metadata": { "scrolled": true }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/daden/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n", + "/home/daden/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n", + "/home/daden/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n", + "/home/daden/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n", + "/home/daden/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n", + "/home/daden/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n", + "/home/daden/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n", + "/home/daden/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n", + "/home/daden/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n", + "/home/daden/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n", + "/home/daden/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n", + "/home/daden/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", + " np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n" + ] + } + ], "source": [ "import os\n", "import shutil\n", @@ -123,7 +154,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": { "tags": [ "parameters" @@ -137,18 +168,18 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "c8a2c62de1354321a25c5c1d5fce2b36", + "model_id": "e18ad04d5d6c472c88c306c15e6f16c4", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=791656.0, style=ProgressStyle(descripti…" + "HBox(children=(IntProgress(value=0, description='Downloading', max=791656, style=ProgressStyle(description_wid…" ] }, "metadata": {}, @@ -164,12 +195,12 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "de2a32d4c70844cd913cb5f8cbf833e7", + "model_id": "9b49de228e3e4d619182ce0081b6d953", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1200.0, style=ProgressStyle(description…" + "HBox(children=(IntProgress(value=0, description='Downloading', max=1200, style=ProgressStyle(description_width…" ] }, "metadata": {}, @@ -202,7 +233,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": { "tags": [ "parameters" @@ -220,7 +251,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": { "scrolled": false }, @@ -229,7 +260,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|██████████| 489k/489k [00:08<00:00, 59.9kKB/s] \n" + "100%|██████████| 489k/489k [00:07<00:00, 63.4kKB/s] \n" ] } ], @@ -246,7 +277,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": { "scrolled": false }, @@ -347,7 +378,7 @@ " 'src_txt': \"editor 's note : in our behind the scenes series , cnn correspondents share their experiences in covering news and analyze the stories behind the events . here , soledad o'brien takes users inside a jail where many of the inmates are mentally ill . an inmate housed on the `` forgotten floor , '' where many mentally ill inmates are housed in miami before trial . miami , florida -lrb- cnn -rrb- -- the ninth floor of the miami-dade pretrial detention facility is dubbed the `` forgotten floor . '' here , inmates with the most severe mental illnesses are incarcerated until they 're ready to appear in court . most often , they face drug charges or charges of assaulting an officer -- charges that judge steven leifman says are usually `` avoidable felonies . '' he says the arrests often result from confrontations with police . mentally ill people often wo n't do what they 're told when police arrive on the scene -- confrontation seems to exacerbate their illness and they become more paranoid , delusional , and less likely to follow directions , according to leifman . so , they end up on the ninth floor severely mentally disturbed , but not getting any real help because they 're in jail . we toured the jail with leifman . he is well known in miami as an advocate for justice and the mentally ill . even though we were not exactly welcomed with open arms by the guards , we were given permission to shoot videotape and tour the floor . go inside the ` forgotten floor ' '' at first , it 's hard to determine where the people are . the prisoners are wearing sleeveless robes . imagine cutting holes for arms and feet in a heavy wool sleeping bag -- that 's kind of what they look like . they 're designed to keep the mentally ill patients from injuring themselves . that 's also why they have no shoes , laces or mattresses . leifman says about one-third of all people in miami-dade county jails are mentally ill . so , he says , the sheer volume is overwhelming the system , and the result is what we see on the ninth floor . of course , it is a jail , so it 's not supposed to be warm and comforting , but the lights glare , the cells are tiny and it 's loud . we see two , sometimes three men -- sometimes in the robes , sometimes naked , lying or sitting in their cells . `` i am the son of the president . you need to get me out of here ! '' one man shouts at me . he is absolutely serious , convinced that help is on the way -- if only he could reach the white house . leifman tells me that these prisoner-patients will often circulate through the system , occasionally stabilizing in a mental hospital , only to return to jail to face their charges . it 's brutally unjust , in his mind , and he has become a strong advocate for changing things in miami . over a meal later , we talk about how things got this way for mental patients . leifman says 200 years ago people were considered `` lunatics '' and they were locked up in jails even if they had no charges against them . they were just considered unfit to be in society . over the years , he says , there was some public outcry , and the mentally ill were moved out of jails and into hospitals . but leifman says many of these mental hospitals were so horrible they were shut down . where did the patients go ? nowhere . the streets . they became , in many cases , the homeless , he says . they never got treatment . leifman says in 1955 there were more than half a million people in state mental hospitals , and today that number has been reduced 90 percent , and 40,000 to 50,000 people are in mental hospitals . the judge says he 's working to change this . starting in 2008 , many inmates who would otherwise have been brought to the `` forgotten floor '' will instead be sent to a new mental health facility -- the first step on a journey toward long-term treatment , not just punishment . leifman says it 's not the complete answer , but it 's a start . leifman says the best part is that it 's a win-win solution . the patients win , the families are relieved , and the state saves money by simply not cycling these prisoners through again and again . and , for leifman , justice is served . e-mail to a friend .\\n\",\n", " 'tgt': \" mentally ill inmates in miami are housed on the `` forgotten floor '' judge steven leifman says most are there as a result of `` avoidable felonies '' while cnn tours facility , patient shouts : `` i am the son of the president '' leifman says the system is unjust and he 's fighting for change . \\n\",\n", " 'tgt_txt': \" mentally ill inmates in miami are housed on the `` forgotten floor '' judge steven leifman says most are there as a result of `` avoidable felonies '' while cnn tours facility , patient shouts : `` i am the son of the president '' leifman says the system is unjust and he 's fighting for change . \\n\",\n", - " 'source_ids': tensor([6005, 3, 31, ..., 307, 18, 1987]),\n", + " 'source_ids': tensor([21603, 10, 6005, ..., 2027, 2957, 307]),\n", " 'source_mask': tensor([1, 1, 1, ..., 1, 1, 1]),\n", " 'target_ids': tensor([ 3, 2, 17, 3155, 19367, 3, 1092, 16, 11171, 16,\n", " 1337, 3690, 33, 629, 26, 30, 8, 3, 2, 11821,\n", @@ -388,7 +419,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 13, "metadata": { "tags": [ "parameters" @@ -425,7 +456,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 14, "metadata": { "scrolled": true }, @@ -602,27 +633,17 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 28, "metadata": {}, - "outputs": [], - "source": [ - "\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": { - "scrolled": false - }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\r", - "Generating summary: 0%| | 0/2 [00:00'" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prediction[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 29, "metadata": {}, "outputs": [ { @@ -704,7 +1015,7 @@ "20" ] }, - "execution_count": 22, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } @@ -713,6 +1024,45 @@ "len(prediction)" ] }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[\"marseille , france - cnn - the french prosecutor leading an investigation into the crash of germanwings flight 9525 insisted wednesday that he was not aware of any video footage from on board the plane . marseille prosecutor brice robin told cnn that so far no videos were used in the crash investigation . '' robin 's comments follow claims by two magazines , german daily bild and french paris match , of\",\n", + " \"the formal accession was marked with a ceremony at the hague , in the netherlands , where the court is based . the palestinians signed the icc's founding rome statute in january . as members of the court , palestinians may be subject to counter-charges as well .\",\n", + " \"amnesty international's annual report catalogs the use of state-sanctioned killing as a punitive measure across the globe . at least 607 people were executed around the world in 2014 . at least 2,466 people globally are confirmed to have been handed the sentence .\",\n", + " 'amnesty international releases annual review of death penalty worldwide . death sentences up 500% on previous year, mostly because of pakistan . countries using death penalty to tackle crime and terrorism are deceiving themselves .',\n", + " 'anne frank died of typhus in a nazi concentration camp at the age of 15 . new research shows that she and her older sister probably did not survive to march 1945 . the exact dates of death for anne and margot remain unclear .',\n", + " \"a duke student has admitted to hanging a noose made of rope from a tree near a student union , university officials said thursday . the prestigious private school did n't identify the student , citing federal privacy laws . in a news release , it said the student was no longer on campus and will face student conduct review .\",\n", + " \"the rev. robert h. schuller , california televangelist and founder of the television ministry hour of power , '' died thursday , according to his family . he was 88 years old . schuller , also the founder of crystal cathedral megachurch , had been diagnosed with esophageal cancer in august 2013 .\",\n", + " 'the dog was hit by a car, apparently killed with a hammer and buried in a field . she staggered to a nearby farm, dirt-covered and emaciated . the dog was found by a worker who took her to a vet for help . she suffered a dislocated jaw, leg injuries and a caved-in sinus cavity .',\n", + " \"mohammad javad zarif is the iranian foreign minister . he has been john kerry's opposite number in securing nuclear breakthrough . zarif is 54, but his official biography says he was born in 1960 . he was investigated by the feds over his alleged role in controlling a charitable organization .\",\n", + " \"for the first time in eight years , a tv legend returned to doing what he does best . contestants told to come on down ! '' on the april 1 edition of the price is right '' encountered not host drew carey but another familiar face in charge of the proceedings . instead , there was bob barker , who hosted the tv game show for 35 years before stepping down in 2007 .\",\n", + " \"-lrb- he 's a blue chip college basketball recruit . she 's a high school freshman with down syndrome . at first glance trey moses and ellie meredith could n't be more different . but all that changed thursday when trey asked ellie to be his prom date .\",\n", + " \"michele bachmann compared president obama to the co-pilot of the doomed germanwings flight . ''with his iran deal , barack obama is for the 300 million souls of the united states,'' she wrote in a facebook comment posted march 31 . many comments posted on her facebook page blasted the former representative .\",\n", + " 'california is a breadbasket to the nation . california is growing more than a third of its vegetables and nearly two-thirds of its fruits and nuts . the drought is in its fourth year .',\n", + " \"walmart's staunch criticism of a religious freedom law in its home state of arkansas came after the company said in february it would boost pay for about 500,000 workers well above the federal minimum wage . the company is emerging as a bellwether for shifting public opinion on hot-button political issues that divide conservatives and liberals . former minnesota gov. tim pawlenty said walmart's actions foreshadow where the republican party will need to move \",\n", + " \"five americans who were monitored for three weeks at an omaha hospital have been released . one of the five had a heart-related issue on saturday and has been discharged but has n't left the area . they were exposed to ebola in sierra leone in march .\",\n", + " \"andrew getty, 47, was found dead in his los angeles home . the coroner's preliminary assessment is there was no foul play involved in the death of getty . getty , grandson of oil tycoon j. paul getty , was found dead near a bathroom in his home .\",\n", + " \"mike pence signed a religious freedom law last week that opens the door to discrimination against gays and lesbians . pence: 'i foolishly hoped this kind of backlash . ''there is no way a republican can get through the pending primary without denouncing lgbt rights,'' he says .\",\n", + " \"filipinos are being warned to be on guard for flash floods and landslides as tropical storm maysak approached the asian island nation saturday . just a few days ago , maysak gained super typhoon status thanks to its sustained 150 mph winds . it has since lost a lot of steam as it has spun west in the pacific ocean . it 's now classified as a tropical storm , according to the philippine national weather\",\n", + " \"norfolk , virginia - the second mate of the houston express probably could n't believe what he was seeing . hundreds of miles from land there was a small boat nearby . at first it looked abandoned . it was in bad shape , listing to one side . the crew of the 1,000-foot long container ship thought it was a yacht that had wrecked . incredibly , as they got closer , they saw there was a man on it \",\n", + " \"walker died in november 2013 after a fiery car crash . the release of furious 7 '' on friday offers fans the opportunity to remember -- and possibly grieve again -- the man that so many have praised as one of the nicest guys in hollywood .\"]" + ] + }, + "execution_count": 30, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prediction" + ] + }, { "cell_type": "code", "execution_count": 20, @@ -757,9 +1107,7 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "len()" - ] + "source": [] }, { "cell_type": "code", @@ -1036,7 +1384,7 @@ "metadata": { "celltoolbar": "Tags", "kernelspec": { - "display_name": "nlp_gpu", + "display_name": "Python (nlp_gpu)", "language": "python", "name": "nlp_gpu" }, diff --git a/utils_nlp/models/transformers/abstractive_summarization_bartt5.py b/utils_nlp/models/transformers/abstractive_summarization_bartt5.py index 5965f7f33..813bcd230 100644 --- a/utils_nlp/models/transformers/abstractive_summarization_bartt5.py +++ b/utils_nlp/models/transformers/abstractive_summarization_bartt5.py @@ -131,7 +131,7 @@ def __init__( tokenizer, min_length, max_length): - super(Translator, self).__init__() + super(Predictor, self).__init__() self.model = model.module if hasattr(model, "module") else model self.tokenizer = tokenizer self.min_length = min_length @@ -143,11 +143,23 @@ def forward(self, src, src_mask): summaries = self.model.generate( input_ids=src, attention_mask=src_mask, - min_length=min_length, - max_length=max_length + min_length=self.min_length, + max_length=self.max_length ) - decoded_summaries = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries] - return decoded_summaries + print(summaries) + predictions = torch.tensor( + [ + i.tolist()[0 : self.max_length] + + [0] * (self.max_length - i.size()[0]) + for i in summaries + ], + device=device, + ) + + return predictions + #decoded_summaries = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries] + #print(decoded_summaries) + #return decoded_summaries class SummarizationProcessor: def __init__( @@ -549,8 +561,8 @@ def this_model_move_callback(model, device): #if self.model_name.startswith("t5"): # batch = [self.model.config.prefix + text for text in batch] #dct = self.tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True) - print(batch) - decoded_summaries = predictor(batch["source_ids"], batch["source_mask"]) + # print(batch) + summaries = predictor(batch["source_ids"], batch["source_mask"]) """ summaries = self.model.module.generate( input_ids=batch["source_ids"], @@ -558,8 +570,9 @@ def this_model_move_callback(model, device): min_length=min_length, max_length=max_length ) - decoded_summaries = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries] """ + decoded_summaries = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries] + generated_summaries.extend(decoded_summaries) # release GPU memories From f3b5bd73bc6dd0973b4357e35625e932277285c9 Mon Sep 17 00:00:00 2001 From: Daisy Deng Date: Tue, 5 May 2020 15:33:13 +0000 Subject: [PATCH 06/14] multi-gpu inferencing --- ...ive_summarization_cnndm_transformers.ipynb | 901 +++++------------- .../abstractive_summarization_bartt5.py | 72 +- 2 files changed, 269 insertions(+), 704 deletions(-) diff --git a/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb b/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb index c75cc4da7..8f844a048 100644 --- a/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb +++ b/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb @@ -56,7 +56,7 @@ "\n", "%autoreload 2\n", "## Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.\n", - "QUICK_RUN = True\n" + "QUICK_RUN = False\n" ] }, { @@ -154,7 +154,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": { "tags": [ "parameters" @@ -163,23 +163,45 @@ "outputs": [], "source": [ "# Transformer model being used\n", - "MODEL_NAME = \"t5-large\"" + "#MODEL_NAME = \"t5-large\"\n", + "MODEL_NAME = \"bart-large-cnn\"" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "e18ad04d5d6c472c88c306c15e6f16c4", + "model_id": "f2ea5f9a1088431282d2df7823104710", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0, description='Downloading', max=898823, style=ProgressStyle(description_wid…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "91031adf816a423da2966031a9731278", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Downloading', max=791656, style=ProgressStyle(description_wid…" + "HBox(children=(IntProgress(value=0, description='Downloading', max=456318, style=ProgressStyle(description_wid…" ] }, "metadata": {}, @@ -195,12 +217,12 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "9b49de228e3e4d619182ce0081b6d953", + "model_id": "bda49783badc4f069f0e57cbd2e1b8c8", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Downloading', max=1200, style=ProgressStyle(description_width…" + "HBox(children=(IntProgress(value=0, description='Downloading', max=1300, style=ProgressStyle(description_width…" ] }, "metadata": {}, @@ -233,7 +255,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": { "tags": [ "parameters" @@ -242,7 +264,7 @@ "outputs": [], "source": [ "# the data path used to save the downloaded data file\n", - "DATA_PATH = TemporaryDirectory().name\n", + "DATA_PATH = \"./bartt5_cnndm\" #TemporaryDirectory().name\n", "# The number of lines at the head of data file used for preprocessing. -1 means all the lines.\n", "TOP_N = 1000\n", "if not QUICK_RUN:\n", @@ -251,21 +273,33 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": { "scrolled": false }, + "outputs": [], + "source": [ + "train_dataset, test_dataset = CNNDMSummarizationDataset(top_n=TOP_N, local_cache_path=DATA_PATH, raw=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, "outputs": [ { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 489k/489k [00:07<00:00, 63.4kKB/s] \n" - ] + "data": { + "text/plain": [ + "11490" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "train_dataset, test_dataset = CNNDMSummarizationDataset(top_n=TOP_N, local_cache_path=DATA_PATH, raw=True)" + "len(test_dataset)" ] }, { @@ -277,58 +311,138 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "scrolled": false }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 2 µs, sys: 2 µs, total: 4 µs\n", + "Wall time: 9.78 µs\n" + ] + } + ], + "source": [ + "%time\n", + "abs_sum_train = processor.preprocess(train_dataset)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'abs_sum_train' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mabs_sum_train\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name 'abs_sum_train' is not defined" + ] + } + ], + "source": [ + "abs_sum_train[0].keys()" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'abs_sum_train' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mabs_sum_train\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;31mNameError\u001b[0m: name 'abs_sum_train' is not defined" + ] + } + ], + "source": [ + "abs_sum_train" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, "outputs": [], "source": [ - "\n", - "abs_sum_train = processor.preprocess(train_dataset)\n", - "abs_sum_test = processor.preprocess(test_dataset)\n" + "abs_sum_test = processor.preprocess(test_dataset)" ] }, { "cell_type": "code", - "execution_count": 9, - "metadata": { - "scrolled": true - }, + "execution_count": 11, + "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "'\\n# save and load preprocessed data\\nsave_path = os.path.join(DATA_PATH, \"processed\")\\ntorch.save(ext_sum_train, os.path.join(save_path, \"train_full.pt\"))\\ntorch.save(ext_sum_test, os.path.join(save_path, \"test_full.pt\"))\\n\\n'" + "{'src': \"marseille , france -lrb- cnn -rrb- the french prosecutor leading an investigation into the crash of germanwings flight 9525 insisted wednesday that he was not aware of any video footage from on board the plane . marseille prosecutor brice robin told cnn that `` so far no videos were used in the crash investigation . '' he added , `` a person who has such a video needs to immediately give it to the investigators . '' robin 's comments follow claims by two magazines , german daily bild and french paris match , of a cell phone video showing the harrowing final seconds from on board germanwings flight 9525 as it crashed into the french alps . all 150 on board were killed . paris match and bild reported that the video was recovered from a phone at the wreckage site . the two publications described the supposed video , but did not post it on their websites . the publications said that they watched the video , which was found by a source close to the investigation . `` one can hear cries of ` my god ' in several languages , '' paris match reported . `` metallic banging can also be heard more than three times , perhaps of the pilot trying to open the cockpit door with a heavy object . towards the end , after a heavy shake , stronger than the others , the screaming intensifies . then nothing . '' `` it is a very disturbing scene , '' said julian reichelt , editor-in-chief of bild online . an official with france 's accident investigation agency , the bea , said the agency is not aware of any such video . lt. col. jean-marc menichini , a french gendarmerie spokesman in charge of communications on rescue efforts around the germanwings crash site , told cnn that the reports were `` completely wrong '' and `` unwarranted . '' cell phones have been collected at the site , he said , but that they `` had n't been exploited yet . '' menichini said he believed the cell phones would need to be sent to the criminal research institute in rosny sous-bois , near paris , in order to be analyzed by specialized technicians working hand-in-hand with investigators . but none of the cell phones found so far have been sent to the institute , menichini said . asked whether staff involved in the search could have leaked a memory card to the media , menichini answered with a categorical `` no . '' reichelt told `` erin burnett : outfront '' that he had watched the video and stood by the report , saying bild and paris match are `` very confident '' that the clip is real . he noted that investigators only revealed they 'd recovered cell phones from the crash site after bild and paris match published their reports . `` that is something we did not know before . ... overall we can say many things of the investigation were n't revealed by the investigation at the beginning , '' he said . what was mental state of germanwings co-pilot ? german airline lufthansa confirmed tuesday that co-pilot andreas lubitz had battled depression years before he took the controls of germanwings flight 9525 , which he 's accused of deliberately crashing last week in the french alps . lubitz told his lufthansa flight training school in 2009 that he had a `` previous episode of severe depression , '' the airline said tuesday . email correspondence between lubitz and the school discovered in an internal investigation , lufthansa said , included medical documents he submitted in connection with resuming his flight training . the announcement indicates that lufthansa , the parent company of germanwings , knew of lubitz 's battle with depression , allowed him to continue training and ultimately put him in the cockpit . lufthansa , whose ceo carsten spohr previously said lubitz was 100 % fit to fly , described its statement tuesday as a `` swift and seamless clarification '' and said it was sharing the information and documents -- including training and medical records -- with public prosecutors . spohr traveled to the crash site wednesday , where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside . he saw the crisis center set up in seyne-les-alpes , laid a wreath in the village of le vernet , closer to the crash site , where grieving families have left flowers at a simple stone memorial . menichini told cnn late tuesday that no visible human remains were left at the site but recovery teams would keep searching . french president francois hollande , speaking tuesday , said that it should be possible to identify all the victims using dna analysis by the end of the week , sooner than authorities had previously suggested . in the meantime , the recovery of the victims ' personal belongings will start wednesday , menichini said . among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board . check out the latest from our correspondents . the details about lubitz 's correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and lubitz 's possible motive for downing the jet . a lufthansa spokesperson told cnn on tuesday that lubitz had a valid medical certificate , had passed all his examinations and `` held all the licenses required . '' earlier , a spokesman for the prosecutor 's office in dusseldorf , christoph kumpa , said medical records reveal lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot 's license . kumpa emphasized there 's no evidence suggesting lubitz was suicidal or acting aggressively before the crash . investigators are looking into whether lubitz feared his medical condition would cause him to lose his pilot 's license , a european government official briefed on the investigation told cnn on tuesday . while flying was `` a big part of his life , '' the source said , it 's only one theory being considered . another source , a law enforcement official briefed on the investigation , also told cnn that authorities believe the primary motive for lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems . lubitz 's girlfriend told investigators he had seen an eye doctor and a neuropsychologist , both of whom deemed him unfit to work recently and concluded he had psychological issues , the european government official said . but no matter what details emerge about his previous mental health struggles , there 's more to the story , said brian russell , a forensic psychologist . `` psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they were n't going to keep doing their job and they 're upset about that and so they 're suicidal , '' he said . `` but there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person 's problems . '' germanwings crash compensation : what we know . who was the captain of germanwings flight 9525 ? cnn 's margot haddad reported from marseille and pamela brown from dusseldorf , while laura smith-spark wrote from london . cnn 's frederik pleitgen , pamela boykoff , antonia mortensen , sandrine amiel and anna-maja rappard contributed to this report .\\n\",\n", + " 'src_txt': \"marseille , france -lrb- cnn -rrb- the french prosecutor leading an investigation into the crash of germanwings flight 9525 insisted wednesday that he was not aware of any video footage from on board the plane . marseille prosecutor brice robin told cnn that `` so far no videos were used in the crash investigation . '' he added , `` a person who has such a video needs to immediately give it to the investigators . '' robin 's comments follow claims by two magazines , german daily bild and french paris match , of a cell phone video showing the harrowing final seconds from on board germanwings flight 9525 as it crashed into the french alps . all 150 on board were killed . paris match and bild reported that the video was recovered from a phone at the wreckage site . the two publications described the supposed video , but did not post it on their websites . the publications said that they watched the video , which was found by a source close to the investigation . `` one can hear cries of ` my god ' in several languages , '' paris match reported . `` metallic banging can also be heard more than three times , perhaps of the pilot trying to open the cockpit door with a heavy object . towards the end , after a heavy shake , stronger than the others , the screaming intensifies . then nothing . '' `` it is a very disturbing scene , '' said julian reichelt , editor-in-chief of bild online . an official with france 's accident investigation agency , the bea , said the agency is not aware of any such video . lt. col. jean-marc menichini , a french gendarmerie spokesman in charge of communications on rescue efforts around the germanwings crash site , told cnn that the reports were `` completely wrong '' and `` unwarranted . '' cell phones have been collected at the site , he said , but that they `` had n't been exploited yet . '' menichini said he believed the cell phones would need to be sent to the criminal research institute in rosny sous-bois , near paris , in order to be analyzed by specialized technicians working hand-in-hand with investigators . but none of the cell phones found so far have been sent to the institute , menichini said . asked whether staff involved in the search could have leaked a memory card to the media , menichini answered with a categorical `` no . '' reichelt told `` erin burnett : outfront '' that he had watched the video and stood by the report , saying bild and paris match are `` very confident '' that the clip is real . he noted that investigators only revealed they 'd recovered cell phones from the crash site after bild and paris match published their reports . `` that is something we did not know before . ... overall we can say many things of the investigation were n't revealed by the investigation at the beginning , '' he said . what was mental state of germanwings co-pilot ? german airline lufthansa confirmed tuesday that co-pilot andreas lubitz had battled depression years before he took the controls of germanwings flight 9525 , which he 's accused of deliberately crashing last week in the french alps . lubitz told his lufthansa flight training school in 2009 that he had a `` previous episode of severe depression , '' the airline said tuesday . email correspondence between lubitz and the school discovered in an internal investigation , lufthansa said , included medical documents he submitted in connection with resuming his flight training . the announcement indicates that lufthansa , the parent company of germanwings , knew of lubitz 's battle with depression , allowed him to continue training and ultimately put him in the cockpit . lufthansa , whose ceo carsten spohr previously said lubitz was 100 % fit to fly , described its statement tuesday as a `` swift and seamless clarification '' and said it was sharing the information and documents -- including training and medical records -- with public prosecutors . spohr traveled to the crash site wednesday , where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside . he saw the crisis center set up in seyne-les-alpes , laid a wreath in the village of le vernet , closer to the crash site , where grieving families have left flowers at a simple stone memorial . menichini told cnn late tuesday that no visible human remains were left at the site but recovery teams would keep searching . french president francois hollande , speaking tuesday , said that it should be possible to identify all the victims using dna analysis by the end of the week , sooner than authorities had previously suggested . in the meantime , the recovery of the victims ' personal belongings will start wednesday , menichini said . among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board . check out the latest from our correspondents . the details about lubitz 's correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and lubitz 's possible motive for downing the jet . a lufthansa spokesperson told cnn on tuesday that lubitz had a valid medical certificate , had passed all his examinations and `` held all the licenses required . '' earlier , a spokesman for the prosecutor 's office in dusseldorf , christoph kumpa , said medical records reveal lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot 's license . kumpa emphasized there 's no evidence suggesting lubitz was suicidal or acting aggressively before the crash . investigators are looking into whether lubitz feared his medical condition would cause him to lose his pilot 's license , a european government official briefed on the investigation told cnn on tuesday . while flying was `` a big part of his life , '' the source said , it 's only one theory being considered . another source , a law enforcement official briefed on the investigation , also told cnn that authorities believe the primary motive for lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems . lubitz 's girlfriend told investigators he had seen an eye doctor and a neuropsychologist , both of whom deemed him unfit to work recently and concluded he had psychological issues , the european government official said . but no matter what details emerge about his previous mental health struggles , there 's more to the story , said brian russell , a forensic psychologist . `` psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they were n't going to keep doing their job and they 're upset about that and so they 're suicidal , '' he said . `` but there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person 's problems . '' germanwings crash compensation : what we know . who was the captain of germanwings flight 9525 ? cnn 's margot haddad reported from marseille and pamela brown from dusseldorf , while laura smith-spark wrote from london . cnn 's frederik pleitgen , pamela boykoff , antonia mortensen , sandrine amiel and anna-maja rappard contributed to this report .\\n\",\n", + " 'tgt': \" marseille prosecutor says `` so far no videos were used in the crash investigation '' despite media reports . journalists at bild and paris match are `` very confident '' the video clip is real , an editor says . andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says . \\n\",\n", + " 'tgt_txt': \" marseille prosecutor says `` so far no videos were used in the crash investigation '' despite media reports . journalists at bild and paris match are `` very confident '' the video clip is real , an editor says . andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says . \\n\",\n", + " 'source_ids': tensor([ 0, 4401, 1090, ..., 604, 1725, 2]),\n", + " 'source_mask': tensor([1, 1, 1, ..., 1, 1, 1]),\n", + " 'target_ids': tensor([ 0, 28696, 90, 15698, 4401, 1090, 4061, 5644, 161, 45518,\n", + " 98, 444, 117, 3424, 58, 341, 11, 5, 2058, 803,\n", + " 12801, 1135, 433, 690, 479, 49703, 90, 15698, 28696, 90,\n", + " 15698, 4225, 23, 741, 9683, 8, 2242, 354, 914, 32,\n", + " 45518, 182, 3230, 12801, 5, 569, 7200, 16, 588, 2156,\n", + " 41, 4474, 161, 479, 49703, 2])}" ] }, - "execution_count": 9, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], + "source": [ + "abs_sum_test[0]" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "scrolled": true + }, + "outputs": [], "source": [ "\"\"\"\n", "# save and load preprocessed data\n", - "save_path = os.path.join(DATA_PATH, \"processed\")\n", - "torch.save(ext_sum_train, os.path.join(save_path, \"train_full.pt\"))\n", - "torch.save(ext_sum_test, os.path.join(save_path, \"test_full.pt\"))\n", + "save_path = DATA_PATH\n", + "torch.save(abs_sum_train, os.path.join(save_path, \"train_full.pt\"))\n", + "torch.save(abs_sum_test, os.path.join(DATA_PATH, \"test_full.pt\"))\n", "\n", "\"\"\"\n", - "# ext_sum_train = torch.load(os.path.join(save_path, \"train_full.pt\"))\n", - "# ext_sum_test = torch.load(os.path.join(save_path, \"test_full.pt\"))" + "save_path = DATA_PATH\n", + "#abs_sum_train = torch.load(os.path.join(save_path, \"train_full.pt\"))\n", + "abs_sum_test = torch.load(os.path.join(save_path, \"test_full.pt\"))" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "1000\n", - "1000\n" + "287227\n", + "11490\n" ] } ], @@ -337,6 +451,17 @@ "print(len(abs_sum_test))" ] }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "#save_path = os.path.join(DATA_PATH, \"processed\")\n", + "#torch.save(abs_sum_train, os.path.join(save_path, \"train_full.pt\"))\n", + "#torch.save(abs_sum_test, os.path.join(DATA_PATH, \"test_full.pt\"))" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -346,53 +471,20 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "dict_keys(['src', 'src_txt', 'tgt', 'tgt_txt', 'source_ids', 'source_mask', 'target_ids'])" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "abs_sum_train[0].keys()" ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": { "scrolled": false }, - "outputs": [ - { - "data": { - "text/plain": [ - "{'src': \"editor 's note : in our behind the scenes series , cnn correspondents share their experiences in covering news and analyze the stories behind the events . here , soledad o'brien takes users inside a jail where many of the inmates are mentally ill . an inmate housed on the `` forgotten floor , '' where many mentally ill inmates are housed in miami before trial . miami , florida -lrb- cnn -rrb- -- the ninth floor of the miami-dade pretrial detention facility is dubbed the `` forgotten floor . '' here , inmates with the most severe mental illnesses are incarcerated until they 're ready to appear in court . most often , they face drug charges or charges of assaulting an officer -- charges that judge steven leifman says are usually `` avoidable felonies . '' he says the arrests often result from confrontations with police . mentally ill people often wo n't do what they 're told when police arrive on the scene -- confrontation seems to exacerbate their illness and they become more paranoid , delusional , and less likely to follow directions , according to leifman . so , they end up on the ninth floor severely mentally disturbed , but not getting any real help because they 're in jail . we toured the jail with leifman . he is well known in miami as an advocate for justice and the mentally ill . even though we were not exactly welcomed with open arms by the guards , we were given permission to shoot videotape and tour the floor . go inside the ` forgotten floor ' '' at first , it 's hard to determine where the people are . the prisoners are wearing sleeveless robes . imagine cutting holes for arms and feet in a heavy wool sleeping bag -- that 's kind of what they look like . they 're designed to keep the mentally ill patients from injuring themselves . that 's also why they have no shoes , laces or mattresses . leifman says about one-third of all people in miami-dade county jails are mentally ill . so , he says , the sheer volume is overwhelming the system , and the result is what we see on the ninth floor . of course , it is a jail , so it 's not supposed to be warm and comforting , but the lights glare , the cells are tiny and it 's loud . we see two , sometimes three men -- sometimes in the robes , sometimes naked , lying or sitting in their cells . `` i am the son of the president . you need to get me out of here ! '' one man shouts at me . he is absolutely serious , convinced that help is on the way -- if only he could reach the white house . leifman tells me that these prisoner-patients will often circulate through the system , occasionally stabilizing in a mental hospital , only to return to jail to face their charges . it 's brutally unjust , in his mind , and he has become a strong advocate for changing things in miami . over a meal later , we talk about how things got this way for mental patients . leifman says 200 years ago people were considered `` lunatics '' and they were locked up in jails even if they had no charges against them . they were just considered unfit to be in society . over the years , he says , there was some public outcry , and the mentally ill were moved out of jails and into hospitals . but leifman says many of these mental hospitals were so horrible they were shut down . where did the patients go ? nowhere . the streets . they became , in many cases , the homeless , he says . they never got treatment . leifman says in 1955 there were more than half a million people in state mental hospitals , and today that number has been reduced 90 percent , and 40,000 to 50,000 people are in mental hospitals . the judge says he 's working to change this . starting in 2008 , many inmates who would otherwise have been brought to the `` forgotten floor '' will instead be sent to a new mental health facility -- the first step on a journey toward long-term treatment , not just punishment . leifman says it 's not the complete answer , but it 's a start . leifman says the best part is that it 's a win-win solution . the patients win , the families are relieved , and the state saves money by simply not cycling these prisoners through again and again . and , for leifman , justice is served . e-mail to a friend .\\n\",\n", - " 'src_txt': \"editor 's note : in our behind the scenes series , cnn correspondents share their experiences in covering news and analyze the stories behind the events . here , soledad o'brien takes users inside a jail where many of the inmates are mentally ill . an inmate housed on the `` forgotten floor , '' where many mentally ill inmates are housed in miami before trial . miami , florida -lrb- cnn -rrb- -- the ninth floor of the miami-dade pretrial detention facility is dubbed the `` forgotten floor . '' here , inmates with the most severe mental illnesses are incarcerated until they 're ready to appear in court . most often , they face drug charges or charges of assaulting an officer -- charges that judge steven leifman says are usually `` avoidable felonies . '' he says the arrests often result from confrontations with police . mentally ill people often wo n't do what they 're told when police arrive on the scene -- confrontation seems to exacerbate their illness and they become more paranoid , delusional , and less likely to follow directions , according to leifman . so , they end up on the ninth floor severely mentally disturbed , but not getting any real help because they 're in jail . we toured the jail with leifman . he is well known in miami as an advocate for justice and the mentally ill . even though we were not exactly welcomed with open arms by the guards , we were given permission to shoot videotape and tour the floor . go inside the ` forgotten floor ' '' at first , it 's hard to determine where the people are . the prisoners are wearing sleeveless robes . imagine cutting holes for arms and feet in a heavy wool sleeping bag -- that 's kind of what they look like . they 're designed to keep the mentally ill patients from injuring themselves . that 's also why they have no shoes , laces or mattresses . leifman says about one-third of all people in miami-dade county jails are mentally ill . so , he says , the sheer volume is overwhelming the system , and the result is what we see on the ninth floor . of course , it is a jail , so it 's not supposed to be warm and comforting , but the lights glare , the cells are tiny and it 's loud . we see two , sometimes three men -- sometimes in the robes , sometimes naked , lying or sitting in their cells . `` i am the son of the president . you need to get me out of here ! '' one man shouts at me . he is absolutely serious , convinced that help is on the way -- if only he could reach the white house . leifman tells me that these prisoner-patients will often circulate through the system , occasionally stabilizing in a mental hospital , only to return to jail to face their charges . it 's brutally unjust , in his mind , and he has become a strong advocate for changing things in miami . over a meal later , we talk about how things got this way for mental patients . leifman says 200 years ago people were considered `` lunatics '' and they were locked up in jails even if they had no charges against them . they were just considered unfit to be in society . over the years , he says , there was some public outcry , and the mentally ill were moved out of jails and into hospitals . but leifman says many of these mental hospitals were so horrible they were shut down . where did the patients go ? nowhere . the streets . they became , in many cases , the homeless , he says . they never got treatment . leifman says in 1955 there were more than half a million people in state mental hospitals , and today that number has been reduced 90 percent , and 40,000 to 50,000 people are in mental hospitals . the judge says he 's working to change this . starting in 2008 , many inmates who would otherwise have been brought to the `` forgotten floor '' will instead be sent to a new mental health facility -- the first step on a journey toward long-term treatment , not just punishment . leifman says it 's not the complete answer , but it 's a start . leifman says the best part is that it 's a win-win solution . the patients win , the families are relieved , and the state saves money by simply not cycling these prisoners through again and again . and , for leifman , justice is served . e-mail to a friend .\\n\",\n", - " 'tgt': \" mentally ill inmates in miami are housed on the `` forgotten floor '' judge steven leifman says most are there as a result of `` avoidable felonies '' while cnn tours facility , patient shouts : `` i am the son of the president '' leifman says the system is unjust and he 's fighting for change . \\n\",\n", - " 'tgt_txt': \" mentally ill inmates in miami are housed on the `` forgotten floor '' judge steven leifman says most are there as a result of `` avoidable felonies '' while cnn tours facility , patient shouts : `` i am the son of the president '' leifman says the system is unjust and he 's fighting for change . \\n\",\n", - " 'source_ids': tensor([21603, 10, 6005, ..., 2027, 2957, 307]),\n", - " 'source_mask': tensor([1, 1, 1, ..., 1, 1, 1]),\n", - " 'target_ids': tensor([ 3, 2, 17, 3155, 19367, 3, 1092, 16, 11171, 16,\n", - " 1337, 3690, 33, 629, 26, 30, 8, 3, 2, 11821,\n", - " 1501, 3, 31, 31, 3, 2, 87, 17, 3155, 3,\n", - " 2, 17, 3155, 5191, 3, 849, 1926, 90, 99, 348,\n", - " 845, 167, 33, 132, 38, 3, 9, 741, 13, 3,\n", - " 2, 1792, 179, 3110, 106, 725])}" - ] - }, - "execution_count": 12, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "abs_sum_train[0]" ] @@ -419,7 +511,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 9, "metadata": { "tags": [ "parameters" @@ -456,7 +548,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 10, "metadata": { "scrolled": true }, @@ -474,22 +566,11 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": { "scrolled": true }, - "outputs": [ - { - "data": { - "text/plain": [ - "'\\n\\nsummarizer.fit(\\n ext_sum_train,\\n num_gpus=NUM_GPUS,\\n batch_size=BATCH_SIZE,\\n gradient_accumulation_steps=2,\\n max_steps=MAX_STEPS,\\n learning_rate=LEARNING_RATE,\\n warmup_steps=WARMUP_STEPS,\\n verbose=True,\\n report_every=REPORT_EVERY,\\n clip_grad_norm=False,\\n use_preprocessed_data=USE_PREPROCSSED_DATA\\n )\\n\\n'" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "\"\"\"\n", "\n", @@ -512,20 +593,9 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'\\nsummarizer.save_model(\\n os.path.join(\\n CACHE_DIR,\\n \"extsum_modelname_{0}_usepreprocess{1}_steps_{2}.pt\".format(\\n MODEL_NAME, USE_PREPROCSSED_DATA, MAX_STEPS\\n ),\\n )\\n)\\n'" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "\"\"\"\n", "summarizer.save_model(\n", @@ -541,20 +611,9 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'\\nimport torch\\nmodel_path = os.path.join(\\n CACHE_DIR,\\n \"extsum_modelname_{0}_usepreprocess{1}_steps_{2}.pt\".format(\\n MODEL_NAME, USE_PREPROCSSED_DATA, MAX_STEPS\\n ))\\nsummarizer = ExtractiveSummarizer(processor, MODEL_NAME, ENCODER, MAX_POS_LENGTH, CACHE_DIR)\\nsummarizer.model.load_state_dict(torch.load(model_path, map_location=\"cpu\"))\\n'" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "# for loading a previous saved model\n", "\"\"\"\n", @@ -580,27 +639,16 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "dict_keys(['src', 'src_txt', 'tgt', 'tgt_txt', 'source_ids', 'source_mask', 'target_ids'])" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ "abs_sum_test[0].keys()" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 8, "metadata": {}, "outputs": [], "source": [ @@ -608,21 +656,21 @@ "target = []\n", "for i in abs_sum_test:\n", " source.append(i[\"src_txt\"]) \n", - " target.append(i['tgt']) " + " target.append(i['tgt'].replace(\"\",\"\").replace(\"\", \"\").replace(\"\\n\", \"\")) " ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 13, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "\" marseille prosecutor says `` so far no videos were used in the crash investigation '' despite media reports . journalists at bild and paris match are `` very confident '' the video clip is real , an editor says . andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says . \\n\"" + "\" marseille prosecutor says `` so far no videos were used in the crash investigation '' despite media reports . journalists at bild and paris match are `` very confident '' the video clip is real , an editor says . andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says . \"" ] }, - "execution_count": 20, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -633,473 +681,127 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\n", - "\n", - "\n", - "Generating summary: 0%| | 0/3 [00:00'" - ] - }, - "execution_count": 37, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "prediction[0]" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [ + }, { - "data": { - "text/plain": [ - "20" - ] - }, - "execution_count": 29, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "len(prediction)" - ] - }, - { - "cell_type": "code", - "execution_count": 30, - "metadata": {}, - "outputs": [ + "name": "stdout", + "output_type": "stream", + "text": [ + "dataset length is 1024\n" + ] + }, { - "data": { - "text/plain": [ - "[\"marseille , france - cnn - the french prosecutor leading an investigation into the crash of germanwings flight 9525 insisted wednesday that he was not aware of any video footage from on board the plane . marseille prosecutor brice robin told cnn that so far no videos were used in the crash investigation . '' robin 's comments follow claims by two magazines , german daily bild and french paris match , of\",\n", - " \"the formal accession was marked with a ceremony at the hague , in the netherlands , where the court is based . the palestinians signed the icc's founding rome statute in january . as members of the court , palestinians may be subject to counter-charges as well .\",\n", - " \"amnesty international's annual report catalogs the use of state-sanctioned killing as a punitive measure across the globe . at least 607 people were executed around the world in 2014 . at least 2,466 people globally are confirmed to have been handed the sentence .\",\n", - " 'amnesty international releases annual review of death penalty worldwide . death sentences up 500% on previous year, mostly because of pakistan . countries using death penalty to tackle crime and terrorism are deceiving themselves .',\n", - " 'anne frank died of typhus in a nazi concentration camp at the age of 15 . new research shows that she and her older sister probably did not survive to march 1945 . the exact dates of death for anne and margot remain unclear .',\n", - " \"a duke student has admitted to hanging a noose made of rope from a tree near a student union , university officials said thursday . the prestigious private school did n't identify the student , citing federal privacy laws . in a news release , it said the student was no longer on campus and will face student conduct review .\",\n", - " \"the rev. robert h. schuller , california televangelist and founder of the television ministry hour of power , '' died thursday , according to his family . he was 88 years old . schuller , also the founder of crystal cathedral megachurch , had been diagnosed with esophageal cancer in august 2013 .\",\n", - " 'the dog was hit by a car, apparently killed with a hammer and buried in a field . she staggered to a nearby farm, dirt-covered and emaciated . the dog was found by a worker who took her to a vet for help . she suffered a dislocated jaw, leg injuries and a caved-in sinus cavity .',\n", - " \"mohammad javad zarif is the iranian foreign minister . he has been john kerry's opposite number in securing nuclear breakthrough . zarif is 54, but his official biography says he was born in 1960 . he was investigated by the feds over his alleged role in controlling a charitable organization .\",\n", - " \"for the first time in eight years , a tv legend returned to doing what he does best . contestants told to come on down ! '' on the april 1 edition of the price is right '' encountered not host drew carey but another familiar face in charge of the proceedings . instead , there was bob barker , who hosted the tv game show for 35 years before stepping down in 2007 .\",\n", - " \"-lrb- he 's a blue chip college basketball recruit . she 's a high school freshman with down syndrome . at first glance trey moses and ellie meredith could n't be more different . but all that changed thursday when trey asked ellie to be his prom date .\",\n", - " \"michele bachmann compared president obama to the co-pilot of the doomed germanwings flight . ''with his iran deal , barack obama is for the 300 million souls of the united states,'' she wrote in a facebook comment posted march 31 . many comments posted on her facebook page blasted the former representative .\",\n", - " 'california is a breadbasket to the nation . california is growing more than a third of its vegetables and nearly two-thirds of its fruits and nuts . the drought is in its fourth year .',\n", - " \"walmart's staunch criticism of a religious freedom law in its home state of arkansas came after the company said in february it would boost pay for about 500,000 workers well above the federal minimum wage . the company is emerging as a bellwether for shifting public opinion on hot-button political issues that divide conservatives and liberals . former minnesota gov. tim pawlenty said walmart's actions foreshadow where the republican party will need to move \",\n", - " \"five americans who were monitored for three weeks at an omaha hospital have been released . one of the five had a heart-related issue on saturday and has been discharged but has n't left the area . they were exposed to ebola in sierra leone in march .\",\n", - " \"andrew getty, 47, was found dead in his los angeles home . the coroner's preliminary assessment is there was no foul play involved in the death of getty . getty , grandson of oil tycoon j. paul getty , was found dead near a bathroom in his home .\",\n", - " \"mike pence signed a religious freedom law last week that opens the door to discrimination against gays and lesbians . pence: 'i foolishly hoped this kind of backlash . ''there is no way a republican can get through the pending primary without denouncing lgbt rights,'' he says .\",\n", - " \"filipinos are being warned to be on guard for flash floods and landslides as tropical storm maysak approached the asian island nation saturday . just a few days ago , maysak gained super typhoon status thanks to its sustained 150 mph winds . it has since lost a lot of steam as it has spun west in the pacific ocean . it 's now classified as a tropical storm , according to the philippine national weather\",\n", - " \"norfolk , virginia - the second mate of the houston express probably could n't believe what he was seeing . hundreds of miles from land there was a small boat nearby . at first it looked abandoned . it was in bad shape , listing to one side . the crew of the 1,000-foot long container ship thought it was a yacht that had wrecked . incredibly , as they got closer , they saw there was a man on it \",\n", - " \"walker died in november 2013 after a fiery car crash . the release of furious 7 '' on friday offers fans the opportunity to remember -- and possibly grieve again -- the man that so many have praised as one of the nicest guys in hollywood .\"]" - ] - }, - "execution_count": 30, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "prediction" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ + "name": "stderr", + "output_type": "stream", + "text": [ + "Generating summary: 100%|██████████| 16/16 [04:59<00:00, 17.81s/it]\n" + ] + }, { - "data": { - "text/plain": [ - "[\"marseille , france - cnn - the french prosecutor leading an investigation into the crash of germanwings flight 9525 insisted wednesday that he was not aware of any video footage from on board the plane . marseille prosecutor brice robin told cnn that so far no videos were used in the crash investigation . '' robin 's comments follow claims by two magazines , german daily bild and french paris match , of\",\n", - " \"the formal accession was marked with a ceremony at the hague , in the netherlands , where the court is based . the palestinians signed the icc's founding rome statute in january . as members of the court , palestinians may be subject to counter-charges as well .\",\n", - " \"amnesty international's annual report catalogs the use of state-sanctioned killing as a punitive measure across the globe . at least 607 people were executed around the world in 2014 . at least 2,466 people globally are confirmed to have been handed the sentence .\",\n", - " 'amnesty international releases annual review of death penalty worldwide . death sentences up 500% on previous year, mostly because of pakistan . countries using death penalty to tackle crime and terrorism are deceiving themselves .',\n", - " 'anne frank died of typhus in a nazi concentration camp at the age of 15 . new research shows that she and her older sister probably did not survive to march 1945 . the exact dates of death for anne and margot remain unclear .',\n", - " \"a duke student has admitted to hanging a noose made of rope from a tree near a student union , university officials said thursday . the prestigious private school did n't identify the student , citing federal privacy laws . in a news release , it said the student was no longer on campus and will face student conduct review .\",\n", - " \"the rev. robert h. schuller , california televangelist and founder of the television ministry hour of power , '' died thursday , according to his family . he was 88 years old . schuller , also the founder of crystal cathedral megachurch , had been diagnosed with esophageal cancer in august 2013 .\",\n", - " 'the dog was hit by a car, apparently killed with a hammer and buried in a field . she staggered to a nearby farm, dirt-covered and emaciated . the dog was found by a worker who took her to a vet for help . she suffered a dislocated jaw, leg injuries and a caved-in sinus cavity .',\n", - " \"mohammad javad zarif is the iranian foreign minister . he has been john kerry's opposite number in securing nuclear breakthrough . zarif is 54, but his official biography says he was born in 1960 . he was investigated by the feds over his alleged role in controlling a charitable organization .\",\n", - " \"for the first time in eight years , a tv legend returned to doing what he does best . contestants told to come on down ! '' on the april 1 edition of the price is right '' encountered not host drew carey but another familiar face in charge of the proceedings . instead , there was bob barker , who hosted the tv game show for 35 years before stepping down in 2007 .\",\n", - " \"-lrb- he 's a blue chip college basketball recruit . she 's a high school freshman with down syndrome . at first glance trey moses and ellie meredith could n't be more different . but all that changed thursday when trey asked ellie to be his prom date .\",\n", - " \"michele bachmann compared president obama to the co-pilot of the doomed germanwings flight . ''with his iran deal , barack obama is for the 300 million souls of the united states,'' she wrote in a facebook comment posted march 31 . many comments posted on her facebook page blasted the former representative .\",\n", - " 'california is a breadbasket to the nation . california is growing more than a third of its vegetables and nearly two-thirds of its fruits and nuts . the drought is in its fourth year .',\n", - " \"walmart's staunch criticism of a religious freedom law in its home state of arkansas came after the company said in february it would boost pay for about 500,000 workers well above the federal minimum wage . the company is emerging as a bellwether for shifting public opinion on hot-button political issues that divide conservatives and liberals . former minnesota gov. tim pawlenty said walmart's actions foreshadow where the republican party will need to move \",\n", - " \"five americans who were monitored for three weeks at an omaha hospital have been released . one of the five had a heart-related issue on saturday and has been discharged but has n't left the area . they were exposed to ebola in sierra leone in march .\",\n", - " \"andrew getty, 47, was found dead in his los angeles home . the coroner's preliminary assessment is there was no foul play involved in the death of getty . getty , grandson of oil tycoon j. paul getty , was found dead near a bathroom in his home .\",\n", - " \"mike pence signed a religious freedom law last week that opens the door to discrimination against gays and lesbians . pence: 'i foolishly hoped this kind of backlash . ''there is no way a republican can get through the pending primary without denouncing lgbt rights,'' he says .\",\n", - " \"filipinos are being warned to be on guard for flash floods and landslides as tropical storm maysak approached the asian island nation saturday . just a few days ago , maysak gained super typhoon status thanks to its sustained 150 mph winds . it has since lost a lot of steam as it has spun west in the pacific ocean . it 's now classified as a tropical storm , according to the philippine national weather\",\n", - " \"norfolk , virginia - the second mate of the houston express probably could n't believe what he was seeing . hundreds of miles from land there was a small boat nearby . at first it looked abandoned . it was in bad shape , listing to one side . the crew of the 1,000-foot long container ship thought it was a yacht that had wrecked . incredibly , as they got closer , they saw there was a man on it \",\n", - " \"walker died in november 2013 after a fiery car crash . the release of furious 7 '' on friday offers fans the opportunity to remember -- and possibly grieve again -- the man that so many have praised as one of the nicest guys in hollywood .\"]" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 9min 11s, sys: 1min 25s, total: 10min 37s\n", + "Wall time: 5min\n" + ] } ], "source": [ - "prediction" + "\n", + "%%time\n", + "prediction = summarizer.predict(abs_sum_test[0:256*4], num_gpus=NUM_GPUS, batch_size=64) " ] }, { @@ -1107,148 +809,39 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['marseille , france - cnn - the french prosecutor leading',\n", - " 'the formal accession was marked with a ceremony at the hague , in the net',\n", - " \"amnesty international's annual report catalogs the use of state-sanction\",\n", - " 'amnesty international releases annual review of death penalty worldwide . death sentences up 500%',\n", - " 'anne frank died of typhus in a nazi concentration camp at the',\n", - " 'a duke student has admitted to hanging a noose made of rope from',\n", - " 'the rev. robert h. schuller , californ',\n", - " 'the dog was hit by a car, apparently killed with a hammer and ',\n", - " 'mohammad javad zarif is the iranian foreign minister ',\n", - " 'for the first time in eight years , a tv legend returned to doing what',\n", - " \"-lrb- he 's a blue chip college basketball recruit \",\n", - " 'michele bachmann compared president obama to the co-pilot of',\n", - " 'california is a breadbasket to the nation . cali',\n", - " \"walmart's staunch criticism of a religious freedom law in its home state of ark\",\n", - " 'five americans who were monitored for three weeks at an omaha hospital have been',\n", - " 'andrew getty, 47, was found dead in his los angeles home',\n", - " 'mike pence signed a religious freedom law last week that opens the door to discrimin',\n", - " 'filipinos are being warned to be on guard for flash floods and landsl',\n", - " 'norfolk , virginia - the second mate of the ',\n", - " 'walker died in november 2013 after a fiery car crash . the release']" - ] - }, - "execution_count": 18, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ - "prediction" + "torch.save(prediction, \"prediction.pt\")" ] }, { "cell_type": "code", - "execution_count": 43, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['bild and paris match claim to have found a video of the crash site .',\n", - " 'palestinian authority officially became the 123rd member of the international criminal court on we',\n", - " \"amnesty international's annual report catalogs the use of state-sanction\",\n", - " '. in egypt and nigeria , courts imposed',\n", - " 'anne frank died of typhus in a nazi concentration camp at the',\n", - " 'student admitted to hanging the noose from a tree near a student union ',\n", - " \". schuller 's legacy is a televangelist \",\n", - " '. a dog in california , found seemingly dead after',\n", - " 'zarif is the iranian foreign minister . he has been ',\n", - " \", he didn't seem to miss a beat . bob bark\",\n", - " '. . . . . . . . .',\n", - " 'michele bachmann compared president obama to the co-pilot of',\n", - " '. . . . . . . . .',\n", - " \"'' the republican party will have to better stand for '' ideas\",\n", - " 'of an infected person . the last of 17 patients who were being monitored are',\n", - " 'found dead in his home . he was found on his side near a bathroom',\n", - " '. . . . . . . . ',\n", - " \". ''we do not know what the impact will be once it will make\",\n", - " '. he was rescued . he spent most of his days in the',\n", - " \"walker 's death .walker 's death is not the first actor\"]" - ] - }, - "execution_count": 43, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "prediction" - ] - }, - { - "cell_type": "code", - "execution_count": 55, + "execution_count": 3, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['marseille , france - cnn - the french prosecutor leading',\n", - " 'the formal accession was marked with a ceremony at the hague , in the net',\n", - " \"amnesty international's annual report catalogs the use of state-sanction\",\n", - " 'amnesty international releases annual review of death penalty worldwide . death sentences up 500%',\n", - " 'anne frank died of typhus in a nazi concentration camp at the',\n", - " 'a duke student has admitted to hanging a noose made of rope from',\n", - " 'the rev. robert h. schuller , californ',\n", - " 'the dog was hit by a car, apparently killed with a hammer and ',\n", - " 'mohammad javad zarif is the iranian foreign minister ',\n", - " 'for the first time in eight years , a tv legend returned to doing what',\n", - " \"-lrb- he 's a blue chip college basketball recruit \",\n", - " 'michele bachmann compared president obama to the co-pilot of',\n", - " 'california is a breadbasket to the nation . cali',\n", - " \"walmart's staunch criticism of a religious freedom law in its home state of ark\",\n", - " 'five americans who were monitored for three weeks at an omaha hospital have been',\n", - " 'andrew getty, 47, was found dead in his los angeles home',\n", - " 'mike pence signed a religious freedom law last week that opens the door to discrimin',\n", - " 'filipinos are being warned to be on guard for flash floods and landsl',\n", - " 'norfolk , virginia - the second mate of the ',\n", - " 'walker died in november 2013 after a fiery car crash . the release']" - ] - }, - "execution_count": 55, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "prediction" + "prediction = torch.load(\"prediction.pt\")" ] }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Number of candidates: 20\n", - "Number of references: 1000\n" - ] - }, - { - "ename": "AssertionError", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mrouge_scores\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcompute_rouge_python\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcand\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mprediction\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mref\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtarget\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0mpprint\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrouge_scores\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/dadendev/nlp-recipes/utils_nlp/eval/rouge/compute_rouge.py\u001b[0m in \u001b[0;36mcompute_rouge_python\u001b[0;34m(cand, ref, is_input_files, language)\u001b[0m\n\u001b[1;32m 116\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Number of candidates: {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcandidates\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 117\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Number of references: {}\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreferences\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 118\u001b[0;31m \u001b[0;32massert\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcandidates\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreferences\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 119\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 120\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlanguage\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0;34m\"en\"\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mAssertionError\u001b[0m: " + "Number of candidates: 11490\n", + "Number of references: 11490\n", + "{'rouge-1': {'f': 0.43366650568906373,\n", + " 'p': 0.3878160218865652,\n", + " 'r': 0.5256684651435454},\n", + " 'rouge-2': {'f': 0.2013283037622797,\n", + " 'p': 0.18073246915601657,\n", + " 'r': 0.24341426947272182},\n", + " 'rouge-l': {'f': 0.2945220878967815,\n", + " 'p': 0.26331547921506626,\n", + " 'r': 0.3573184457546521}}\n" ] } ], diff --git a/utils_nlp/models/transformers/abstractive_summarization_bartt5.py b/utils_nlp/models/transformers/abstractive_summarization_bartt5.py index 813bcd230..7bfb99c26 100644 --- a/utils_nlp/models/transformers/abstractive_summarization_bartt5.py +++ b/utils_nlp/models/transformers/abstractive_summarization_bartt5.py @@ -74,7 +74,8 @@ from transformers.tokenization_utils import trim_batch -def encode_example(example, tokenizer=None, prefix="", max_source_length=None, max_target_length=None, pad_to_max_length=True, return_tensors="pt"): +from tempfile import TemporaryDirectory +def encode_example(example, tokenizer, prefix="", max_source_length=None, max_target_length=None, pad_to_max_length=True, return_tensors="pt"): ## add to the dataset tokenized_source = tokenizer.batch_encode_plus( [prefix + example['src']], max_length=max_source_length, pad_to_max_length=pad_to_max_length, return_tensors=return_tensors, @@ -92,50 +93,19 @@ def encode_example(example, tokenizer=None, prefix="", max_source_length=None, m example["target_ids"] = target_ids return example -def parallel_preprocess(input_data, preprocess, num_pool=-1): - """ - Process data in parallel using multiple GPUs. - - Args: - input_data (list): List if input strings to process. - preprocess (function): function to apply on the input data. - word_tokenize (func, optional): A tokenization function used to tokenize - the results from preprocess_pipeline. - num_pool (int, optional): Number of CPUs to use. Defaults to -1 and all - available CPUs are used. - - Returns: - list: list of processed text strings. - - """ - if num_pool == -1: - num_pool = cpu_count() - - num_pool = min(num_pool, len(input_data)) - - result = None - with Pool(num_pool) as p: - results = p.map( - preprocess, input_data, chunksize=max(1, int(len(input_data) / num_pool)), - ) - - p.close() - #p.join() - - return results class Predictor(nn.Module): def __init__( self, model, - tokenizer, min_length, - max_length): + max_length, + **kwargs): super(Predictor, self).__init__() self.model = model.module if hasattr(model, "module") else model - self.tokenizer = tokenizer self.min_length = min_length self.max_length = max_length + self.config = kwargs def forward(self, src, src_mask): device = src.device @@ -144,9 +114,9 @@ def forward(self, src, src_mask): input_ids=src, attention_mask=src_mask, min_length=self.min_length, - max_length=self.max_length + max_length=self.max_length, + **self.config ) - print(summaries) predictions = torch.tensor( [ i.tolist()[0 : self.max_length] @@ -157,9 +127,6 @@ def forward(self, src, src_mask): ) return predictions - #decoded_summaries = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries] - #print(decoded_summaries) - #return decoded_summaries class SummarizationProcessor: def __init__( @@ -194,11 +161,12 @@ def __init__( # self.target = source_examples #encode_file(tokenizer, os.path.join(data_dir, type_path + ".target"), max_target_length) def preprocess(self, input_data_list): - preprocess = functools.partial( - encode_example, tokenizer=self.tokenizer, prefix=self.prefix, max_source_length=self.max_source_length, max_target_length=self.max_target_length - ) - - return parallel_preprocess(input_data_list, preprocess, num_pool=-1) + result = [] + for i in input_data_list: + result.append(encode_example(i, tokenizer=self.tokenizer, prefix=self.prefix, + max_source_length=self.max_source_length, max_target_length=self.max_target_length )) + return result + @staticmethod def trim_seq2seq_batch(batch, pad_token_id): @@ -474,10 +442,13 @@ def predict( gpu_ids=None, local_rank=-1, batch_size=16, - alpha=0.6, - beam_size=5, - min_length=15, - max_length=150, + length_penalty=0.95, + beam_size=4, + min_length=50, + max_length=200, + no_repeat_ngram_size=3, + early_stopping=True, + fp16=False, verbose=True, ): @@ -544,7 +515,7 @@ def collate_fn(data): ) print("dataset length is {}".format(len(test_dataset))) - predictor = Predictor(self.model, self.tokenizer, min_length, max_length) + predictor = Predictor(self.model, min_length, max_length) # move model to devices def this_model_move_callback(model, device): model = move_model_to_device(model, device) @@ -577,6 +548,7 @@ def this_model_move_callback(model, device): # release GPU memories self.model.cpu() + del batch torch.cuda.empty_cache() return generated_summaries From 1ae0a03e7570013b0db0cbf495a3f1b2f48b1d9b Mon Sep 17 00:00:00 2001 From: Daisy Deng Date: Tue, 5 May 2020 21:51:21 +0000 Subject: [PATCH 07/14] initial finetune code --- examples/text_summarization/test_bartt5.py | 82 ++++++++++++++ .../abstractive_summarization_bartt5.py | 103 +++++++++++------- 2 files changed, 146 insertions(+), 39 deletions(-) create mode 100644 examples/text_summarization/test_bartt5.py diff --git a/examples/text_summarization/test_bartt5.py b/examples/text_summarization/test_bartt5.py new file mode 100644 index 000000000..64ac5e15c --- /dev/null +++ b/examples/text_summarization/test_bartt5.py @@ -0,0 +1,82 @@ +QUICK_RUN = False +import os +import shutil +import sys +from tempfile import TemporaryDirectory +import torch + +nlp_path = os.path.abspath("../../") +if nlp_path not in sys.path: + sys.path.insert(0, nlp_path) + +from utils_nlp.dataset.cnndm import CNNDMBertSumProcessedData, CNNDMSummarizationDataset +from utils_nlp.eval import compute_rouge_python, compute_rouge_perl +from utils_nlp.models.transformers.abstractive_summarization_bartt5 import AbstractiveSummarizer, SummarizationProcessor + +from utils_nlp.models.transformers.datasets import SummarizationDataset +import nltk +from nltk import tokenize + +import pandas as pd +import scrapbook as sb +import pprint + +QUICK_RUN = True +MODEL_NAME = "bart-large" +CACHE_DIR = "./bartt5_cache" #TemporaryDirectory().name + +#processor = SummarizationProcessor(MODEL_NAME,cache_dir=CACHE_DIR ) #tokenizer, config.prefix) +DATA_PATH = "./bartt5_cnndm" #TemporaryDirectory().name +# The number of lines at the head of data file used for preprocessing. -1 means all the lines. +TOP_N = 1000 +if not QUICK_RUN: + TOP_N = -1 +#train_dataset, test_dataset = CNNDMSummarizationDataset(top_n=TOP_N, local_cache_path=DATA_PATH, raw=True) +#abs_sum_train = processor.preprocess(train_dataset) +#torch.save(abs_sum_train, os.path.join(DATA_PATH, "train_full_2.pt")) +abs_sum_train = torch.load(os.path.join(DATA_PATH, "train_full.pt")) + + +BATCH_SIZE = 8 # batch size, unit is the number of samples +MAX_POS_LENGTH = 512 +# GPU used for training +NUM_GPUS = torch.cuda.device_count() +# Learning rate +LEARNING_RATE=3e-5 +# How often the statistics reports show up in training, unit is step. +REPORT_EVERY=20 +# total number of steps for training +MAX_STEPS=1e2 +# number of steps for warm up +WARMUP_STEPS=5e2 +if not QUICK_RUN: + MAX_STEPS=5e4 + WARMUP_STEPS=5e3 + + +summarizer = AbstractiveSummarizer(MODEL_NAME, cache_dir=CACHE_DIR) +processor = summarizer.processor +train_dataset, test_dataset = CNNDMSummarizationDataset(top_n=TOP_N, local_cache_path=DATA_PATH, raw=True) +abs_sum_train = processor.preprocess(train_dataset) + +summarizer.fit( + abs_sum_train, + num_gpus=NUM_GPUS, + batch_size=BATCH_SIZE, + gradient_accumulation_steps=1, + max_steps=MAX_STEPS, + learning_rate=LEARNING_RATE, + warmup_steps=WARMUP_STEPS, + verbose=True, + report_every=REPORT_EVERY, + clip_grad_norm=False, + ) +summarizer.save_model( + os.path.join( + CACHE_DIR, + "abssum_modelname_{0}_steps_{1}.pt".format( + MODEL_NAME, MAX_STEPS + ), + ) +) + diff --git a/utils_nlp/models/transformers/abstractive_summarization_bartt5.py b/utils_nlp/models/transformers/abstractive_summarization_bartt5.py index 7bfb99c26..b4ac35a73 100644 --- a/utils_nlp/models/transformers/abstractive_summarization_bartt5.py +++ b/utils_nlp/models/transformers/abstractive_summarization_bartt5.py @@ -53,17 +53,18 @@ "language-modeling": AutoModelWithLMHead, } -from transformers import BartForConditionalGeneration, BartTokenizer -from transformers import T5ForConditionalGeneration, T5Tokenizer +from transformers import BartForConditionalGeneration, BartTokenizer, BART_PRETRAINED_MODEL_ARCHIVE_MAP +from transformers import T5ForConditionalGeneration, T5Tokenizer, T5_PRETRAINED_MODEL_ARCHIVE_MAP -MODEL_CLASS = { - "bart-large-cnn": BartForConditionalGeneration, - "t5-large":T5ForConditionalGeneration -} -TOKENIZER_CLASS = { - "bart-large-cnn": BartTokenizer, - "t5-large": T5Tokenizer -} +MODEL_CLASS = {} +MODEL_CLASS.update({k: BartForConditionalGeneration for k in BART_PRETRAINED_MODEL_ARCHIVE_MAP}) +MODEL_CLASS.update({k: T5ForConditionalGeneration for k in T5_PRETRAINED_MODEL_ARCHIVE_MAP}) + +""" +TOKENIZER_CLASS = {} +TOKENIZER_CLASS.update({k: BartTokenizer for k in BART_PRETRAINED_MODEL_ARCHIVE_MAP }) +TOKENIZER_CLASS.update({k: T5Tokenizer for k in T5_PRETRAINED_MODEL_ARCHIVE_MAP}) +""" logger = logging.getLogger(__name__) @@ -137,7 +138,13 @@ def __init__( max_target_length=56, ): #super().__init__() - self.tokenizer = TOKENIZER_CLASS[model_name].from_pretrained(model_name, cache_dir=cache_dir) # b + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + #self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path, + cache_dir=cache_dir, + ) + + #TOKENIZER_CLASS[model_name].from_pretrained(model_name, cache_dir=cache_dir) # b config = AutoConfig.from_pretrained( model_name, #self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path, @@ -152,13 +159,9 @@ def __init__( config.update(task_specific_params.get("summarization", {})) self.prefix = config.prefix - #self.source = source_examples #encode_file(tokenizer, os.path.join(data_dir, type_path + ".source"), max_source_length) self.with_target = False self.max_source_length = max_source_length self.max_target_length = max_target_length - #if with_target: - # self.with_target = True - # self.target = source_examples #encode_file(tokenizer, os.path.join(data_dir, type_path + ".target"), max_target_length) def preprocess(self, input_data_list): result = [] @@ -166,13 +169,26 @@ def preprocess(self, input_data_list): result.append(encode_example(i, tokenizer=self.tokenizer, prefix=self.prefix, max_source_length=self.max_source_length, max_target_length=self.max_target_length )) return result - - + @staticmethod - def trim_seq2seq_batch(batch, pad_token_id): - y = trim_batch(batch["target_ids"], pad_token_id) - source_ids, source_mask = trim_batch(batch["source_ids"], pad_token_id, attention_mask=batch["source_mask"]) - return source_ids, source_mask, y + def get_inputs(batch, device, model_name, tokenizer=None, train_mode=True): + pad_token_id = tokenizer.pad_token_id + source_ids, source_mask, y = batch["source_ids"], batch["source_mask"], batch["target_ids"] + y_ids = y[:, :-1].contiguous() + lm_labels = y[:, 1:].clone() + lm_labels[y[:, 1:] == pad_token_id] = -100 + #outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, lm_labels=lm_labels,) + + if train_mode: + return {"input_ids": source_ids, + "attention_mask": source_mask, + "decoder_input_ids": y_ids, + "lm_labels": lm_labels + } + else: + return {"input_ids": source_ids, + "attention_mask": source_mask, + } def collate_fn(self, batch, device, train_mode=False): input_ids = torch.stack([x["source_ids"] for x in batch]) @@ -193,8 +209,8 @@ class AbstractiveSummarizer(Transformer): def __init__( self, - processor, - model_name="bart-large-cnn", + #processor, + model_name="bart-large", cache_dir=".", max_source_length=1024, max_target_length=240 @@ -225,7 +241,7 @@ def __init__( "names.".format(value) ) """ - self.processor = processor + self.processor = SummarizationProcessor(model_name, cache_dir, max_source_length, max_target_length) self.config = AutoConfig.from_pretrained( model_name, #self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path, @@ -253,6 +269,9 @@ def __init__( self.max_source_length = max_source_length self.max_target_length = max_target_length + self.amp = None + self.optimizer = None + self.scheduler = None @staticmethod def list_supported_models(): @@ -351,7 +370,7 @@ def fit( """ # move model to devices - print("device is {}".format(device)) + checkpoint_state_dict = None if checkpoint: # checkpoint should have "model", "optimizer", "amp" checkpoint_state_dict = torch.load(checkpoint, map_location="cpu") @@ -366,11 +385,13 @@ def fit( weight_decay=weight_decay, learning_rate=learning_rate, adam_epsilon=adam_epsilon, - checkpoint_state_dict=checkpoint_state_dict, + # checkpoint_state_dict=checkpoint_state_dict, ) + + self.amp = amp global_step = 0 - if "global_step" in checkpoint_state_dict and checkpoint_state_dict["global_step"]: + if checkpoint_state_dict and "global_step" in checkpoint_state_dict and checkpoint_state_dict["global_step"]: global_step = checkpoint_state_dict["global_step"] / world_size print("global_step is {}".format(global_step)) @@ -389,11 +410,13 @@ def fit( train_dataset, num_replicas=world_size, rank=rank ) + def collate_fn(data): - return self.processor.collate( - data, block_size=self.max_pos_length, device=device + return self.processor.collate_fn( + data, device, train_mode=True ) + train_dataloader = DataLoader( train_dataset, sampler=sampler, @@ -407,10 +430,11 @@ def collate_fn(data): max_steps=max_steps, gradient_accumulation_steps=gradient_accumulation_steps, ) - + import functools + get_inputs = functools.partial(self.processor.get_inputs, tokenizer=self.processor.tokenizer) super().fine_tune( train_dataloader=train_dataloader, - get_inputs=xxxx.get_inputs, + get_inputs=get_inputs, device=device, num_gpus=num_gpus, max_steps=max_steps, @@ -421,12 +445,11 @@ def collate_fn(data): seed=seed, report_every=report_every, save_every=save_every, - clip_grad_norm=False, - optimizer=optimizers, - scheduler=None, + optimizer=self.optimizer, + scheduler=self.scheduler, fp16=fp16, - amp=self.amp, - validation_function=validation_function, + amp=amp, + validation_function=None, ) # release GPU memories @@ -572,18 +595,20 @@ def save_model(self, global_step=None, full_name=None): output_model_dir = os.path.join(self.cache_dir, "fine_tuned") os.makedirs(self.cache_dir, exist_ok=True) os.makedirs(output_model_dir, exist_ok=True) - full_name = os.path.join(output_model_dir, "bertsumabs.pt") + full_name = os.path.join(output_model_dir, "abssum_{}.pt".format(self.model_name)) else: path, filename = os.path.split(full_name) print(path) os.makedirs(path, exist_ok=True) checkpoint = { - "optimizers": [self.optim_bert.state_dict(), self.optim_dec.state_dict()], + "optimizer": self.optimizer.state_dict(), + "lr_scheduler": self.scheduler.state_dict(), "model": model_to_save.state_dict(), "amp": self.amp.state_dict() if self.amp else None, "global_step": global_step, - "max_pos_length": self.max_pos_length, + "max_source_length": self.max_source_length, + "max_target_length": self.max_target_length, } logger.info("Saving model checkpoint to %s", full_name) From db3676764d7c92d04045f2913c6f5f8087722fb1 Mon Sep 17 00:00:00 2001 From: Daisy Deng Date: Mon, 11 May 2020 21:43:38 +0000 Subject: [PATCH 08/14] dataset preprocessing change; added validation function --- ...ive_summarization_cnndm_transformers.ipynb | 3018 +++++++++++++++-- examples/text_summarization/test_bartt5.py | 54 +- utils_nlp/dataset/cnndm.py | 17 + .../abstractive_summarization_bartt5.py | 194 +- utils_nlp/models/transformers/common.py | 10 +- 5 files changed, 3002 insertions(+), 291 deletions(-) diff --git a/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb b/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb index 8f844a048..f2765c72c 100644 --- a/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb +++ b/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb @@ -44,19 +44,35 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 47, "metadata": { "tags": [ "parameters" ] }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The autoreload extension is already loaded. To reload it, use:\n", + " %reload_ext autoreload\n" + ] + } + ], "source": [ "%load_ext autoreload\n", "\n", "%autoreload 2\n", "## Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.\n", - "QUICK_RUN = False\n" + "QUICK_RUN = False" ] }, { @@ -117,7 +133,8 @@ "\n", "from utils_nlp.dataset.cnndm import CNNDMBertSumProcessedData, CNNDMSummarizationDataset\n", "from utils_nlp.eval import compute_rouge_python, compute_rouge_perl\n", - "from utils_nlp.models.transformers.abstractive_summarization_bartt5 import AbstractiveSummarizer, SummarizationProcessor\n", + "from utils_nlp.models.transformers.abstractive_summarization_bartt5 import (\n", + " AbstractiveSummarizer, SummarizationProcessor, validate)\n", "\n", "from utils_nlp.models.transformers.datasets import SummarizationDataset\n", "import nltk\n", @@ -154,7 +171,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 28, "metadata": { "tags": [ "parameters" @@ -164,7 +181,72 @@ "source": [ "# Transformer model being used\n", "#MODEL_NAME = \"t5-large\"\n", - "MODEL_NAME = \"bart-large-cnn\"" + "MODEL_NAME = \"bart-large-cnn\"\n", + "# notebook parameters\n", + "# the cache data path during find tuning\n", + "CACHE_DIR = \"./bart_cache\" #TemporaryDirectory().name\n", + "summarizer = AbstractiveSummarizer(MODEL_NAME, cache_dir=CACHE_DIR)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['hello', 'Ġfrench', 's', 'df', 'a']" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# bart-large\n", + "summarizer.tokenizer.tokenize(\"hello frenchsdfa \")" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['hello', 'Ġn', 'lp', 'Ġam', 'azon', 'Ġch', 'ina']" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "summarizer.tokenizer.tokenize(\"hello nlp amazon china\")" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "summarizer.tokenizer" ] }, { @@ -174,34 +256,519 @@ "outputs": [ { "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "f2ea5f9a1088431282d2df7823104710", - "version_major": 2, - "version_minor": 0 - }, "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Downloading', max=898823, style=ProgressStyle(description_wid…" + "BartForConditionalGeneration(\n", + " (model): BartModel(\n", + " (shared): Embedding(50265, 1024, padding_idx=1)\n", + " (encoder): BartEncoder(\n", + " (embed_tokens): Embedding(50265, 1024, padding_idx=1)\n", + " (embed_positions): LearnedPositionalEmbedding(1026, 1024, padding_idx=1)\n", + " (layers): ModuleList(\n", + " (0): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (1): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (2): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (3): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (4): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (5): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (6): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (7): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (8): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (9): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (10): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (11): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " (layernorm_embedding): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (decoder): BartDecoder(\n", + " (embed_tokens): Embedding(50265, 1024, padding_idx=1)\n", + " (embed_positions): LearnedPositionalEmbedding(1026, 1024, padding_idx=1)\n", + " (layers): ModuleList(\n", + " (0): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (1): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (2): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (3): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (4): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (5): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (6): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (7): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (8): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (9): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (10): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (11): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " (layernorm_embedding): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + ")" ] }, + "execution_count": 4, "metadata": {}, - "output_type": "display_data" - }, + "output_type": "execute_result" + } + ], + "source": [ + "summarizer.model" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, + "data": { + "text/plain": [ + "BartConfig {\n", + " \"_num_labels\": 3,\n", + " \"activation_dropout\": 0.0,\n", + " \"activation_function\": \"gelu\",\n", + " \"add_final_layer_norm\": false,\n", + " \"architectures\": [\n", + " \"BartModel\",\n", + " \"BartForMaskedLM\",\n", + " \"BartForSequenceClassification\"\n", + " ],\n", + " \"attention_dropout\": 0.0,\n", + " \"bad_words_ids\": null,\n", + " \"bos_token_id\": 0,\n", + " \"classif_dropout\": 0.0,\n", + " \"d_model\": 1024,\n", + " \"decoder_attention_heads\": 16,\n", + " \"decoder_ffn_dim\": 4096,\n", + " \"decoder_layerdrop\": 0.0,\n", + " \"decoder_layers\": 12,\n", + " \"decoder_start_token_id\": 2,\n", + " \"do_sample\": false,\n", + " \"dropout\": 0.1,\n", + " \"early_stopping\": false,\n", + " \"encoder_attention_heads\": 16,\n", + " \"encoder_ffn_dim\": 4096,\n", + " \"encoder_layerdrop\": 0.0,\n", + " \"encoder_layers\": 12,\n", + " \"eos_token_id\": 2,\n", + " \"finetuning_task\": null,\n", + " \"id2label\": {\n", + " \"0\": \"LABEL_0\",\n", + " \"1\": \"LABEL_1\",\n", + " \"2\": \"LABEL_2\"\n", + " },\n", + " \"init_std\": 0.02,\n", + " \"is_decoder\": false,\n", + " \"is_encoder_decoder\": true,\n", + " \"label2id\": {\n", + " \"LABEL_0\": 0,\n", + " \"LABEL_1\": 1,\n", + " \"LABEL_2\": 2\n", + " },\n", + " \"length_penalty\": 1.0,\n", + " \"max_length\": 20,\n", + " \"max_position_embeddings\": 1024,\n", + " \"min_length\": 0,\n", + " \"model_type\": \"bart\",\n", + " \"no_repeat_ngram_size\": 0,\n", + " \"normalize_before\": false,\n", + " \"num_beams\": 1,\n", + " \"num_hidden_layers\": 12,\n", + " \"num_return_sequences\": 1,\n", + " \"output_attentions\": false,\n", + " \"output_hidden_states\": false,\n", + " \"output_past\": false,\n", + " \"pad_token_id\": 1,\n", + " \"prefix\": \" \",\n", + " \"pruned_heads\": {},\n", + " \"repetition_penalty\": 1.0,\n", + " \"scale_embedding\": false,\n", + " \"task_specific_params\": {\n", + " \"summarization\": {\n", + " \"early_stopping\": true,\n", + " \"length_penalty\": 2.0,\n", + " \"max_length\": 142,\n", + " \"min_length\": 56,\n", + " \"no_repeat_ngram_size\": 3,\n", + " \"num_beams\": 4\n", + " }\n", + " },\n", + " \"temperature\": 1.0,\n", + " \"top_k\": 50,\n", + " \"top_p\": 1.0,\n", + " \"torchscript\": false,\n", + " \"use_bfloat16\": false,\n", + " \"vocab_size\": 50265\n", + "}" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "summarizer.config" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "91031adf816a423da2966031a9731278", + "model_id": "fe888b456fcb438a8c330d3f7328cfe8", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Downloading', max=456318, style=ProgressStyle(description_wid…" + "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1300.0, style=ProgressStyle(description…" ] }, "metadata": {}, @@ -217,12 +784,12 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "bda49783badc4f069f0e57cbd2e1b8c8", + "model_id": "ebe6e048029042029ef1d24008ebe7d5", "version_major": 2, "version_minor": 0 }, "text/plain": [ - "HBox(children=(IntProgress(value=0, description='Downloading', max=1300, style=ProgressStyle(description_width…" + "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1625270765.0, style=ProgressStyle(descr…" ] }, "metadata": {}, @@ -234,361 +801,2346 @@ "text": [ "\n" ] + }, + { + "data": { + "text/plain": [ + "BartForConditionalGeneration(\n", + " (model): BartModel(\n", + " (shared): Embedding(50264, 1024, padding_idx=1)\n", + " (encoder): BartEncoder(\n", + " (embed_tokens): Embedding(50264, 1024, padding_idx=1)\n", + " (embed_positions): LearnedPositionalEmbedding(1026, 1024, padding_idx=1)\n", + " (layers): ModuleList(\n", + " (0): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (1): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (2): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (3): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (4): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (5): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (6): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (7): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (8): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (9): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (10): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (11): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " (layernorm_embedding): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (decoder): BartDecoder(\n", + " (embed_tokens): Embedding(50264, 1024, padding_idx=1)\n", + " (embed_positions): LearnedPositionalEmbedding(1026, 1024, padding_idx=1)\n", + " (layers): ModuleList(\n", + " (0): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (1): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (2): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (3): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (4): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (5): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (6): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (7): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (8): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (9): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (10): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (11): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " (layernorm_embedding): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ + "# Transformer model being used\n", + "#MODEL_NAME = \"t5-large\"\n", + "MODEL_NAME = \"bart-large-cnn\"\n", "# notebook parameters\n", "# the cache data path during find tuning\n", - "CACHE_DIR = TemporaryDirectory().name\n", - " \n", - "processor = SummarizationProcessor(MODEL_NAME,cache_dir=CACHE_DIR ) #tokenizer, config.prefix)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Data Preprocessing\n", - "\n", - "The dataset we used for this notebook is CNN/DM dataset which contains the documents and accompanying questions from the news articles of CNN and Daily mail. The highlights in each article are used as summary. The dataset consits of ~289K training examples, ~11K valiation examples and ~11K test examples. The code in following cell will download the CNN/DM dataset listed at https://github.com/harvardnlp/sent-summary/.\n" + "CACHE_DIR = \"./bart_cache\" #TemporaryDirectory().name\n", + "summarizer = AbstractiveSummarizer(MODEL_NAME, cache_dir=CACHE_DIR)\n", + "summarizer.model" ] }, { "cell_type": "code", - "execution_count": 5, - "metadata": { - "tags": [ - "parameters" - ] - }, - "outputs": [], + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "BartConfig {\n", + " \"_num_labels\": 3,\n", + " \"activation_dropout\": 0.0,\n", + " \"activation_function\": \"gelu\",\n", + " \"add_final_layer_norm\": false,\n", + " \"architectures\": null,\n", + " \"attention_dropout\": 0.0,\n", + " \"bad_words_ids\": null,\n", + " \"bos_token_id\": 0,\n", + " \"classif_dropout\": 0.0,\n", + " \"d_model\": 1024,\n", + " \"decoder_attention_heads\": 16,\n", + " \"decoder_ffn_dim\": 4096,\n", + " \"decoder_layerdrop\": 0.0,\n", + " \"decoder_layers\": 12,\n", + " \"decoder_start_token_id\": 2,\n", + " \"do_sample\": false,\n", + " \"dropout\": 0.1,\n", + " \"early_stopping\": true,\n", + " \"encoder_attention_heads\": 16,\n", + " \"encoder_ffn_dim\": 4096,\n", + " \"encoder_layerdrop\": 0.0,\n", + " \"encoder_layers\": 12,\n", + " \"eos_token_id\": 2,\n", + " \"finetuning_task\": null,\n", + " \"id2label\": {\n", + " \"0\": \"LABEL_0\",\n", + " \"1\": \"LABEL_1\",\n", + " \"2\": \"LABEL_2\"\n", + " },\n", + " \"init_std\": 0.02,\n", + " \"is_decoder\": false,\n", + " \"is_encoder_decoder\": true,\n", + " \"label2id\": {\n", + " \"LABEL_0\": 0,\n", + " \"LABEL_1\": 1,\n", + " \"LABEL_2\": 2\n", + " },\n", + " \"length_penalty\": 2.0,\n", + " \"max_length\": 142,\n", + " \"max_position_embeddings\": 1024,\n", + " \"min_length\": 56,\n", + " \"model_type\": \"bart\",\n", + " \"no_repeat_ngram_size\": 3,\n", + " \"normalize_before\": false,\n", + " \"num_beams\": 4,\n", + " \"num_hidden_layers\": 12,\n", + " \"num_return_sequences\": 1,\n", + " \"output_attentions\": false,\n", + " \"output_hidden_states\": false,\n", + " \"output_past\": true,\n", + " \"pad_token_id\": 1,\n", + " \"prefix\": \" \",\n", + " \"pruned_heads\": {},\n", + " \"repetition_penalty\": 1.0,\n", + " \"scale_embedding\": false,\n", + " \"task_specific_params\": {\n", + " \"summarization\": {\n", + " \"early_stopping\": true,\n", + " \"length_penalty\": 2.0,\n", + " \"max_length\": 142,\n", + " \"min_length\": 56,\n", + " \"no_repeat_ngram_size\": 3,\n", + " \"num_beams\": 4\n", + " }\n", + " },\n", + " \"temperature\": 1.0,\n", + " \"top_k\": 50,\n", + " \"top_p\": 1.0,\n", + " \"torchscript\": false,\n", + " \"use_bfloat16\": false,\n", + " \"vocab_size\": 50264\n", + "}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# the data path used to save the downloaded data file\n", - "DATA_PATH = \"./bartt5_cnndm\" #TemporaryDirectory().name\n", - "# The number of lines at the head of data file used for preprocessing. -1 means all the lines.\n", - "TOP_N = 1000\n", - "if not QUICK_RUN:\n", - " TOP_N = -1" + "summarizer.config" ] }, { "cell_type": "code", - "execution_count": 6, - "metadata": { - "scrolled": false - }, + "execution_count": 19, + "metadata": {}, "outputs": [], "source": [ - "train_dataset, test_dataset = CNNDMSummarizationDataset(top_n=TOP_N, local_cache_path=DATA_PATH, raw=True)" + "task_specific_params = summarizer.config.task_specific_params" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "11490" + "{'summarization': {'early_stopping': True,\n", + " 'length_penalty': 2.0,\n", + " 'max_length': 142,\n", + " 'min_length': 56,\n", + " 'no_repeat_ngram_size': 3,\n", + " 'num_beams': 4}}" ] }, - "execution_count": 7, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "len(test_dataset)" + "task_specific_params" ] }, { - "cell_type": "markdown", + "cell_type": "code", + "execution_count": null, "metadata": {}, + "outputs": [], "source": [ - "Preprocess the data." + "tokens = bart.encode('Hello world!')\n", + "assert tokens.tolist() == [0, 31414, 232, 328, 2]\n", + "bart.decode(tokens) # 'Hello world!" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": false - }, + "execution_count": 44, + "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 2 µs, sys: 2 µs, total: 4 µs\n", - "Wall time: 9.78 µs\n" - ] + "data": { + "text/plain": [ + "[0, 20920, 7619, 611, 328, 2]" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "%time\n", - "abs_sum_train = processor.preprocess(train_dataset)\n" + "summarizer.tokenizer.encode('Hello frech!')" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 46, "metadata": {}, "outputs": [ { - "ename": "NameError", - "evalue": "name 'abs_sum_train' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mabs_sum_train\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mNameError\u001b[0m: name 'abs_sum_train' is not defined" - ] + "data": { + "text/plain": [ + "' Hello frech!'" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "abs_sum_train[0].keys()" + "summarizer.tokenizer.decode([0, 20920, 7619, 611, 328, 2], skip_special_tokens=True, clean_up_tokenization_spaces=True)" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 36, "metadata": {}, "outputs": [ { - "ename": "NameError", - "evalue": "name 'abs_sum_train' is not defined", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mabs_sum_train\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", - "\u001b[0;31mNameError\u001b[0m: name 'abs_sum_train' is not defined" - ] + "data": { + "text/plain": [ + "'Hello world!'" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "abs_sum_train" + "summarizer.tokenizer.decode([0, 31414, 232, 328, 2], skip_special_tokens=True, clean_up_tokenization_spaces=True)" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ - "abs_sum_test = processor.preprocess(test_dataset)" + "summarizer.config.update(task_specific_params.get(\"summarization\", {}))" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'src': \"marseille , france -lrb- cnn -rrb- the french prosecutor leading an investigation into the crash of germanwings flight 9525 insisted wednesday that he was not aware of any video footage from on board the plane . marseille prosecutor brice robin told cnn that `` so far no videos were used in the crash investigation . '' he added , `` a person who has such a video needs to immediately give it to the investigators . '' robin 's comments follow claims by two magazines , german daily bild and french paris match , of a cell phone video showing the harrowing final seconds from on board germanwings flight 9525 as it crashed into the french alps . all 150 on board were killed . paris match and bild reported that the video was recovered from a phone at the wreckage site . the two publications described the supposed video , but did not post it on their websites . the publications said that they watched the video , which was found by a source close to the investigation . `` one can hear cries of ` my god ' in several languages , '' paris match reported . `` metallic banging can also be heard more than three times , perhaps of the pilot trying to open the cockpit door with a heavy object . towards the end , after a heavy shake , stronger than the others , the screaming intensifies . then nothing . '' `` it is a very disturbing scene , '' said julian reichelt , editor-in-chief of bild online . an official with france 's accident investigation agency , the bea , said the agency is not aware of any such video . lt. col. jean-marc menichini , a french gendarmerie spokesman in charge of communications on rescue efforts around the germanwings crash site , told cnn that the reports were `` completely wrong '' and `` unwarranted . '' cell phones have been collected at the site , he said , but that they `` had n't been exploited yet . '' menichini said he believed the cell phones would need to be sent to the criminal research institute in rosny sous-bois , near paris , in order to be analyzed by specialized technicians working hand-in-hand with investigators . but none of the cell phones found so far have been sent to the institute , menichini said . asked whether staff involved in the search could have leaked a memory card to the media , menichini answered with a categorical `` no . '' reichelt told `` erin burnett : outfront '' that he had watched the video and stood by the report , saying bild and paris match are `` very confident '' that the clip is real . he noted that investigators only revealed they 'd recovered cell phones from the crash site after bild and paris match published their reports . `` that is something we did not know before . ... overall we can say many things of the investigation were n't revealed by the investigation at the beginning , '' he said . what was mental state of germanwings co-pilot ? german airline lufthansa confirmed tuesday that co-pilot andreas lubitz had battled depression years before he took the controls of germanwings flight 9525 , which he 's accused of deliberately crashing last week in the french alps . lubitz told his lufthansa flight training school in 2009 that he had a `` previous episode of severe depression , '' the airline said tuesday . email correspondence between lubitz and the school discovered in an internal investigation , lufthansa said , included medical documents he submitted in connection with resuming his flight training . the announcement indicates that lufthansa , the parent company of germanwings , knew of lubitz 's battle with depression , allowed him to continue training and ultimately put him in the cockpit . lufthansa , whose ceo carsten spohr previously said lubitz was 100 % fit to fly , described its statement tuesday as a `` swift and seamless clarification '' and said it was sharing the information and documents -- including training and medical records -- with public prosecutors . spohr traveled to the crash site wednesday , where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside . he saw the crisis center set up in seyne-les-alpes , laid a wreath in the village of le vernet , closer to the crash site , where grieving families have left flowers at a simple stone memorial . menichini told cnn late tuesday that no visible human remains were left at the site but recovery teams would keep searching . french president francois hollande , speaking tuesday , said that it should be possible to identify all the victims using dna analysis by the end of the week , sooner than authorities had previously suggested . in the meantime , the recovery of the victims ' personal belongings will start wednesday , menichini said . among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board . check out the latest from our correspondents . the details about lubitz 's correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and lubitz 's possible motive for downing the jet . a lufthansa spokesperson told cnn on tuesday that lubitz had a valid medical certificate , had passed all his examinations and `` held all the licenses required . '' earlier , a spokesman for the prosecutor 's office in dusseldorf , christoph kumpa , said medical records reveal lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot 's license . kumpa emphasized there 's no evidence suggesting lubitz was suicidal or acting aggressively before the crash . investigators are looking into whether lubitz feared his medical condition would cause him to lose his pilot 's license , a european government official briefed on the investigation told cnn on tuesday . while flying was `` a big part of his life , '' the source said , it 's only one theory being considered . another source , a law enforcement official briefed on the investigation , also told cnn that authorities believe the primary motive for lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems . lubitz 's girlfriend told investigators he had seen an eye doctor and a neuropsychologist , both of whom deemed him unfit to work recently and concluded he had psychological issues , the european government official said . but no matter what details emerge about his previous mental health struggles , there 's more to the story , said brian russell , a forensic psychologist . `` psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they were n't going to keep doing their job and they 're upset about that and so they 're suicidal , '' he said . `` but there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person 's problems . '' germanwings crash compensation : what we know . who was the captain of germanwings flight 9525 ? cnn 's margot haddad reported from marseille and pamela brown from dusseldorf , while laura smith-spark wrote from london . cnn 's frederik pleitgen , pamela boykoff , antonia mortensen , sandrine amiel and anna-maja rappard contributed to this report .\\n\",\n", - " 'src_txt': \"marseille , france -lrb- cnn -rrb- the french prosecutor leading an investigation into the crash of germanwings flight 9525 insisted wednesday that he was not aware of any video footage from on board the plane . marseille prosecutor brice robin told cnn that `` so far no videos were used in the crash investigation . '' he added , `` a person who has such a video needs to immediately give it to the investigators . '' robin 's comments follow claims by two magazines , german daily bild and french paris match , of a cell phone video showing the harrowing final seconds from on board germanwings flight 9525 as it crashed into the french alps . all 150 on board were killed . paris match and bild reported that the video was recovered from a phone at the wreckage site . the two publications described the supposed video , but did not post it on their websites . the publications said that they watched the video , which was found by a source close to the investigation . `` one can hear cries of ` my god ' in several languages , '' paris match reported . `` metallic banging can also be heard more than three times , perhaps of the pilot trying to open the cockpit door with a heavy object . towards the end , after a heavy shake , stronger than the others , the screaming intensifies . then nothing . '' `` it is a very disturbing scene , '' said julian reichelt , editor-in-chief of bild online . an official with france 's accident investigation agency , the bea , said the agency is not aware of any such video . lt. col. jean-marc menichini , a french gendarmerie spokesman in charge of communications on rescue efforts around the germanwings crash site , told cnn that the reports were `` completely wrong '' and `` unwarranted . '' cell phones have been collected at the site , he said , but that they `` had n't been exploited yet . '' menichini said he believed the cell phones would need to be sent to the criminal research institute in rosny sous-bois , near paris , in order to be analyzed by specialized technicians working hand-in-hand with investigators . but none of the cell phones found so far have been sent to the institute , menichini said . asked whether staff involved in the search could have leaked a memory card to the media , menichini answered with a categorical `` no . '' reichelt told `` erin burnett : outfront '' that he had watched the video and stood by the report , saying bild and paris match are `` very confident '' that the clip is real . he noted that investigators only revealed they 'd recovered cell phones from the crash site after bild and paris match published their reports . `` that is something we did not know before . ... overall we can say many things of the investigation were n't revealed by the investigation at the beginning , '' he said . what was mental state of germanwings co-pilot ? german airline lufthansa confirmed tuesday that co-pilot andreas lubitz had battled depression years before he took the controls of germanwings flight 9525 , which he 's accused of deliberately crashing last week in the french alps . lubitz told his lufthansa flight training school in 2009 that he had a `` previous episode of severe depression , '' the airline said tuesday . email correspondence between lubitz and the school discovered in an internal investigation , lufthansa said , included medical documents he submitted in connection with resuming his flight training . the announcement indicates that lufthansa , the parent company of germanwings , knew of lubitz 's battle with depression , allowed him to continue training and ultimately put him in the cockpit . lufthansa , whose ceo carsten spohr previously said lubitz was 100 % fit to fly , described its statement tuesday as a `` swift and seamless clarification '' and said it was sharing the information and documents -- including training and medical records -- with public prosecutors . spohr traveled to the crash site wednesday , where recovery teams have been working for the past week to recover human remains and plane debris scattered across a steep mountainside . he saw the crisis center set up in seyne-les-alpes , laid a wreath in the village of le vernet , closer to the crash site , where grieving families have left flowers at a simple stone memorial . menichini told cnn late tuesday that no visible human remains were left at the site but recovery teams would keep searching . french president francois hollande , speaking tuesday , said that it should be possible to identify all the victims using dna analysis by the end of the week , sooner than authorities had previously suggested . in the meantime , the recovery of the victims ' personal belongings will start wednesday , menichini said . among those personal belongings could be more cell phones belonging to the 144 passengers and six crew on board . check out the latest from our correspondents . the details about lubitz 's correspondence with the flight school during his training were among several developments as investigators continued to delve into what caused the crash and lubitz 's possible motive for downing the jet . a lufthansa spokesperson told cnn on tuesday that lubitz had a valid medical certificate , had passed all his examinations and `` held all the licenses required . '' earlier , a spokesman for the prosecutor 's office in dusseldorf , christoph kumpa , said medical records reveal lubitz suffered from suicidal tendencies at some point before his aviation career and underwent psychotherapy before he got his pilot 's license . kumpa emphasized there 's no evidence suggesting lubitz was suicidal or acting aggressively before the crash . investigators are looking into whether lubitz feared his medical condition would cause him to lose his pilot 's license , a european government official briefed on the investigation told cnn on tuesday . while flying was `` a big part of his life , '' the source said , it 's only one theory being considered . another source , a law enforcement official briefed on the investigation , also told cnn that authorities believe the primary motive for lubitz to bring down the plane was that he feared he would not be allowed to fly because of his medical problems . lubitz 's girlfriend told investigators he had seen an eye doctor and a neuropsychologist , both of whom deemed him unfit to work recently and concluded he had psychological issues , the european government official said . but no matter what details emerge about his previous mental health struggles , there 's more to the story , said brian russell , a forensic psychologist . `` psychology can explain why somebody would turn rage inward on themselves about the fact that maybe they were n't going to keep doing their job and they 're upset about that and so they 're suicidal , '' he said . `` but there is no mental illness that explains why somebody then feels entitled to also take that rage and turn it outward on 149 other people who had nothing to do with the person 's problems . '' germanwings crash compensation : what we know . who was the captain of germanwings flight 9525 ? cnn 's margot haddad reported from marseille and pamela brown from dusseldorf , while laura smith-spark wrote from london . cnn 's frederik pleitgen , pamela boykoff , antonia mortensen , sandrine amiel and anna-maja rappard contributed to this report .\\n\",\n", - " 'tgt': \" marseille prosecutor says `` so far no videos were used in the crash investigation '' despite media reports . journalists at bild and paris match are `` very confident '' the video clip is real , an editor says . andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says . \\n\",\n", - " 'tgt_txt': \" marseille prosecutor says `` so far no videos were used in the crash investigation '' despite media reports . journalists at bild and paris match are `` very confident '' the video clip is real , an editor says . andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says . \\n\",\n", - " 'source_ids': tensor([ 0, 4401, 1090, ..., 604, 1725, 2]),\n", - " 'source_mask': tensor([1, 1, 1, ..., 1, 1, 1]),\n", - " 'target_ids': tensor([ 0, 28696, 90, 15698, 4401, 1090, 4061, 5644, 161, 45518,\n", - " 98, 444, 117, 3424, 58, 341, 11, 5, 2058, 803,\n", - " 12801, 1135, 433, 690, 479, 49703, 90, 15698, 28696, 90,\n", - " 15698, 4225, 23, 741, 9683, 8, 2242, 354, 914, 32,\n", - " 45518, 182, 3230, 12801, 5, 569, 7200, 16, 588, 2156,\n", - " 41, 4474, 161, 479, 49703, 2])}" + "BartConfig {\n", + " \"_num_labels\": 3,\n", + " \"activation_dropout\": 0.0,\n", + " \"activation_function\": \"gelu\",\n", + " \"add_final_layer_norm\": false,\n", + " \"architectures\": [\n", + " \"BartModel\",\n", + " \"BartForMaskedLM\",\n", + " \"BartForSequenceClassification\"\n", + " ],\n", + " \"attention_dropout\": 0.0,\n", + " \"bad_words_ids\": null,\n", + " \"bos_token_id\": 0,\n", + " \"classif_dropout\": 0.0,\n", + " \"d_model\": 1024,\n", + " \"decoder_attention_heads\": 16,\n", + " \"decoder_ffn_dim\": 4096,\n", + " \"decoder_layerdrop\": 0.0,\n", + " \"decoder_layers\": 12,\n", + " \"decoder_start_token_id\": 2,\n", + " \"do_sample\": false,\n", + " \"dropout\": 0.1,\n", + " \"early_stopping\": true,\n", + " \"encoder_attention_heads\": 16,\n", + " \"encoder_ffn_dim\": 4096,\n", + " \"encoder_layerdrop\": 0.0,\n", + " \"encoder_layers\": 12,\n", + " \"eos_token_id\": 2,\n", + " \"finetuning_task\": null,\n", + " \"id2label\": {\n", + " \"0\": \"LABEL_0\",\n", + " \"1\": \"LABEL_1\",\n", + " \"2\": \"LABEL_2\"\n", + " },\n", + " \"init_std\": 0.02,\n", + " \"is_decoder\": false,\n", + " \"is_encoder_decoder\": true,\n", + " \"label2id\": {\n", + " \"LABEL_0\": 0,\n", + " \"LABEL_1\": 1,\n", + " \"LABEL_2\": 2\n", + " },\n", + " \"length_penalty\": 2.0,\n", + " \"max_length\": 142,\n", + " \"max_position_embeddings\": 1024,\n", + " \"min_length\": 56,\n", + " \"model_type\": \"bart\",\n", + " \"no_repeat_ngram_size\": 3,\n", + " \"normalize_before\": false,\n", + " \"num_beams\": 4,\n", + " \"num_hidden_layers\": 12,\n", + " \"num_return_sequences\": 1,\n", + " \"output_attentions\": false,\n", + " \"output_hidden_states\": false,\n", + " \"output_past\": false,\n", + " \"pad_token_id\": 1,\n", + " \"prefix\": \" \",\n", + " \"pruned_heads\": {},\n", + " \"repetition_penalty\": 1.0,\n", + " \"scale_embedding\": false,\n", + " \"task_specific_params\": {\n", + " \"summarization\": {\n", + " \"early_stopping\": true,\n", + " \"length_penalty\": 2.0,\n", + " \"max_length\": 142,\n", + " \"min_length\": 56,\n", + " \"no_repeat_ngram_size\": 3,\n", + " \"num_beams\": 4\n", + " }\n", + " },\n", + " \"temperature\": 1.0,\n", + " \"top_k\": 50,\n", + " \"top_p\": 1.0,\n", + " \"torchscript\": false,\n", + " \"use_bfloat16\": false,\n", + " \"vocab_size\": 50265\n", + "}" ] }, - "execution_count": 11, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "abs_sum_test[0]" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "\"\"\"\n", - "# save and load preprocessed data\n", - "save_path = DATA_PATH\n", - "torch.save(abs_sum_train, os.path.join(save_path, \"train_full.pt\"))\n", - "torch.save(abs_sum_test, os.path.join(DATA_PATH, \"test_full.pt\"))\n", - "\n", - "\"\"\"\n", - "save_path = DATA_PATH\n", - "#abs_sum_train = torch.load(os.path.join(save_path, \"train_full.pt\"))\n", - "abs_sum_test = torch.load(os.path.join(save_path, \"test_full.pt\"))" + "summarizer.config" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 21, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "287227\n", - "11490\n" - ] + "data": { + "text/plain": [ + "{'early_stopping': True,\n", + " 'length_penalty': 2.0,\n", + " 'max_length': 142,\n", + " 'min_length': 56,\n", + " 'no_repeat_ngram_size': 3,\n", + " 'num_beams': 4}" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "print(len(abs_sum_train))\n", - "print(len(abs_sum_test))" + "task_specific_params.get(\"summarization\", {})" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 25, "metadata": {}, "outputs": [], "source": [ - "#save_path = os.path.join(DATA_PATH, \"processed\")\n", - "#torch.save(abs_sum_train, os.path.join(save_path, \"train_full.pt\"))\n", - "#torch.save(abs_sum_test, os.path.join(DATA_PATH, \"test_full.pt\"))" + "summarizer.config.update({\"vocab_size\": 50264})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "#### Inspect Data" + "### Data Preprocessing\n", + "\n", + "The dataset we used for this notebook is CNN/DM dataset which contains the documents and accompanying questions from the news articles of CNN and Daily mail. The highlights in each article are used as summary. The dataset consits of ~289K training examples, ~11K valiation examples and ~11K test examples. The code in following cell will download the CNN/DM dataset listed at https://github.com/harvardnlp/sent-summary/.\n" ] }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, + "execution_count": 49, + "metadata": { + "tags": [ + "parameters" + ] + }, "outputs": [], "source": [ - "abs_sum_train[0].keys()" + "# the data path used to save the downloaded data file\n", + "DATA_PATH = \"./bartt5_cnndm\" #TemporaryDirectory().name\n", + "# The number of lines at the head of data file used for preprocessing. -1 means all the lines.\n", + "TOP_N = 1000\n", + "if not QUICK_RUN:\n", + " TOP_N = -1" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 50, "metadata": { "scrolled": false }, "outputs": [], "source": [ - "abs_sum_train[0]" + "train_dataset, test_dataset = CNNDMSummarizationDataset(top_n=TOP_N, local_cache_path=DATA_PATH, raw=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\" marseille prosecutor says `` so far no videos were used in the crash investigation '' despite media reports . journalists at bild and paris match are `` very confident '' the video clip is real , an editor says . andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says . \\n\"" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_dataset[0]['tgt_txt']" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "11490" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(test_dataset)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Model training\n", - "To start model training, we need to create a instance of ExtractiveSummarizer.\n", - "#### Choose the transformer model.\n", - "Currently ExtractiveSummarizer support two models:\n", - "- distilbert-base-uncase, \n", - "- bert-base-uncase\n", - "\n", - "Potentionally, roberta-based model and xlnet can be supported but needs to be tested.\n", - "#### Choose the encoder algorithm.\n", - "There are four options:\n", - "- baseline: it used a smaller transformer model to replace the bert model and with transformer summarization layer\n", - "- classifier: it uses pretrained BERT and fine-tune BERT with **simple logistic classification** summarization layer\n", - "- transformer: it uses pretrained BERT and fine-tune BERT with **transformer** summarization layer\n", - "- RNN: it uses pretrained BERT and fine-tune BERT with **LSTM** summarization layer" + "Preprocess the data." ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "metadata": { - "tags": [ - "parameters" - ] + "scrolled": false }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 5 µs, sys: 2 µs, total: 7 µs\n", + "Wall time: 20.3 µs\n" + ] + } + ], + "source": [ + "%time\n", + "abs_sum_train = summarizer.processor.preprocess(train_dataset)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, "outputs": [], "source": [ - "BATCH_SIZE = 5 # batch size, unit is the number of samples\n", - "MAX_POS_LENGTH = 512\n", - "\n", - " \n", - "\n", - "\n", - "# GPU used for training\n", - "NUM_GPUS = torch.cuda.device_count()\n", - "\n", - "\n", - "# Learning rate\n", - "LEARNING_RATE=2e-3\n", - "\n", - "# How often the statistics reports show up in training, unit is step.\n", - "REPORT_EVERY=100\n", - "\n", - "# total number of steps for training\n", - "MAX_STEPS=1e2\n", - "# number of steps for warm up\n", - "WARMUP_STEPS=5e2\n", - " \n", - "if not QUICK_RUN:\n", - " MAX_STEPS=5e4\n", - " WARMUP_STEPS=5e3\n", - " " + "# torch.save(abs_sum_train, os.path.join(DATA_PATH, \"train_{0}_full.pt\".format(MODEL_NAME)))" ] }, { "cell_type": "code", - "execution_count": 10, - "metadata": { - "scrolled": true - }, + "execution_count": 16, + "metadata": {}, "outputs": [], "source": [ - "summarizer = AbstractiveSummarizer(processor, MODEL_NAME)" + "# torch.save(abs_sum_test, os.path.join(DATA_PATH, \"test_{0}_full.pt\".format(MODEL_NAME)))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 51, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "abs_sum_test = summarizer.processor.preprocess(test_dataset)" + ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 58, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'input_ids': tensor([[ 0, 4401, 1090, ..., 167, 1081, 2],\n", + " [ 0, 36, 740, ..., 1, 1, 1]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, ..., 1, 1, 1],\n", + " [1, 1, 1, ..., 0, 0, 0]], device='cuda:0'), 'decoder_input_ids': tensor([[ 0, 4401, 1090, 4061, 5644, 161, 22, 98, 444, 117,\n", + " 3424, 58, 341, 11, 5, 2058, 803, 22, 1135, 433,\n", + " 690, 479, 1437, 1437, 4225, 23, 741, 9683, 8, 2242,\n", + " 354, 914, 32, 22, 182, 3230, 22, 5, 569, 7200,\n", + " 16, 588, 2156, 41, 4474, 161, 479, 1437, 1437, 8,\n", + " 241, 281, 784, 1792, 4494, 56, 3978, 39, 784, 2951,\n", + " 212, 1253, 102, 1058, 334, 9, 41, 3238, 9, 3814,\n", + " 6943, 2156, 5195, 161, 479],\n", + " [ 0, 6332, 2029, 5, 41591, 438, 10542, 81, 1697, 3474,\n", + " 2021, 11, 8750, 990, 28307, 13560, 187, 94, 1236, 4438,\n", + " 479, 1437, 1437, 16, 37715, 8, 5, 10409, 982, 4340,\n", + " 5, 517, 2156, 61, 115, 490, 5, 1883, 7, 997,\n", + " 3474, 4941, 136, 16, 37715, 354, 479, 2, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1]], device='cuda:0'), 'lm_labels': tensor([[ 4401, 1090, 4061, 5644, 161, 22, 98, 444, 117, 3424,\n", + " 58, 341, 11, 5, 2058, 803, 22, 1135, 433, 690,\n", + " 479, 1437, 1437, 4225, 23, 741, 9683, 8, 2242, 354,\n", + " 914, 32, 22, 182, 3230, 22, 5, 569, 7200, 16,\n", + " 588, 2156, 41, 4474, 161, 479, 1437, 1437, 8, 241,\n", + " 281, 784, 1792, 4494, 56, 3978, 39, 784, 2951, 212,\n", + " 1253, 102, 1058, 334, 9, 41, 3238, 9, 3814, 6943,\n", + " 2156, 5195, 161, 479, 2],\n", + " [ 6332, 2029, 5, 41591, 438, 10542, 81, 1697, 3474, 2021,\n", + " 11, 8750, 990, 28307, 13560, 187, 94, 1236, 4438, 479,\n", + " 1437, 1437, 16, 37715, 8, 5, 10409, 982, 4340, 5,\n", + " 517, 2156, 61, 115, 490, 5, 1883, 7, 997, 3474,\n", + " 4941, 136, 16, 37715, 354, 479, 2, -100, -100, -100,\n", + " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n", + " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n", + " -100, -100, -100, -100, -100]], device='cuda:0')}\n" + ] + } + ], + "source": [ + "a = summarizer.processor.collate_fn(abs_sum_test[0:2], \"cuda:0\", True)\n", + "c = summarizer.processor.get_inputs(a, \"cuda:0\", MODEL_NAME, summarizer.tokenizer, True)\n", + "print(c)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "abs_sum_train = torch.load( os.path.join(DATA_PATH, \"train_{0}_full.pt\".format(MODEL_NAME)))" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r", + "Generating summary: 0%| | 0/1 [00:00 marseille prosecutor says `` so far no videos were used in the crash investigation '' despite media reports . journalists at bild and paris match are `` very confident '' the video clip is real , an editor says . andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says . \\n\", 'tgt_txt': \" marseille prosecutor says `` so far no videos were used in the crash investigation '' despite media reports . journalists at bild and paris match are `` very confident '' the video clip is real , an editor says . andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says . \\n\", 'source_ids': tensor([ 0, 4401, 1090, ..., 604, 1725, 2]), 'source_mask': tensor([1, 1, 1, ..., 1, 1, 1]), 'target_ids': tensor([ 0, 28696, 90, 15698, 4401, 1090, 4061, 5644, 161, 45518,\n", + " 98, 444, 117, 3424, 58, 341, 11, 5, 2058, 803,\n", + " 12801, 1135, 433, 690, 479, 49703, 90, 15698, 28696, 90,\n", + " 15698, 4225, 23, 741, 9683, 8, 2242, 354, 914, 32,\n", + " 45518, 182, 3230, 12801, 5, 569, 7200, 16, 588, 2156,\n", + " 41, 4474, 161, 479, 49703, 90, 15698, 28696, 90, 15698,\n", + " 8, 241, 281, 784, 1792, 4494, 56, 3978, 39, 784,\n", + " 2951, 212, 1253, 102, 1058, 334, 9, 41, 3238, 9,\n", + " 3814, 6943, 2156, 5195, 161, 479, 49703, 90, 15698, 2,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", + " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])}\n" + ] + } + ], + "source": [ + "print(abs_sum_test[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "nnnnennennsnnnnnnnnsnnennennennnnnnennsnnnnnnnnnnnnsnnnnnnnnnnnnnnnnennennnnnnnnnnnnnnnnennnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnennnnnnnnnnnnnnnnennnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnennnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnn\n" + ] + } + ], + "source": [ + "print(prediction[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "BartForConditionalGeneration(\n", + " (model): BartModel(\n", + " (shared): Embedding(50265, 1024, padding_idx=1)\n", + " (encoder): BartEncoder(\n", + " (embed_tokens): Embedding(50265, 1024, padding_idx=1)\n", + " (embed_positions): LearnedPositionalEmbedding(1026, 1024, padding_idx=1)\n", + " (layers): ModuleList(\n", + " (0): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (1): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (2): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (3): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (4): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (5): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (6): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (7): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (8): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (9): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (10): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (11): EncoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " (layernorm_embedding): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (decoder): BartDecoder(\n", + " (embed_tokens): Embedding(50265, 1024, padding_idx=1)\n", + " (embed_positions): LearnedPositionalEmbedding(1026, 1024, padding_idx=1)\n", + " (layers): ModuleList(\n", + " (0): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (1): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (2): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (3): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (4): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (5): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (6): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (7): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (8): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (9): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (10): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " (11): DecoderLayer(\n", + " (self_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (encoder_attn): SelfAttention(\n", + " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", + " )\n", + " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", + " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", + " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + " (layernorm_embedding): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", + " )\n", + " )\n", + ")" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "summarizer.model.module" + ] + }, + { + "cell_type": "code", + "execution_count": 8, "metadata": { "scrolled": true }, "outputs": [], "source": [ - "\"\"\"\n", + "\"\"\"\n", + "# save and load preprocessed data\n", + "save_path = DATA_PATH\n", + "torch.save(abs_sum_train, os.path.join(save_path, \"train_full.pt\"))\n", + "torch.save(abs_sum_test, os.path.join(DATA_PATH, \"test_full.pt\"))\n", + "\n", + "\"\"\"\n", + "save_path = DATA_PATH\n", + "#abs_sum_train = torch.load(os.path.join(save_path, \"train_full.pt\"))\n", + "abs_sum_test = torch.load(os.path.join(save_path, \"test_full.pt\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "287227\n", + "11490\n" + ] + } + ], + "source": [ + "print(len(abs_sum_train))\n", + "print(len(abs_sum_test))" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "#save_path = os.path.join(DATA_PATH, \"processed\")\n", + "#torch.save(abs_sum_train, os.path.join(save_path, \"train_full.pt\"))\n", + "#torch.save(abs_sum_test, os.path.join(DATA_PATH, \"test_full.pt\"))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Inspect Data" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "abs_sum_train[0].keys()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "abs_sum_train[0]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Model training\n", + "To start model training, we need to create a instance of ExtractiveSummarizer.\n", + "#### Choose the transformer model.\n", + "Currently ExtractiveSummarizer support two models:\n", + "- distilbert-base-uncase, \n", + "- bert-base-uncase\n", + "\n", + "Potentionally, roberta-based model and xlnet can be supported but needs to be tested.\n", + "#### Choose the encoder algorithm.\n", + "There are four options:\n", + "- baseline: it used a smaller transformer model to replace the bert model and with transformer summarization layer\n", + "- classifier: it uses pretrained BERT and fine-tune BERT with **simple logistic classification** summarization layer\n", + "- transformer: it uses pretrained BERT and fine-tune BERT with **transformer** summarization layer\n", + "- RNN: it uses pretrained BERT and fine-tune BERT with **LSTM** summarization layer" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "tags": [ + "parameters" + ] + }, + "outputs": [], + "source": [ + "BATCH_SIZE = 8 # batch size, unit is the number of samples\n", + "MAX_POS_LENGTH = 512\n", + "\n", + " \n", + "\n", + "\n", + "# GPU used for training\n", + "NUM_GPUS = torch.cuda.device_count()\n", + "\n", + "\n", + "# Learning rate\n", + "LEARNING_RATE=3e-5\n", + "\n", + "# How often the statistics reports show up in training, unit is step.\n", + "REPORT_EVERY=100\n", + "\n", + "# total number of steps for training\n", + "MAX_STEPS=1e2\n", + "# number of steps for warm up\n", + "WARMUP_STEPS=5e2\n", + " \n", + "if not QUICK_RUN:\n", + " MAX_STEPS=5e3\n", + " WARMUP_STEPS=5e2\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Iteration: 0%| | 0/35904 [00:00\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0mwarmup_steps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mWARMUP_STEPS\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0mreport_every\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mREPORT_EVERY\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m )\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/dadendev/nlp-recipes/utils_nlp/models/transformers/abstractive_summarization_bartt5.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, train_dataset, num_gpus, gpu_ids, batch_size, local_rank, max_steps, warmup_steps, learning_rate, weight_decay, adam_epsilon, max_grad_norm, gradient_accumulation_steps, report_every, save_every, verbose, seed, fp16, fp16_opt_level, world_size, rank, validation_function, checkpoint, **kwargs)\u001b[0m\n\u001b[1;32m 447\u001b[0m \u001b[0mfp16\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfp16\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 448\u001b[0m \u001b[0mamp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mamp\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 449\u001b[0;31m \u001b[0mvalidation_function\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 450\u001b[0m )\n\u001b[1;32m 451\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/dadendev/nlp-recipes/utils_nlp/models/transformers/common.py\u001b[0m in \u001b[0;36mfine_tune\u001b[0;34m(self, train_dataloader, get_inputs, device, num_gpus, max_steps, global_step, max_grad_norm, gradient_accumulation_steps, optimizer, scheduler, fp16, amp, local_rank, verbose, seed, report_every, save_every, clip_grad_norm, validation_function)\u001b[0m\n\u001b[1;32m 237\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch_iterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 238\u001b[0m \u001b[0minputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_inputs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 239\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 240\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 241\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 530\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 531\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 532\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 533\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 534\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, *inputs, **kwargs)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 151\u001b[0m \u001b[0mreplicas\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreplicate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice_ids\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 152\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparallel_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreplicas\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 153\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgather\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutput_device\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 154\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py\u001b[0m in \u001b[0;36mparallel_apply\u001b[0;34m(self, replicas, inputs, kwargs)\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mparallel_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreplicas\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mparallel_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreplicas\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice_ids\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreplicas\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mgather\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_device\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py\u001b[0m in \u001b[0;36mparallel_apply\u001b[0;34m(modules, inputs, kwargs_tup, devices)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 74\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mthread\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mthreads\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 75\u001b[0;31m \u001b[0mthread\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 76\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mthread\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mthreads\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0mthread\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/threading.py\u001b[0m in \u001b[0;36mstart\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 849\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0m_limbo\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 850\u001b[0m \u001b[0;32mraise\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 851\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_started\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwait\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 852\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 853\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/threading.py\u001b[0m in \u001b[0;36mwait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 549\u001b[0m \u001b[0msignaled\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_flag\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 550\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0msignaled\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 551\u001b[0;31m \u001b[0msignaled\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cond\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwait\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 552\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0msignaled\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 553\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/threading.py\u001b[0m in \u001b[0;36mwait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 293\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# restore state no matter what (e.g., KeyboardInterrupt)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 294\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtimeout\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 295\u001b[0;31m \u001b[0mwaiter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0macquire\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 296\u001b[0m \u001b[0mgotit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 297\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], + "source": [ + "#\"\"\"\n", "\n", "summarizer.fit(\n", - " ext_sum_train,\n", + " abs_sum_train,\n", " num_gpus=NUM_GPUS,\n", " batch_size=BATCH_SIZE,\n", - " gradient_accumulation_steps=2,\n", + " gradient_accumulation_steps=1,\n", " max_steps=MAX_STEPS,\n", " learning_rate=LEARNING_RATE,\n", " warmup_steps=WARMUP_STEPS,\n", " verbose=True,\n", " report_every=REPORT_EVERY,\n", - " clip_grad_norm=False,\n", - " use_preprocessed_data=USE_PREPROCSSED_DATA\n", " )\n", "\n", - "\"\"\"\n" + "#\"\"\"\n" ] }, { @@ -720,6 +3272,58 @@ "prediction = summarizer.predict(abs_sum_test[0:256*4], num_gpus=1, batch_size=16) " ] }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r", + "Generating summary: 0%| | 0/1 [00:00", "", line) return line +def _remove_tags(line): + line = re.sub(r"", "", line) + # change to + # pyrouge test requires as sentence splitter + line = re.sub(r"", "", line) + return line + def _target_sentence_tokenization(line): return line.split("") @@ -91,10 +98,20 @@ def _setup_datasets( SummarizationDataset( train_source_file, target_file=train_target_file, + source_preprocessing=[_clean,], + target_preprocessing=[ + _clean, + _remove_tags, + ], top_n=top_n ), SummarizationDataset( test_source_file, + source_preprocessing=[_clean,], + target_preprocessing=[ + _clean, + _remove_tags, + ], target_file=test_target_file, top_n=top_n ), diff --git a/utils_nlp/models/transformers/abstractive_summarization_bartt5.py b/utils_nlp/models/transformers/abstractive_summarization_bartt5.py index b4ac35a73..97f002349 100644 --- a/utils_nlp/models/transformers/abstractive_summarization_bartt5.py +++ b/utils_nlp/models/transformers/abstractive_summarization_bartt5.py @@ -37,11 +37,6 @@ from utils_nlp.models.transformers.common import Transformer # from utils_nlp.models.transformers.common import TOKENIZER_CLASS - - - -#from transformers.modeling_bart import BART_PRETRAINED_MODEL_ARCHIVE_MAP - from transformers import ( AutoConfig, AutoModel, @@ -76,6 +71,17 @@ from transformers.tokenization_utils import trim_batch from tempfile import TemporaryDirectory + +def trim_batch( + input_ids, pad_token_id, attention_mask=None, +): + """Remove columns that are populated exclusively by pad_token_id""" + keep_column_mask = input_ids.ne(pad_token_id).any(dim=0) + if attention_mask is None: + return input_ids[:, keep_column_mask] + else: + return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask]) + def encode_example(example, tokenizer, prefix="", max_source_length=None, max_target_length=None, pad_to_max_length=True, return_tensors="pt"): ## add to the dataset tokenized_source = tokenizer.batch_encode_plus( @@ -96,11 +102,17 @@ def encode_example(example, tokenizer, prefix="", max_source_length=None, max_ta class Predictor(nn.Module): + """ + Predictor which can run on multi-GPUs. + + Args: + model (): + """ def __init__( self, model, - min_length, - max_length, + min_length=55, + max_length=140, **kwargs): super(Predictor, self).__init__() self.model = model.module if hasattr(model, "module") else model @@ -116,6 +128,13 @@ def forward(self, src, src_mask): attention_mask=src_mask, min_length=self.min_length, max_length=self.max_length, + # todo + num_beams=1, + repetition_penalty=2.5, + length_penalty=1.0, + early_stopping=True, + no_repeat_ngram_size=3, + #use_cache=True, **self.config ) predictions = torch.tensor( @@ -129,40 +148,85 @@ def forward(self, src, src_mask): return predictions +def validate(summarizer, validate_dataset, num_gpus=1, TOP_N=2): + """ validation function to be used optionally in fine tuning. + + Args: + summarizer(BertSumAbs): The summarizer under fine tuning. + validate_dataset (SummarizationDataset): dataset for validation. + + Returns: + string: A string which contains the rouge score on a subset of + the validation dataset. + + """ + #TOP_N = 2 + shortened_dataset = validate_dataset[0:TOP_N] + a = summarizer.processor.collate_fn(shortened_dataset, "cuda:0", True) + c = summarizer.processor.get_inputs(a, "cuda:0", summarizer.model_name, summarizer.tokenizer, True) + + # reference_summaries = [] + # for i in shortened_dataset: + # reference_summaries.append(i['tgt'].replace("","").replace("", "").replace("\n", "")) + output = summarizer.model(**c) + generated_summaries = summarizer.predict( + shortened_dataset, num_gpus=num_gpus, batch_size=TOP_N + ) + #assert len(generated_summaries) == len(reference_summaries) + #print("###################") + print("validation loss is {}".format(output[0])) + print("prediction is {}".format(generated_summaries[0])) + #print("reference is {}".format(reference_summaries[0])) + + #rouge_score = compute_rouge_python( + # cand=generated_summaries, ref=reference_summaries + #) + #return "rouge score: {}".format(rouge_score) + class SummarizationProcessor: def __init__( self, - model_name, - cache_dir="./", + tokenizer, + config, + #model_name, + #cache_dir="./", max_source_length=1024, - max_target_length=56, + max_target_length=140, ): #super().__init__() - self.tokenizer = AutoTokenizer.from_pretrained( - model_name, + # self.tokenizer = AutoTokenizer.from_pretrained( + # model_name, #self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path, - cache_dir=cache_dir, - ) + # cache_dir=cache_dir, + #) #TOKENIZER_CLASS[model_name].from_pretrained(model_name, cache_dir=cache_dir) # b - config = AutoConfig.from_pretrained( - model_name, - #self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path, + #config = AutoConfig.from_pretrained( + # model_name, + # #self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path, #**({"num_labels": num_labels} if num_labels is not None else {}), - cache_dir=cache_dir, + # cache_dir=cache_dir, #**config_kwargs, - ) - if model_name.startswith("t5"): + #) + # if model_name.startswith("t5"): # update config with summarization specific params - task_specific_params = config.task_specific_params - if task_specific_params is not None: - config.update(task_specific_params.get("summarization", {})) + # task_specific_params = config.task_specific_params + # if task_specific_params is not None: + # config.update(task_specific_params.get("summarization", {})) + self.tokenizer = tokenizer + self.config = config self.prefix = config.prefix self.with_target = False self.max_source_length = max_source_length self.max_target_length = max_target_length + @staticmethod + def trim_seq2seq_batch(batch, pad_token_id): + y = trim_batch(batch["target_ids"], pad_token_id) + source_ids, source_mask = trim_batch(batch["source_ids"], pad_token_id, attention_mask=batch["source_mask"]) + return source_ids, source_mask, y + def preprocess(self, input_data_list): result = [] for i in input_data_list: @@ -173,7 +237,10 @@ def preprocess(self, input_data_list): @staticmethod def get_inputs(batch, device, model_name, tokenizer=None, train_mode=True): pad_token_id = tokenizer.pad_token_id - source_ids, source_mask, y = batch["source_ids"], batch["source_mask"], batch["target_ids"] + if not train_mode: + source_ids, source_mask = batch["source_ids"], batch["source_mask"] + else: + source_ids, source_mask, y = SummarizationProcessor.trim_seq2seq_batch(batch, pad_token_id) y_ids = y[:, :-1].contiguous() lm_labels = y[:, 1:].clone() lm_labels[y[:, 1:] == pad_token_id] = -100 @@ -218,30 +285,20 @@ def __init__( """Initialize an object of BertSumAbs. Args: - model_name (str, optional:) Name of the pretrained model which is used - to initialize the encoder of the BertSumAbs model. - check MODEL_CLASS for supported models. Defaults to "bert-base-uncased". + model_name (str, optional:) Name of the pretrained model which is used . + `S2SAbsSumProcessor.list_supported_models()` to see all supported model names.. Defaults to "bart-large". cache_dir (str, optional): Directory to cache the tokenizer. Defaults to ".". max_pos_length (int, optional): maximum postional embedding length for the input. Defaults to 768. """ - """super().__init__( - model_class=AutoModelWithLMHead, - model_name=model_name, - num_labels=0, - cache_dir=cache_dir, - ) - """ - """ if model_name not in self.list_supported_models(): raise ValueError( - "Model name {} is not supported by BertSumAbs. " - "Call 'BertSumAbs.list_supported_models()' to get all supported model " + "Model name {} is not supported by AbstractiveSummarizer. " + "Call 'AbstractiveSummarizer.list_supported_models()' to get all supported model " "names.".format(value) ) - """ - self.processor = SummarizationProcessor(model_name, cache_dir, max_source_length, max_target_length) + self.config = AutoConfig.from_pretrained( model_name, #self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path, @@ -249,6 +306,12 @@ def __init__( cache_dir=cache_dir, #**config_kwargs, ) + task_specific_params = self.config.task_specific_params + if task_specific_params is not None: + self.config.update(task_specific_params.get("summarization", {})) + #self.config.update({"vocab_size": 50264}) + self.config.update({"max_length": max_target_length}) + self.config.update({"attention_dropout": 0.1}) self.tokenizer = AutoTokenizer.from_pretrained( model_name, @@ -256,8 +319,11 @@ def __init__( cache_dir=cache_dir, ) + self.processor = SummarizationProcessor(self.tokenizer, self.config, max_source_length, max_target_length) + self._model_name = model_name - self.model = MODEL_MODES["language-modeling"].from_pretrained( + # self.model = MODEL_CLASS[model_name].from_pretrained( + self.model = MODEL_MODES["language-modeling"].from_pretrained( self.model_name, #from_tf=bool(".ckpt" in self.hparams.model_name_or_path), config=self.config, @@ -385,7 +451,7 @@ def fit( weight_decay=weight_decay, learning_rate=learning_rate, adam_epsilon=adam_epsilon, - # checkpoint_state_dict=checkpoint_state_dict, + checkpoint_state_dict=checkpoint_state_dict, ) self.amp = amp @@ -449,14 +515,14 @@ def collate_fn(data): scheduler=self.scheduler, fp16=fp16, amp=amp, - validation_function=None, + validation_function=validation_function, ) # release GPU memories self.model.cpu() torch.cuda.empty_cache() - self.save_model(max_steps) + #self.save_model(max_steps) def predict( self, @@ -467,13 +533,13 @@ def predict( batch_size=16, length_penalty=0.95, beam_size=4, - min_length=50, - max_length=200, + min_length=56, + max_length=140, no_repeat_ngram_size=3, early_stopping=True, - fp16=False, verbose=True, + checkpoint=None, ): """ Predict the summarization for the input data iterator. @@ -505,23 +571,32 @@ def predict( List of strings which are the summaries """ + device, num_gpus = get_device( num_gpus=num_gpus, gpu_ids=gpu_ids, local_rank=local_rank ) + model = move_model_to_device(self.model, device) + + checkpoint_state_dict = None + if checkpoint: + # checkpoint should have "model", "optimizer", "amp" + checkpoint_state_dict = torch.load(checkpoint, map_location="cpu") + model.load_state_dict(checkpoint_state_dict["model"]) - if fp16: - self.model = self.model.half() - self.model = move_model_to_device(self.model, device) - self.model.eval() + model.eval() - self.model = parallelize_model( - self.model, + model = parallelize_model( + model, device, num_gpus=num_gpus, gpu_ids=gpu_ids, local_rank=local_rank, ) + + + if fp16: + model = model.half() test_sampler = SequentialSampler(test_dataset) @@ -538,7 +613,7 @@ def collate_fn(data): ) print("dataset length is {}".format(len(test_dataset))) - predictor = Predictor(self.model, min_length, max_length) + predictor = Predictor(model, min_length, max_length) # move model to devices def this_model_move_callback(model, device): model = move_model_to_device(model, device) @@ -550,13 +625,14 @@ def this_model_move_callback(model, device): generated_summaries = [] for batch in tqdm( - test_dataloader, desc="Generating summary", disable=not verbose + test_dataloader, desc="Generating summary", disable=True #not verbose ): #if self.model_name.startswith("t5"): # batch = [self.model.config.prefix + text for text in batch] #dct = self.tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True) # print(batch) - summaries = predictor(batch["source_ids"], batch["source_mask"]) + input_ids, masks = trim_batch(batch["source_ids"], self.tokenizer.pad_token_id, attention_mask=batch["source_mask"]) + summaries = predictor(input_ids, masks) """ summaries = self.model.module.generate( input_ids=batch["source_ids"], @@ -565,12 +641,12 @@ def this_model_move_callback(model, device): max_length=max_length ) """ - decoded_summaries = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in summaries] + decoded_summaries = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in summaries] generated_summaries.extend(decoded_summaries) # release GPU memories - self.model.cpu() + # self.model.cpu() del batch torch.cuda.empty_cache() @@ -602,8 +678,8 @@ def save_model(self, global_step=None, full_name=None): os.makedirs(path, exist_ok=True) checkpoint = { - "optimizer": self.optimizer.state_dict(), - "lr_scheduler": self.scheduler.state_dict(), + "optimizer": self.optimizer.state_dict() if self.optimizer else None, + "lr_scheduler": self.scheduler.state_dict() if self.scheduler else None, "model": model_to_save.state_dict(), "amp": self.amp.state_dict() if self.amp else None, "global_step": global_step, diff --git a/utils_nlp/models/transformers/common.py b/utils_nlp/models/transformers/common.py index cbe845c5f..ec193e7ed 100755 --- a/utils_nlp/models/transformers/common.py +++ b/utils_nlp/models/transformers/common.py @@ -232,7 +232,7 @@ def fine_tune( epoch_iterator = tqdm( train_dataloader, desc="Iteration", - disable=local_rank not in [-1, 0] or not verbose, + disable=True #local_rank not in [-1, 0] or not verbose, ) for step, batch in enumerate(epoch_iterator): inputs = get_inputs(batch, device, self.model_name) @@ -291,6 +291,10 @@ def fine_tune( ) logger.info(log_line) print(log_line) + if validation_function: + validation_log = validation_function(self) + logger.info(validation_log) + print(validation_log) accum_loss = 0 train_size = 0 start = end @@ -318,10 +322,6 @@ def fine_tune( self.cache_dir, f"{self.model_name}_step_{global_step}.pt" ) self.save_model(global_step, saved_model_path) - if validation_function: - validation_log = validation_function(self) - logger.info(validation_log) - print(validation_log) if global_step > max_steps: epoch_iterator.close() break From fd5d2ea9076752012ea648378556a5b4689c6b4d Mon Sep 17 00:00:00 2001 From: Daisy Deng Date: Wed, 20 May 2020 18:53:44 +0000 Subject: [PATCH 09/14] clean up and add documentation --- ...ive_summarization_cnndm_transformers.ipynb | 3445 ++--------------- .../abstractive_summarization_bartt5.py | 481 +-- 2 files changed, 530 insertions(+), 3396 deletions(-) diff --git a/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb b/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb index f2765c72c..34067c1c1 100644 --- a/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb +++ b/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb @@ -18,20 +18,15 @@ "\n", "### Summary\n", "\n", - "This notebook demonstrates how to fine tune Transformers for extractive text summarization. Utility functions and classes in the NLP Best Practices repo are used to facilitate data preprocessing, model training, model scoring, result postprocessing, and model evaluation.\n", + "This notebook demonstrates how to fine tune Transformers models like [BART](https://arxiv.org/abs/1910.13461) and [T5](https://arxiv.org/abs/1910.10683) together with HuggingFace's [transformers library](https://github.com/huggingface/transformers)for abstractive text summarization. Utility functions and classes in the NLP Best Practices repo are used to facilitate data preprocessing, model training, model scoring, result postprocessing, and model evaluation.\n", "\n", "\n", "\n", "\n", "### Before You Start\n", "\n", - "The running time shown in this notebook is on a Standard_NC24s_v3 Azure Ubuntu Virtual Machine with 4 NVIDIA Tesla V100 GPUs. \n", - "> **Tip**: If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`** to run the notebook on a small subset of the data and a smaller number of epochs. \n", - "\n", - "Using only 1 NVIDIA Tesla V100 GPUs, 16GB GPU memory configuration,\n", - "- for data preprocessing, it takes around 1 minutes to preprocess the data for quick run. Otherwise it takes ~20 minutes to finish the data preprocessing. This time estimation assumes that the chosen transformer model is \"distilbert-base-uncased\" and the sentence selection method is \"greedy\", which is the default. The preprocessing time can be significantly longer if the sentence selection method is \"combination\", which can achieve better model performance.\n", "\n", - "- for model fine tuning, it takes around 2 minutes for quick run. Otherwise, it takes around ~3 hours to finish. This estimation assumes the chosen encoder method is \"transformer\". The model fine tuning time can be shorter if other encoder method is chosen, which may result in worse model performance. \n", + "> **Tip**: If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`** to run the notebook on a small subset of the data and a smaller number of epochs. \n", "\n", "### Additional Notes\n", "\n", @@ -44,35 +39,19 @@ }, { "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 47, + "execution_count": 1, "metadata": { "tags": [ "parameters" ] }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "The autoreload extension is already loaded. To reload it, use:\n", - " %reload_ext autoreload\n" - ] - } - ], + "outputs": [], "source": [ "%load_ext autoreload\n", "\n", "%autoreload 2\n", "## Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.\n", - "QUICK_RUN = False" + "QUICK_RUN = True" ] }, { @@ -93,30 +72,8 @@ "name": "stderr", "output_type": "stream", "text": [ - "/home/daden/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:516: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", - " _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n", - "/home/daden/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", - " _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n", - "/home/daden/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", - " _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n", - "/home/daden/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", - " _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n", - "/home/daden/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", - " _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n", - "/home/daden/.local/lib/python3.6/site-packages/tensorflow/python/framework/dtypes.py:525: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", - " np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n", - "/home/daden/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:541: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", - " _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n", - "/home/daden/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:542: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", - " _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n", - "/home/daden/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:543: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", - " _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n", - "/home/daden/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:544: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", - " _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n", - "/home/daden/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:545: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", - " _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n", - "/home/daden/.local/lib/python3.6/site-packages/tensorboard/compat/tensorflow_stub/dtypes.py:550: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n", - " np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n" + "/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/site-packages/dask/dataframe/utils.py:15: FutureWarning: pandas.util.testing is deprecated. Use the functions in the public API at pandas.testing instead.\n", + " import pandas.util.testing as tm\n" ] } ], @@ -125,13 +82,14 @@ "import shutil\n", "import sys\n", "from tempfile import TemporaryDirectory\n", + "import time\n", "import torch\n", "\n", "nlp_path = os.path.abspath(\"../../\")\n", "if nlp_path not in sys.path:\n", " sys.path.insert(0, nlp_path)\n", "\n", - "from utils_nlp.dataset.cnndm import CNNDMBertSumProcessedData, CNNDMSummarizationDataset\n", + "from utils_nlp.dataset.cnndm import CNNDMSummarizationDataset\n", "from utils_nlp.eval import compute_rouge_python, compute_rouge_perl\n", "from utils_nlp.models.transformers.abstractive_summarization_bartt5 import (\n", " AbstractiveSummarizer, SummarizationProcessor, validate)\n", @@ -142,7 +100,8 @@ "\n", "import pandas as pd\n", "import scrapbook as sb\n", - "import pprint" + "import pprint\n", + "start_time = time.time()" ] }, { @@ -157,21 +116,104 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Several pretrained models have been made available by [Hugging Face](https://github.com/huggingface/transformers). For extractive summarization, the following pretrained models are supported. " + "Several pretrained models have been made available by [Hugging Face](https://github.com/huggingface/transformers). For abstractive summarization, the following pretrained models are supported. " ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
model_name
0bart-large
1bart-large-mnli
2bart-large-cnn
3bart-large-xsum
4t5-small
5t5-base
6t5-large
7t5-3b
8t5-11b
\n", + "
" + ], + "text/plain": [ + " model_name\n", + "0 bart-large\n", + "1 bart-large-mnli\n", + "2 bart-large-cnn\n", + "3 bart-large-xsum\n", + "4 t5-small\n", + "5 t5-base\n", + "6 t5-large\n", + "7 t5-3b\n", + "8 t5-11b" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "#pd.DataFrame({\"model_name\": ExtractiveSummarizer.list_supported_models()})" + "pd.DataFrame({\"model_name\": AbstractiveSummarizer.list_supported_models()})" ] }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 4, "metadata": { "tags": [ "parameters" @@ -180,3003 +222,316 @@ "outputs": [], "source": [ "# Transformer model being used\n", - "#MODEL_NAME = \"t5-large\"\n", - "MODEL_NAME = \"bart-large-cnn\"\n", + "# MODEL_NAME = \"bart-large\"\n", + "MODEL_NAME = \"t5-small\"\n", "# notebook parameters\n", "# the cache data path during find tuning\n", - "CACHE_DIR = \"./bart_cache\" #TemporaryDirectory().name\n", + "CACHE_DIR = \"./t5_cache\" #TemporaryDirectory().name\n", "summarizer = AbstractiveSummarizer(MODEL_NAME, cache_dir=CACHE_DIR)" ] }, { - "cell_type": "code", - "execution_count": 27, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['hello', 'Ġfrench', 's', 'df', 'a']" - ] - }, - "execution_count": 27, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ - "# bart-large\n", - "summarizer.tokenizer.tokenize(\"hello frenchsdfa \")" + "### Data Preprocessing\n", + "\n", + "The dataset we used for this notebook is CNN/DM dataset which contains the documents and accompanying questions from the news articles of CNN and Daily mail. The highlights in each article are used as summary. The dataset consits of ~289K training examples, ~11K valiation examples and ~11K test examples. The code in following cell will download the CNN/DM dataset listed at https://github.com/harvardnlp/sent-summary/.\n" ] }, { "cell_type": "code", - "execution_count": 35, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "['hello', 'Ġn', 'lp', 'Ġam', 'azon', 'Ġch', 'ina']" - ] - }, - "execution_count": 35, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": 5, + "metadata": { + "tags": [ + "parameters" + ] + }, + "outputs": [], "source": [ - "summarizer.tokenizer.tokenize(\"hello nlp amazon china\")" + "# the data path used to save the downloaded data file\n", + "DATA_PATH = \"./bartt5_cnndm\" #TemporaryDirectory().name\n", + "# The number of lines at the head of data file used for preprocessing. -1 means all the lines.\n", + "TOP_N = 100\n", + "if not QUICK_RUN:\n", + " TOP_N = -1" ] }, { "cell_type": "code", - "execution_count": 32, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "" - ] - }, - "execution_count": 32, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": 6, + "metadata": { + "scrolled": false + }, + "outputs": [], "source": [ - "summarizer.tokenizer" + "train_dataset, test_dataset = CNNDMSummarizationDataset(top_n=TOP_N, local_cache_path=DATA_PATH, raw=True)" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "BartForConditionalGeneration(\n", - " (model): BartModel(\n", - " (shared): Embedding(50265, 1024, padding_idx=1)\n", - " (encoder): BartEncoder(\n", - " (embed_tokens): Embedding(50265, 1024, padding_idx=1)\n", - " (embed_positions): LearnedPositionalEmbedding(1026, 1024, padding_idx=1)\n", - " (layers): ModuleList(\n", - " (0): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (1): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (2): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (3): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (4): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (5): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (6): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (7): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (8): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (9): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (10): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (11): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " )\n", - " (layernorm_embedding): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (decoder): BartDecoder(\n", - " (embed_tokens): Embedding(50265, 1024, padding_idx=1)\n", - " (embed_positions): LearnedPositionalEmbedding(1026, 1024, padding_idx=1)\n", - " (layers): ModuleList(\n", - " (0): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (1): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (2): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (3): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (4): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (5): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (6): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (7): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (8): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (9): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (10): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (11): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " )\n", - " (layernorm_embedding): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " )\n", - ")" + "\" marseille prosecutor says `` so far no videos were used in the crash investigation '' despite media reports . journalists at bild and paris match are `` very confident '' the video clip is real , an editor says . andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says . \\n\"" ] }, - "execution_count": 4, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "summarizer.model" + "test_dataset[0]['tgt_txt']" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "BartConfig {\n", - " \"_num_labels\": 3,\n", - " \"activation_dropout\": 0.0,\n", - " \"activation_function\": \"gelu\",\n", - " \"add_final_layer_norm\": false,\n", - " \"architectures\": [\n", - " \"BartModel\",\n", - " \"BartForMaskedLM\",\n", - " \"BartForSequenceClassification\"\n", - " ],\n", - " \"attention_dropout\": 0.0,\n", - " \"bad_words_ids\": null,\n", - " \"bos_token_id\": 0,\n", - " \"classif_dropout\": 0.0,\n", - " \"d_model\": 1024,\n", - " \"decoder_attention_heads\": 16,\n", - " \"decoder_ffn_dim\": 4096,\n", - " \"decoder_layerdrop\": 0.0,\n", - " \"decoder_layers\": 12,\n", - " \"decoder_start_token_id\": 2,\n", - " \"do_sample\": false,\n", - " \"dropout\": 0.1,\n", - " \"early_stopping\": false,\n", - " \"encoder_attention_heads\": 16,\n", - " \"encoder_ffn_dim\": 4096,\n", - " \"encoder_layerdrop\": 0.0,\n", - " \"encoder_layers\": 12,\n", - " \"eos_token_id\": 2,\n", - " \"finetuning_task\": null,\n", - " \"id2label\": {\n", - " \"0\": \"LABEL_0\",\n", - " \"1\": \"LABEL_1\",\n", - " \"2\": \"LABEL_2\"\n", - " },\n", - " \"init_std\": 0.02,\n", - " \"is_decoder\": false,\n", - " \"is_encoder_decoder\": true,\n", - " \"label2id\": {\n", - " \"LABEL_0\": 0,\n", - " \"LABEL_1\": 1,\n", - " \"LABEL_2\": 2\n", - " },\n", - " \"length_penalty\": 1.0,\n", - " \"max_length\": 20,\n", - " \"max_position_embeddings\": 1024,\n", - " \"min_length\": 0,\n", - " \"model_type\": \"bart\",\n", - " \"no_repeat_ngram_size\": 0,\n", - " \"normalize_before\": false,\n", - " \"num_beams\": 1,\n", - " \"num_hidden_layers\": 12,\n", - " \"num_return_sequences\": 1,\n", - " \"output_attentions\": false,\n", - " \"output_hidden_states\": false,\n", - " \"output_past\": false,\n", - " \"pad_token_id\": 1,\n", - " \"prefix\": \" \",\n", - " \"pruned_heads\": {},\n", - " \"repetition_penalty\": 1.0,\n", - " \"scale_embedding\": false,\n", - " \"task_specific_params\": {\n", - " \"summarization\": {\n", - " \"early_stopping\": true,\n", - " \"length_penalty\": 2.0,\n", - " \"max_length\": 142,\n", - " \"min_length\": 56,\n", - " \"no_repeat_ngram_size\": 3,\n", - " \"num_beams\": 4\n", - " }\n", - " },\n", - " \"temperature\": 1.0,\n", - " \"top_k\": 50,\n", - " \"top_p\": 1.0,\n", - " \"torchscript\": false,\n", - " \"use_bfloat16\": false,\n", - " \"vocab_size\": 50265\n", - "}" + "100" ] }, - "execution_count": 11, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "summarizer.config" + "len(test_dataset)" ] }, { - "cell_type": "code", - "execution_count": 5, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "fe888b456fcb438a8c330d3f7328cfe8", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1300.0, style=ProgressStyle(description…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "ebe6e048029042029ef1d24008ebe7d5", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1625270765.0, style=ProgressStyle(descr…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "text/plain": [ - "BartForConditionalGeneration(\n", - " (model): BartModel(\n", - " (shared): Embedding(50264, 1024, padding_idx=1)\n", - " (encoder): BartEncoder(\n", - " (embed_tokens): Embedding(50264, 1024, padding_idx=1)\n", - " (embed_positions): LearnedPositionalEmbedding(1026, 1024, padding_idx=1)\n", - " (layers): ModuleList(\n", - " (0): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (1): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (2): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (3): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (4): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (5): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (6): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (7): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (8): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (9): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (10): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (11): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " )\n", - " (layernorm_embedding): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (decoder): BartDecoder(\n", - " (embed_tokens): Embedding(50264, 1024, padding_idx=1)\n", - " (embed_positions): LearnedPositionalEmbedding(1026, 1024, padding_idx=1)\n", - " (layers): ModuleList(\n", - " (0): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (1): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (2): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (3): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (4): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (5): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (6): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (7): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (8): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (9): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (10): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (11): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " )\n", - " (layernorm_embedding): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " )\n", - ")" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ - "# Transformer model being used\n", - "#MODEL_NAME = \"t5-large\"\n", - "MODEL_NAME = \"bart-large-cnn\"\n", - "# notebook parameters\n", - "# the cache data path during find tuning\n", - "CACHE_DIR = \"./bart_cache\" #TemporaryDirectory().name\n", - "summarizer = AbstractiveSummarizer(MODEL_NAME, cache_dir=CACHE_DIR)\n", - "summarizer.model" + "Preprocess the data." ] }, { "cell_type": "code", "execution_count": 9, - "metadata": {}, + "metadata": { + "scrolled": false + }, "outputs": [ { - "data": { - "text/plain": [ - "BartConfig {\n", - " \"_num_labels\": 3,\n", - " \"activation_dropout\": 0.0,\n", - " \"activation_function\": \"gelu\",\n", - " \"add_final_layer_norm\": false,\n", - " \"architectures\": null,\n", - " \"attention_dropout\": 0.0,\n", - " \"bad_words_ids\": null,\n", - " \"bos_token_id\": 0,\n", - " \"classif_dropout\": 0.0,\n", - " \"d_model\": 1024,\n", - " \"decoder_attention_heads\": 16,\n", - " \"decoder_ffn_dim\": 4096,\n", - " \"decoder_layerdrop\": 0.0,\n", - " \"decoder_layers\": 12,\n", - " \"decoder_start_token_id\": 2,\n", - " \"do_sample\": false,\n", - " \"dropout\": 0.1,\n", - " \"early_stopping\": true,\n", - " \"encoder_attention_heads\": 16,\n", - " \"encoder_ffn_dim\": 4096,\n", - " \"encoder_layerdrop\": 0.0,\n", - " \"encoder_layers\": 12,\n", - " \"eos_token_id\": 2,\n", - " \"finetuning_task\": null,\n", - " \"id2label\": {\n", - " \"0\": \"LABEL_0\",\n", - " \"1\": \"LABEL_1\",\n", - " \"2\": \"LABEL_2\"\n", - " },\n", - " \"init_std\": 0.02,\n", - " \"is_decoder\": false,\n", - " \"is_encoder_decoder\": true,\n", - " \"label2id\": {\n", - " \"LABEL_0\": 0,\n", - " \"LABEL_1\": 1,\n", - " \"LABEL_2\": 2\n", - " },\n", - " \"length_penalty\": 2.0,\n", - " \"max_length\": 142,\n", - " \"max_position_embeddings\": 1024,\n", - " \"min_length\": 56,\n", - " \"model_type\": \"bart\",\n", - " \"no_repeat_ngram_size\": 3,\n", - " \"normalize_before\": false,\n", - " \"num_beams\": 4,\n", - " \"num_hidden_layers\": 12,\n", - " \"num_return_sequences\": 1,\n", - " \"output_attentions\": false,\n", - " \"output_hidden_states\": false,\n", - " \"output_past\": true,\n", - " \"pad_token_id\": 1,\n", - " \"prefix\": \" \",\n", - " \"pruned_heads\": {},\n", - " \"repetition_penalty\": 1.0,\n", - " \"scale_embedding\": false,\n", - " \"task_specific_params\": {\n", - " \"summarization\": {\n", - " \"early_stopping\": true,\n", - " \"length_penalty\": 2.0,\n", - " \"max_length\": 142,\n", - " \"min_length\": 56,\n", - " \"no_repeat_ngram_size\": 3,\n", - " \"num_beams\": 4\n", - " }\n", - " },\n", - " \"temperature\": 1.0,\n", - " \"top_k\": 50,\n", - " \"top_p\": 1.0,\n", - " \"torchscript\": false,\n", - " \"use_bfloat16\": false,\n", - " \"vocab_size\": 50264\n", - "}" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "CPU times: user 540 ms, sys: 0 ns, total: 540 ms\n", + "Wall time: 539 ms\n" + ] } ], "source": [ - "summarizer.config" + "%%time\n", + "abs_sum_train = summarizer.processor.preprocess(train_dataset)\n", + "# torch.save(abs_sum_train, os.path.join(DATA_PATH, \"train_{0}_full.pt\".format(MODEL_NAME)))\n", + "# abs_sum_train = torch.load(os.path.join(DATA_PATH, \"train_{0}_full.pt\".format(MODEL_NAME)))" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 10, "metadata": {}, "outputs": [], "source": [ - "task_specific_params = summarizer.config.task_specific_params" + "abs_sum_test = summarizer.processor.preprocess(test_dataset)\n", + "# torch.save(abs_sum_test, os.path.join(DATA_PATH, \"test_{0}_full.pt\".format(MODEL_NAME)))\n", + "# abs_sum_test = torch.load(os.path.join(DATA_PATH, \"test_{0}_full.pt\".format(MODEL_NAME)))" ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 11, "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "{'summarization': {'early_stopping': True,\n", - " 'length_penalty': 2.0,\n", - " 'max_length': 142,\n", - " 'min_length': 56,\n", - " 'no_repeat_ngram_size': 3,\n", - " 'num_beams': 4}}" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "100\n", + "100\n" + ] } ], "source": [ - "task_specific_params" + "print(len(abs_sum_train))\n", + "print(len(abs_sum_test))" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "tokens = bart.encode('Hello world!')\n", - "assert tokens.tolist() == [0, 31414, 232, 328, 2]\n", - "bart.decode(tokens) # 'Hello world!" + "#### Inspect Data" ] }, { "cell_type": "code", - "execution_count": 44, + "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "[0, 20920, 7619, 611, 328, 2]" + "dict_keys(['src', 'src_txt', 'tgt', 'tgt_txt', 'source_ids', 'source_mask', 'target_ids'])" ] }, - "execution_count": 44, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "summarizer.tokenizer.encode('Hello frech!')" + "abs_sum_train[0].keys()" ] }, { "cell_type": "code", - "execution_count": 46, - "metadata": {}, + "execution_count": 13, + "metadata": { + "scrolled": false + }, "outputs": [ { "data": { "text/plain": [ - "' Hello frech!'" + "{'src': 'editor \\'s note : in our behind the scenes series , cnn correspondents share their experiences in covering news and analyze the stories behind the events . here , soledad o\\'brien takes users inside a jail where many of the inmates are mentally ill . an inmate housed on the \" forgotten floor , \" where many mentally ill inmates are housed in miami before trial . miami , florida ( cnn ) -- the ninth floor of the miami-dade pretrial detention facility is dubbed the \" forgotten floor . \" here , inmates with the most severe mental illnesses are incarcerated until they \\'re ready to appear in court . most often , they face drug charges or charges of assaulting an officer -- charges that judge steven leifman says are usually \" avoidable felonies . \" he says the arrests often result from confrontations with police . mentally ill people often wo n\\'t do what they \\'re told when police arrive on the scene -- confrontation seems to exacerbate their illness and they become more paranoid , delusional , and less likely to follow directions , according to leifman . so , they end up on the ninth floor severely mentally disturbed , but not getting any real help because they \\'re in jail . we toured the jail with leifman . he is well known in miami as an advocate for justice and the mentally ill . even though we were not exactly welcomed with open arms by the guards , we were given permission to shoot videotape and tour the floor . go inside the ` forgotten floor \\' \" at first , it \\'s hard to determine where the people are . the prisoners are wearing sleeveless robes . imagine cutting holes for arms and feet in a heavy wool sleeping bag -- that \\'s kind of what they look like . they \\'re designed to keep the mentally ill patients from injuring themselves . that \\'s also why they have no shoes , laces or mattresses . leifman says about one-third of all people in miami-dade county jails are mentally ill . so , he says , the sheer volume is overwhelming the system , and the result is what we see on the ninth floor . of course , it is a jail , so it \\'s not supposed to be warm and comforting , but the lights glare , the cells are tiny and it \\'s loud . we see two , sometimes three men -- sometimes in the robes , sometimes naked , lying or sitting in their cells . \" i am the son of the president . you need to get me out of here ! \" one man shouts at me . he is absolutely serious , convinced that help is on the way -- if only he could reach the white house . leifman tells me that these prisoner-patients will often circulate through the system , occasionally stabilizing in a mental hospital , only to return to jail to face their charges . it \\'s brutally unjust , in his mind , and he has become a strong advocate for changing things in miami . over a meal later , we talk about how things got this way for mental patients . leifman says 200 years ago people were considered \" lunatics \" and they were locked up in jails even if they had no charges against them . they were just considered unfit to be in society . over the years , he says , there was some public outcry , and the mentally ill were moved out of jails and into hospitals . but leifman says many of these mental hospitals were so horrible they were shut down . where did the patients go ? nowhere . the streets . they became , in many cases , the homeless , he says . they never got treatment . leifman says in 1955 there were more than half a million people in state mental hospitals , and today that number has been reduced 90 percent , and 40,000 to 50,000 people are in mental hospitals . the judge says he \\'s working to change this . starting in 2008 , many inmates who would otherwise have been brought to the \" forgotten floor \" will instead be sent to a new mental health facility -- the first step on a journey toward long-term treatment , not just punishment . leifman says it \\'s not the complete answer , but it \\'s a start . leifman says the best part is that it \\'s a win-win solution . the patients win , the families are relieved , and the state saves money by simply not cycling these prisoners through again and again . and , for leifman , justice is served . e-mail to a friend .\\n',\n", + " 'src_txt': \"editor 's note : in our behind the scenes series , cnn correspondents share their experiences in covering news and analyze the stories behind the events . here , soledad o'brien takes users inside a jail where many of the inmates are mentally ill . an inmate housed on the `` forgotten floor , '' where many mentally ill inmates are housed in miami before trial . miami , florida -lrb- cnn -rrb- -- the ninth floor of the miami-dade pretrial detention facility is dubbed the `` forgotten floor . '' here , inmates with the most severe mental illnesses are incarcerated until they 're ready to appear in court . most often , they face drug charges or charges of assaulting an officer -- charges that judge steven leifman says are usually `` avoidable felonies . '' he says the arrests often result from confrontations with police . mentally ill people often wo n't do what they 're told when police arrive on the scene -- confrontation seems to exacerbate their illness and they become more paranoid , delusional , and less likely to follow directions , according to leifman . so , they end up on the ninth floor severely mentally disturbed , but not getting any real help because they 're in jail . we toured the jail with leifman . he is well known in miami as an advocate for justice and the mentally ill . even though we were not exactly welcomed with open arms by the guards , we were given permission to shoot videotape and tour the floor . go inside the ` forgotten floor ' '' at first , it 's hard to determine where the people are . the prisoners are wearing sleeveless robes . imagine cutting holes for arms and feet in a heavy wool sleeping bag -- that 's kind of what they look like . they 're designed to keep the mentally ill patients from injuring themselves . that 's also why they have no shoes , laces or mattresses . leifman says about one-third of all people in miami-dade county jails are mentally ill . so , he says , the sheer volume is overwhelming the system , and the result is what we see on the ninth floor . of course , it is a jail , so it 's not supposed to be warm and comforting , but the lights glare , the cells are tiny and it 's loud . we see two , sometimes three men -- sometimes in the robes , sometimes naked , lying or sitting in their cells . `` i am the son of the president . you need to get me out of here ! '' one man shouts at me . he is absolutely serious , convinced that help is on the way -- if only he could reach the white house . leifman tells me that these prisoner-patients will often circulate through the system , occasionally stabilizing in a mental hospital , only to return to jail to face their charges . it 's brutally unjust , in his mind , and he has become a strong advocate for changing things in miami . over a meal later , we talk about how things got this way for mental patients . leifman says 200 years ago people were considered `` lunatics '' and they were locked up in jails even if they had no charges against them . they were just considered unfit to be in society . over the years , he says , there was some public outcry , and the mentally ill were moved out of jails and into hospitals . but leifman says many of these mental hospitals were so horrible they were shut down . where did the patients go ? nowhere . the streets . they became , in many cases , the homeless , he says . they never got treatment . leifman says in 1955 there were more than half a million people in state mental hospitals , and today that number has been reduced 90 percent , and 40,000 to 50,000 people are in mental hospitals . the judge says he 's working to change this . starting in 2008 , many inmates who would otherwise have been brought to the `` forgotten floor '' will instead be sent to a new mental health facility -- the first step on a journey toward long-term treatment , not just punishment . leifman says it 's not the complete answer , but it 's a start . leifman says the best part is that it 's a win-win solution . the patients win , the families are relieved , and the state saves money by simply not cycling these prisoners through again and again . and , for leifman , justice is served . e-mail to a friend .\\n\",\n", + " 'tgt': ' mentally ill inmates in miami are housed on the \" forgotten floor \" judge steven leifman says most are there as a result of \" avoidable felonies \" while cnn tours facility , patient shouts : \" i am the son of the president \" leifman says the system is unjust and he \\'s fighting for change . \\n',\n", + " 'tgt_txt': \" mentally ill inmates in miami are housed on the `` forgotten floor '' judge steven leifman says most are there as a result of `` avoidable felonies '' while cnn tours facility , patient shouts : `` i am the son of the president '' leifman says the system is unjust and he 's fighting for change . \\n\",\n", + " 'source_ids': tensor([21603, 10, 6005, ..., 3, 31, 7]),\n", + " 'source_mask': tensor([1, 1, 1, ..., 1, 1, 1]),\n", + " 'target_ids': tensor([19367, 3, 1092, 16, 11171, 16, 1337, 3690, 33, 629,\n", + " 26, 30, 8, 96, 11821, 1501, 96, 5191, 3, 849,\n", + " 1926, 90, 99, 348, 845, 167, 33, 132, 38, 3,\n", + " 9, 741, 13, 96, 1792, 179, 3110, 106, 725, 96,\n", + " 298, 3, 75, 29, 29, 8108, 3064, 3, 6, 1868,\n", + " 14314, 7, 3, 10, 96, 3, 23, 183, 8, 520,\n", + " 13, 8, 2753, 96, 90, 99, 348, 845, 8, 358,\n", + " 19, 73, 4998, 11, 3, 88, 3, 31, 7, 6237,\n", + " 21, 483, 3, 5, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", + " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])}" ] }, - "execution_count": 46, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "summarizer.tokenizer.decode([0, 20920, 7619, 611, 328, 2], skip_special_tokens=True, clean_up_tokenization_spaces=True)" + "abs_sum_train[0]" ] }, { - "cell_type": "code", - "execution_count": 36, + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "'Hello world!'" - ] - }, - "execution_count": 36, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ - "summarizer.tokenizer.decode([0, 31414, 232, 328, 2], skip_special_tokens=True, clean_up_tokenization_spaces=True)" + "## Fine tune model\n", + "To start model fine-tuning, we need to specify the paramters as follows." ] }, { "cell_type": "code", - "execution_count": 18, - "metadata": {}, + "execution_count": 14, + "metadata": { + "tags": [ + "parameters" + ] + }, "outputs": [], "source": [ - "summarizer.config.update(task_specific_params.get(\"summarization\", {}))" + "BATCH_SIZE_PER_GPU = 4\n", + "GRADIENT_ACCUMULATION_STEPS = 1\n", + "MAX_POS_LENGTH = 512\n", + "\n", + "# GPU used for training\n", + "NUM_GPUS = torch.cuda.device_count()\n", + "\n", + "\n", + "# Learning rate\n", + "LEARNING_RATE=3e-5\n", + "MAX_GRAD_NORM=0.1\n", + "\n", + "# How often the statistics reports show up in training, unit is step.\n", + "REPORT_EVERY=100\n", + "SAVE_EVERY=1000\n", + "\n", + "# total number of steps for training\n", + "MAX_STEPS=1000\n", + "# number of steps for warm up\n", + "WARMUP_STEPS=5e2\n", + " \n", + "if not QUICK_RUN:\n", + " MAX_STEPS=2e4\n", + " WARMUP_STEPS=5e3\n", + " " ] }, { "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "BartConfig {\n", - " \"_num_labels\": 3,\n", - " \"activation_dropout\": 0.0,\n", - " \"activation_function\": \"gelu\",\n", - " \"add_final_layer_norm\": false,\n", - " \"architectures\": [\n", - " \"BartModel\",\n", - " \"BartForMaskedLM\",\n", - " \"BartForSequenceClassification\"\n", - " ],\n", - " \"attention_dropout\": 0.0,\n", - " \"bad_words_ids\": null,\n", - " \"bos_token_id\": 0,\n", - " \"classif_dropout\": 0.0,\n", - " \"d_model\": 1024,\n", - " \"decoder_attention_heads\": 16,\n", - " \"decoder_ffn_dim\": 4096,\n", - " \"decoder_layerdrop\": 0.0,\n", - " \"decoder_layers\": 12,\n", - " \"decoder_start_token_id\": 2,\n", - " \"do_sample\": false,\n", - " \"dropout\": 0.1,\n", - " \"early_stopping\": true,\n", - " \"encoder_attention_heads\": 16,\n", - " \"encoder_ffn_dim\": 4096,\n", - " \"encoder_layerdrop\": 0.0,\n", - " \"encoder_layers\": 12,\n", - " \"eos_token_id\": 2,\n", - " \"finetuning_task\": null,\n", - " \"id2label\": {\n", - " \"0\": \"LABEL_0\",\n", - " \"1\": \"LABEL_1\",\n", - " \"2\": \"LABEL_2\"\n", - " },\n", - " \"init_std\": 0.02,\n", - " \"is_decoder\": false,\n", - " \"is_encoder_decoder\": true,\n", - " \"label2id\": {\n", - " \"LABEL_0\": 0,\n", - " \"LABEL_1\": 1,\n", - " \"LABEL_2\": 2\n", - " },\n", - " \"length_penalty\": 2.0,\n", - " \"max_length\": 142,\n", - " \"max_position_embeddings\": 1024,\n", - " \"min_length\": 56,\n", - " \"model_type\": \"bart\",\n", - " \"no_repeat_ngram_size\": 3,\n", - " \"normalize_before\": false,\n", - " \"num_beams\": 4,\n", - " \"num_hidden_layers\": 12,\n", - " \"num_return_sequences\": 1,\n", - " \"output_attentions\": false,\n", - " \"output_hidden_states\": false,\n", - " \"output_past\": false,\n", - " \"pad_token_id\": 1,\n", - " \"prefix\": \" \",\n", - " \"pruned_heads\": {},\n", - " \"repetition_penalty\": 1.0,\n", - " \"scale_embedding\": false,\n", - " \"task_specific_params\": {\n", - " \"summarization\": {\n", - " \"early_stopping\": true,\n", - " \"length_penalty\": 2.0,\n", - " \"max_length\": 142,\n", - " \"min_length\": 56,\n", - " \"no_repeat_ngram_size\": 3,\n", - " \"num_beams\": 4\n", - " }\n", - " },\n", - " \"temperature\": 1.0,\n", - " \"top_k\": 50,\n", - " \"top_p\": 1.0,\n", - " \"torchscript\": false,\n", - " \"use_bfloat16\": false,\n", - " \"vocab_size\": 50265\n", - "}" - ] - }, - "execution_count": 23, - "metadata": {}, - "output_type": "execute_result" - } - ], + "execution_count": null, + "metadata": { + "scrolled": true + }, + "outputs": [], "source": [ - "summarizer.config" + "\n", + "summarizer.fit(\n", + " abs_sum_train,\n", + " num_gpus=NUM_GPUS,\n", + " batch_size=BATCH_SIZE_PER_GPU*NUM_GPUS,\n", + " gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,\n", + " max_steps=MAX_STEPS,\n", + " learning_rate=LEARNING_RATE,\n", + " max_grad_norm=MAX_GRAD_NORM,\n", + " warmup_steps=WARMUP_STEPS,\n", + " verbose=True,\n", + " report_every=REPORT_EVERY,\n", + " )\n" ] }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'early_stopping': True,\n", - " 'length_penalty': 2.0,\n", - " 'max_length': 142,\n", - " 'min_length': 56,\n", - " 'no_repeat_ngram_size': 3,\n", - " 'num_beams': 4}" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "task_specific_params.get(\"summarization\", {})" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [], - "source": [ - "summarizer.config.update({\"vocab_size\": 50264})" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Data Preprocessing\n", - "\n", - "The dataset we used for this notebook is CNN/DM dataset which contains the documents and accompanying questions from the news articles of CNN and Daily mail. The highlights in each article are used as summary. The dataset consits of ~289K training examples, ~11K valiation examples and ~11K test examples. The code in following cell will download the CNN/DM dataset listed at https://github.com/harvardnlp/sent-summary/.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "metadata": { - "tags": [ - "parameters" - ] - }, - "outputs": [], - "source": [ - "# the data path used to save the downloaded data file\n", - "DATA_PATH = \"./bartt5_cnndm\" #TemporaryDirectory().name\n", - "# The number of lines at the head of data file used for preprocessing. -1 means all the lines.\n", - "TOP_N = 1000\n", - "if not QUICK_RUN:\n", - " TOP_N = -1" - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "train_dataset, test_dataset = CNNDMSummarizationDataset(top_n=TOP_N, local_cache_path=DATA_PATH, raw=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\" marseille prosecutor says `` so far no videos were used in the crash investigation '' despite media reports . journalists at bild and paris match are `` very confident '' the video clip is real , an editor says . andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says . \\n\"" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "test_dataset[0]['tgt_txt']" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "11490" - ] - }, - "execution_count": 10, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "len(test_dataset)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "Preprocess the data." - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": { - "scrolled": false - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CPU times: user 5 µs, sys: 2 µs, total: 7 µs\n", - "Wall time: 20.3 µs\n" - ] - } - ], - "source": [ - "%time\n", - "abs_sum_train = summarizer.processor.preprocess(train_dataset)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "# torch.save(abs_sum_train, os.path.join(DATA_PATH, \"train_{0}_full.pt\".format(MODEL_NAME)))" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "# torch.save(abs_sum_test, os.path.join(DATA_PATH, \"test_{0}_full.pt\".format(MODEL_NAME)))" - ] - }, - { - "cell_type": "code", - "execution_count": 51, - "metadata": {}, - "outputs": [], - "source": [ - "abs_sum_test = summarizer.processor.preprocess(test_dataset)" - ] - }, - { - "cell_type": "code", - "execution_count": 58, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "{'input_ids': tensor([[ 0, 4401, 1090, ..., 167, 1081, 2],\n", - " [ 0, 36, 740, ..., 1, 1, 1]], device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, ..., 1, 1, 1],\n", - " [1, 1, 1, ..., 0, 0, 0]], device='cuda:0'), 'decoder_input_ids': tensor([[ 0, 4401, 1090, 4061, 5644, 161, 22, 98, 444, 117,\n", - " 3424, 58, 341, 11, 5, 2058, 803, 22, 1135, 433,\n", - " 690, 479, 1437, 1437, 4225, 23, 741, 9683, 8, 2242,\n", - " 354, 914, 32, 22, 182, 3230, 22, 5, 569, 7200,\n", - " 16, 588, 2156, 41, 4474, 161, 479, 1437, 1437, 8,\n", - " 241, 281, 784, 1792, 4494, 56, 3978, 39, 784, 2951,\n", - " 212, 1253, 102, 1058, 334, 9, 41, 3238, 9, 3814,\n", - " 6943, 2156, 5195, 161, 479],\n", - " [ 0, 6332, 2029, 5, 41591, 438, 10542, 81, 1697, 3474,\n", - " 2021, 11, 8750, 990, 28307, 13560, 187, 94, 1236, 4438,\n", - " 479, 1437, 1437, 16, 37715, 8, 5, 10409, 982, 4340,\n", - " 5, 517, 2156, 61, 115, 490, 5, 1883, 7, 997,\n", - " 3474, 4941, 136, 16, 37715, 354, 479, 2, 1, 1,\n", - " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", - " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", - " 1, 1, 1, 1, 1]], device='cuda:0'), 'lm_labels': tensor([[ 4401, 1090, 4061, 5644, 161, 22, 98, 444, 117, 3424,\n", - " 58, 341, 11, 5, 2058, 803, 22, 1135, 433, 690,\n", - " 479, 1437, 1437, 4225, 23, 741, 9683, 8, 2242, 354,\n", - " 914, 32, 22, 182, 3230, 22, 5, 569, 7200, 16,\n", - " 588, 2156, 41, 4474, 161, 479, 1437, 1437, 8, 241,\n", - " 281, 784, 1792, 4494, 56, 3978, 39, 784, 2951, 212,\n", - " 1253, 102, 1058, 334, 9, 41, 3238, 9, 3814, 6943,\n", - " 2156, 5195, 161, 479, 2],\n", - " [ 6332, 2029, 5, 41591, 438, 10542, 81, 1697, 3474, 2021,\n", - " 11, 8750, 990, 28307, 13560, 187, 94, 1236, 4438, 479,\n", - " 1437, 1437, 16, 37715, 8, 5, 10409, 982, 4340, 5,\n", - " 517, 2156, 61, 115, 490, 5, 1883, 7, 997, 3474,\n", - " 4941, 136, 16, 37715, 354, 479, 2, -100, -100, -100,\n", - " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n", - " -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,\n", - " -100, -100, -100, -100, -100]], device='cuda:0')}\n" - ] - } - ], - "source": [ - "a = summarizer.processor.collate_fn(abs_sum_test[0:2], \"cuda:0\", True)\n", - "c = summarizer.processor.get_inputs(a, \"cuda:0\", MODEL_NAME, summarizer.tokenizer, True)\n", - "print(c)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "abs_sum_train = torch.load( os.path.join(DATA_PATH, \"train_{0}_full.pt\".format(MODEL_NAME)))" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\r", - "Generating summary: 0%| | 0/1 [00:00 marseille prosecutor says `` so far no videos were used in the crash investigation '' despite media reports . journalists at bild and paris match are `` very confident '' the video clip is real , an editor says . andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says . \\n\", 'tgt_txt': \" marseille prosecutor says `` so far no videos were used in the crash investigation '' despite media reports . journalists at bild and paris match are `` very confident '' the video clip is real , an editor says . andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says . \\n\", 'source_ids': tensor([ 0, 4401, 1090, ..., 604, 1725, 2]), 'source_mask': tensor([1, 1, 1, ..., 1, 1, 1]), 'target_ids': tensor([ 0, 28696, 90, 15698, 4401, 1090, 4061, 5644, 161, 45518,\n", - " 98, 444, 117, 3424, 58, 341, 11, 5, 2058, 803,\n", - " 12801, 1135, 433, 690, 479, 49703, 90, 15698, 28696, 90,\n", - " 15698, 4225, 23, 741, 9683, 8, 2242, 354, 914, 32,\n", - " 45518, 182, 3230, 12801, 5, 569, 7200, 16, 588, 2156,\n", - " 41, 4474, 161, 479, 49703, 90, 15698, 28696, 90, 15698,\n", - " 8, 241, 281, 784, 1792, 4494, 56, 3978, 39, 784,\n", - " 2951, 212, 1253, 102, 1058, 334, 9, 41, 3238, 9,\n", - " 3814, 6943, 2156, 5195, 161, 479, 49703, 90, 15698, 2,\n", - " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", - " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", - " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", - " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", - " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", - " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", - " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", - " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", - " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", - " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", - " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", - " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", - " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", - " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", - " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])}\n" - ] - } - ], - "source": [ - "print(abs_sum_test[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "nnnnennennsnnnnnnnnsnnennennennnnnnennsnnnnnnnnnnnnsnnnnnnnnnnnnnnnnennennnnnnnnnnnnnnnnennnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnennnnnnnnnnnnnnnnennnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnennnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnnn\n" - ] - } - ], - "source": [ - "print(prediction[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "BartForConditionalGeneration(\n", - " (model): BartModel(\n", - " (shared): Embedding(50265, 1024, padding_idx=1)\n", - " (encoder): BartEncoder(\n", - " (embed_tokens): Embedding(50265, 1024, padding_idx=1)\n", - " (embed_positions): LearnedPositionalEmbedding(1026, 1024, padding_idx=1)\n", - " (layers): ModuleList(\n", - " (0): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (1): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (2): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (3): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (4): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (5): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (6): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (7): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (8): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (9): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (10): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (11): EncoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " )\n", - " (layernorm_embedding): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (decoder): BartDecoder(\n", - " (embed_tokens): Embedding(50265, 1024, padding_idx=1)\n", - " (embed_positions): LearnedPositionalEmbedding(1026, 1024, padding_idx=1)\n", - " (layers): ModuleList(\n", - " (0): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (1): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (2): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (3): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (4): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (5): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (6): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (7): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (8): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (9): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (10): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " (11): DecoderLayer(\n", - " (self_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (self_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (encoder_attn): SelfAttention(\n", - " (k_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (v_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (q_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " (out_proj): Linear(in_features=1024, out_features=1024, bias=True)\n", - " )\n", - " (encoder_attn_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " (fc1): Linear(in_features=1024, out_features=4096, bias=True)\n", - " (fc2): Linear(in_features=4096, out_features=1024, bias=True)\n", - " (final_layer_norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " )\n", - " (layernorm_embedding): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)\n", - " )\n", - " )\n", - ")" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "summarizer.model.module" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [ - "\"\"\"\n", - "# save and load preprocessed data\n", - "save_path = DATA_PATH\n", - "torch.save(abs_sum_train, os.path.join(save_path, \"train_full.pt\"))\n", - "torch.save(abs_sum_test, os.path.join(DATA_PATH, \"test_full.pt\"))\n", - "\n", - "\"\"\"\n", - "save_path = DATA_PATH\n", - "#abs_sum_train = torch.load(os.path.join(save_path, \"train_full.pt\"))\n", - "abs_sum_test = torch.load(os.path.join(save_path, \"test_full.pt\"))" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "287227\n", - "11490\n" - ] - } - ], - "source": [ - "print(len(abs_sum_train))\n", - "print(len(abs_sum_test))" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [], - "source": [ - "#save_path = os.path.join(DATA_PATH, \"processed\")\n", - "#torch.save(abs_sum_train, os.path.join(save_path, \"train_full.pt\"))\n", - "#torch.save(abs_sum_test, os.path.join(DATA_PATH, \"test_full.pt\"))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "#### Inspect Data" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "abs_sum_train[0].keys()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "scrolled": false - }, - "outputs": [], - "source": [ - "abs_sum_train[0]" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Model training\n", - "To start model training, we need to create a instance of ExtractiveSummarizer.\n", - "#### Choose the transformer model.\n", - "Currently ExtractiveSummarizer support two models:\n", - "- distilbert-base-uncase, \n", - "- bert-base-uncase\n", - "\n", - "Potentionally, roberta-based model and xlnet can be supported but needs to be tested.\n", - "#### Choose the encoder algorithm.\n", - "There are four options:\n", - "- baseline: it used a smaller transformer model to replace the bert model and with transformer summarization layer\n", - "- classifier: it uses pretrained BERT and fine-tune BERT with **simple logistic classification** summarization layer\n", - "- transformer: it uses pretrained BERT and fine-tune BERT with **transformer** summarization layer\n", - "- RNN: it uses pretrained BERT and fine-tune BERT with **LSTM** summarization layer" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "tags": [ - "parameters" - ] - }, - "outputs": [], - "source": [ - "BATCH_SIZE = 8 # batch size, unit is the number of samples\n", - "MAX_POS_LENGTH = 512\n", - "\n", - " \n", - "\n", - "\n", - "# GPU used for training\n", - "NUM_GPUS = torch.cuda.device_count()\n", - "\n", - "\n", - "# Learning rate\n", - "LEARNING_RATE=3e-5\n", - "\n", - "# How often the statistics reports show up in training, unit is step.\n", - "REPORT_EVERY=100\n", - "\n", - "# total number of steps for training\n", - "MAX_STEPS=1e2\n", - "# number of steps for warm up\n", - "WARMUP_STEPS=5e2\n", - " \n", - "if not QUICK_RUN:\n", - " MAX_STEPS=5e3\n", - " WARMUP_STEPS=5e2\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": { - "scrolled": true - }, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Iteration: 0%| | 0/35904 [00:00\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0mwarmup_steps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mWARMUP_STEPS\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m \u001b[0mreport_every\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mREPORT_EVERY\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 13\u001b[0m )\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/dadendev/nlp-recipes/utils_nlp/models/transformers/abstractive_summarization_bartt5.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, train_dataset, num_gpus, gpu_ids, batch_size, local_rank, max_steps, warmup_steps, learning_rate, weight_decay, adam_epsilon, max_grad_norm, gradient_accumulation_steps, report_every, save_every, verbose, seed, fp16, fp16_opt_level, world_size, rank, validation_function, checkpoint, **kwargs)\u001b[0m\n\u001b[1;32m 447\u001b[0m \u001b[0mfp16\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mfp16\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 448\u001b[0m \u001b[0mamp\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mamp\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 449\u001b[0;31m \u001b[0mvalidation_function\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 450\u001b[0m )\n\u001b[1;32m 451\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/dadendev/nlp-recipes/utils_nlp/models/transformers/common.py\u001b[0m in \u001b[0;36mfine_tune\u001b[0;34m(self, train_dataloader, get_inputs, device, num_gpus, max_steps, global_step, max_grad_norm, gradient_accumulation_steps, optimizer, scheduler, fp16, amp, local_rank, verbose, seed, report_every, save_every, clip_grad_norm, validation_function)\u001b[0m\n\u001b[1;32m 237\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mstep\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mepoch_iterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 238\u001b[0m \u001b[0minputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_inputs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel_name\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 239\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 240\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 241\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 530\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 531\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 532\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 533\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 534\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, *inputs, **kwargs)\u001b[0m\n\u001b[1;32m 150\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 151\u001b[0m \u001b[0mreplicas\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreplicate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice_ids\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 152\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparallel_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreplicas\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 153\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgather\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutput_device\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 154\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py\u001b[0m in \u001b[0;36mparallel_apply\u001b[0;34m(self, replicas, inputs, kwargs)\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mparallel_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreplicas\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mparallel_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreplicas\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice_ids\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreplicas\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mgather\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_device\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py\u001b[0m in \u001b[0;36mparallel_apply\u001b[0;34m(modules, inputs, kwargs_tup, devices)\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 74\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mthread\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mthreads\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 75\u001b[0;31m \u001b[0mthread\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstart\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 76\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mthread\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mthreads\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0mthread\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/threading.py\u001b[0m in \u001b[0;36mstart\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 849\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0m_limbo\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 850\u001b[0m \u001b[0;32mraise\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 851\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_started\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwait\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 852\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 853\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mrun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/threading.py\u001b[0m in \u001b[0;36mwait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 549\u001b[0m \u001b[0msignaled\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_flag\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 550\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0msignaled\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 551\u001b[0;31m \u001b[0msignaled\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_cond\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwait\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 552\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0msignaled\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 553\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/dadendev/anaconda3/envs/nlp_gpu/lib/python3.6/threading.py\u001b[0m in \u001b[0;36mwait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 293\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m# restore state no matter what (e.g., KeyboardInterrupt)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 294\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mtimeout\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 295\u001b[0;31m \u001b[0mwaiter\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0macquire\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 296\u001b[0m \u001b[0mgotit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 297\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " - ] - } - ], - "source": [ - "#\"\"\"\n", - "\n", - "summarizer.fit(\n", - " abs_sum_train,\n", - " num_gpus=NUM_GPUS,\n", - " batch_size=BATCH_SIZE,\n", - " gradient_accumulation_steps=1,\n", - " max_steps=MAX_STEPS,\n", - " learning_rate=LEARNING_RATE,\n", - " warmup_steps=WARMUP_STEPS,\n", - " verbose=True,\n", - " report_every=REPORT_EVERY,\n", - " )\n", - "\n", - "#\"\"\"\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "\"\"\"\n", - "summarizer.save_model(\n", - " os.path.join(\n", - " CACHE_DIR,\n", - " \"extsum_modelname_{0}_usepreprocess{1}_steps_{2}.pt\".format(\n", - " MODEL_NAME, USE_PREPROCSSED_DATA, MAX_STEPS\n", - " ),\n", - " )\n", - ")\n", - "\"\"\"\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# for loading a previous saved model\n", + "# save a finetuned model and load a previous saved model\n", "\"\"\"\n", "import torch\n", "model_path = os.path.join(\n", " CACHE_DIR,\n", - " \"extsum_modelname_{0}_usepreprocess{1}_steps_{2}.pt\".format(\n", - " MODEL_NAME, USE_PREPROCSSED_DATA, MAX_STEPS\n", + " \"abssum_modelname_{0}_steps_{1}.pt\".format(\n", + " MODEL_NAME, MAX_STEPS\n", " ))\n", - "summarizer = ExtractiveSummarizer(processor, MODEL_NAME, ENCODER, MAX_POS_LENGTH, CACHE_DIR)\n", - "summarizer.model.load_state_dict(torch.load(model_path, map_location=\"cpu\"))\n", + "summarizer.save_model(global_step=MAX_STEPS, full_name=model_path)\n", + "\n", + "summarizer = AbstractiveSummarizer(MODEL_NAME, cache_dir=CACHE_DIR)\n", + "summarizer.model.load_state_dict(torch.load(model_path, map_location=\"cpu\")['model'])\n", "\"\"\"" ] }, @@ -3186,7 +541,7 @@ "source": [ "### Model Evaluation\n", "\n", - "[ROUGE](https://en.wikipedia.org/wiki/ROUGE_(metric)), or Recall-Oriented Understudy for Gisting Evaluation has been commonly used for evaluating text summarization." + "[ROUGE](https://en.wikipedia.org/wiki/ROUGE_(metric)), or Recall-Oriented Understudy for Gisting Evaluation has been commonly used for evaluating text summarization. " ] }, { @@ -3194,15 +549,6 @@ "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "abs_sum_test[0].keys()" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], "source": [ "source = []\n", "target = []\n", @@ -3211,258 +557,28 @@ " target.append(i['tgt'].replace(\"\",\"\").replace(\"\", \"\").replace(\"\\n\", \"\")) " ] }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "\" marseille prosecutor says `` so far no videos were used in the crash investigation '' despite media reports . journalists at bild and paris match are `` very confident '' the video clip is real , an editor says . andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says . \"" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "target[0]" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\r", - "Generating summary: 0%| | 0/64 [00:00","").replace("", "").replace("\n", "")) - output = summarizer.model(**c) + output = summarizer.model(**c) generated_summaries = summarizer.predict( shortened_dataset, num_gpus=num_gpus, batch_size=TOP_N ) - #assert len(generated_summaries) == len(reference_summaries) - #print("###################") print("validation loss is {}".format(output[0])) print("prediction is {}".format(generated_summaries[0])) - #print("reference is {}".format(reference_summaries[0])) - #rouge_score = compute_rouge_python( - # cand=generated_summaries, ref=reference_summaries - #) - #return "rouge score: {}".format(rouge_score) class SummarizationProcessor: + """ Class for preprocessing abstractive summarization data for BART/T5 models. + + Args: + tokenizer(AutoTokenizer): tokenizer for the model used for preprocessing. + config(AutoConfig): config for the model used for preprocessing. + max_source_length (int, optional): Max number of tokens that be used + as input. Defaults to 1024. + max_target_length (int, optional): Max number of tokens that be used + as in target. Defaults to 140. + + """ def __init__( - self, - tokenizer, - config, - #model_name, - #cache_dir="./", - max_source_length=1024, - max_target_length=140, + self, tokenizer, config, max_source_length=1024, max_target_length=140, ): - #super().__init__() - # self.tokenizer = AutoTokenizer.from_pretrained( - # model_name, - #self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path, - # cache_dir=cache_dir, - #) - - #TOKENIZER_CLASS[model_name].from_pretrained(model_name, cache_dir=cache_dir) # b - #config = AutoConfig.from_pretrained( - # model_name, - # #self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path, - #**({"num_labels": num_labels} if num_labels is not None else {}), - # cache_dir=cache_dir, - #**config_kwargs, - #) - # if model_name.startswith("t5"): - # update config with summarization specific params - # task_specific_params = config.task_specific_params - # if task_specific_params is not None: - # config.update(task_specific_params.get("summarization", {})) + self.tokenizer = tokenizer self.config = config @@ -223,114 +214,128 @@ def __init__( @staticmethod def trim_seq2seq_batch(batch, pad_token_id): + y = trim_batch(batch["target_ids"], pad_token_id) - source_ids, source_mask = trim_batch(batch["source_ids"], pad_token_id, attention_mask=batch["source_mask"]) + source_ids, source_mask = trim_batch( + batch["source_ids"], pad_token_id, attention_mask=batch["source_mask"] + ) return source_ids, source_mask, y def preprocess(self, input_data_list): result = [] for i in input_data_list: - result.append(encode_example(i, tokenizer=self.tokenizer, prefix=self.prefix, - max_source_length=self.max_source_length, max_target_length=self.max_target_length )) + result.append( + encode_example( + i, + tokenizer=self.tokenizer, + prefix=self.prefix, + max_source_length=self.max_source_length, + max_target_length=self.max_target_length, + ) + ) return result - + @staticmethod def get_inputs(batch, device, model_name, tokenizer=None, train_mode=True): pad_token_id = tokenizer.pad_token_id if not train_mode: source_ids, source_mask = batch["source_ids"], batch["source_mask"] else: - source_ids, source_mask, y = SummarizationProcessor.trim_seq2seq_batch(batch, pad_token_id) + source_ids, source_mask, y = SummarizationProcessor.trim_seq2seq_batch( + batch, pad_token_id + ) y_ids = y[:, :-1].contiguous() lm_labels = y[:, 1:].clone() lm_labels[y[:, 1:] == pad_token_id] = -100 - #outputs = self(source_ids, attention_mask=source_mask, decoder_input_ids=y_ids, lm_labels=lm_labels,) if train_mode: - return {"input_ids": source_ids, - "attention_mask": source_mask, - "decoder_input_ids": y_ids, - "lm_labels": lm_labels + return { + "input_ids": source_ids, + "attention_mask": source_mask, + "decoder_input_ids": y_ids, + "lm_labels": lm_labels, } else: - return {"input_ids": source_ids, - "attention_mask": source_mask, + return { + "input_ids": source_ids, + "attention_mask": source_mask, } def collate_fn(self, batch, device, train_mode=False): input_ids = torch.stack([x["source_ids"] for x in batch]) masks = torch.stack([x["source_mask"] for x in batch]) pad_token_id = self.tokenizer.pad_token_id - source_ids, source_mask = trim_batch(input_ids, pad_token_id, attention_mask=masks) + source_ids, source_mask = trim_batch( + input_ids, pad_token_id, attention_mask=masks + ) if train_mode: target_ids = torch.stack([x["target_ids"] for x in batch]) y = trim_batch(target_ids, pad_token_id) - return {"source_ids": source_ids.to(device), "source_mask": source_mask.to(device), "target_ids": y.to(device)} + return { + "source_ids": source_ids.to(device), + "source_mask": source_mask.to(device), + "target_ids": y.to(device), + } else: - return {"source_ids": source_ids.to(device), "source_mask": source_mask.to(device)} + return { + "source_ids": source_ids.to(device), + "source_mask": source_mask.to(device), + } class AbstractiveSummarizer(Transformer): """class which performs abstractive summarization fine tuning and - prediction based on BertSumAbs model """ + prediction based on BART and T5 model """ def __init__( self, - #processor, - model_name="bart-large", + # processor, + model_name="t5-small", cache_dir=".", max_source_length=1024, - max_target_length=240 + max_target_length=240, ): """Initialize an object of BertSumAbs. Args: model_name (str, optional:) Name of the pretrained model which is used . - `S2SAbsSumProcessor.list_supported_models()` to see all supported model names.. Defaults to "bart-large". - cache_dir (str, optional): Directory to cache the tokenizer. Defaults to ".". - max_pos_length (int, optional): maximum postional embedding length for the - input. Defaults to 768. + `AbstractiveSummarizer.list_supported_models()` to see all supported + model names. Defaults to "t5-small". + cache_dir (str, optional): Directory to cache the model. Defaults to ".". + max_source_length (int, optional): maximum source length for the + input. Defaults to 1024. + max_target_length (int, optional): maximum target length for the + training input. Defaults to 240. + """ if model_name not in self.list_supported_models(): raise ValueError( "Model name {} is not supported by AbstractiveSummarizer. " - "Call 'AbstractiveSummarizer.list_supported_models()' to get all supported model " + "Call 'AbstractiveSummarizer.list_supported_models()' to" + "get all supported model " "names.".format(value) ) - self.config = AutoConfig.from_pretrained( - model_name, - #self.hparams.config_name if self.hparams.config_name else self.hparams.model_name_or_path, - #**({"num_labels": num_labels} if num_labels is not None else {}), - cache_dir=cache_dir, - #**config_kwargs, - ) + self.config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir,) + self.config.output_past = True # to enable num_beams greater than 1 task_specific_params = self.config.task_specific_params if task_specific_params is not None: self.config.update(task_specific_params.get("summarization", {})) - #self.config.update({"vocab_size": 50264}) self.config.update({"max_length": max_target_length}) self.config.update({"attention_dropout": 0.1}) - self.tokenizer = AutoTokenizer.from_pretrained( - model_name, - #self.hparams.tokenizer_name if self.hparams.tokenizer_name else self.hparams.model_name_or_path, - cache_dir=cache_dir, + self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir,) + + self.processor = SummarizationProcessor( + self.tokenizer, self.config, max_source_length, max_target_length ) - self.processor = SummarizationProcessor(self.tokenizer, self.config, max_source_length, max_target_length) - self._model_name = model_name - # self.model = MODEL_CLASS[model_name].from_pretrained( - self.model = MODEL_MODES["language-modeling"].from_pretrained( - self.model_name, - #from_tf=bool(".ckpt" in self.hparams.model_name_or_path), - config=self.config, - cache_dir=cache_dir, + self.model = MODEL_MODES["language-modeling"].from_pretrained( + self.model_name, config=self.config, cache_dir=cache_dir, ) - self.model_class = AutoModelWithLMHead #MODEL_CLASS[model_name] self.cache_dir = cache_dir self.max_source_length = max_source_length self.max_target_length = max_target_length @@ -351,7 +356,7 @@ def fit( batch_size=4, local_rank=-1, max_steps=5e4, - warmup_steps=20000, + warmup_steps=2e3, learning_rate=0.002, weight_decay=0.01, adam_epsilon=1e-8, @@ -375,42 +380,32 @@ def fit( Args: train_dataset (SummarizationDataset): Training dataset. num_gpus (int, optional): The number of GPUs to use. If None, all - available GPUs will be used. If set to 0 or GPUs are not available, - CPU device will be used. Defaults to None. + available GPUs will be used. If set to 0 or GPUs are + not available, CPU device will be used. Defaults to None. gpu_ids (list): List of GPU IDs to be used. If set to None, the first num_gpus GPUs will be used. Defaults to None. - batch_size (int, optional): Maximum number of tokens in each batch. + batch_size (int, optional): Maximum number of examples in each batch. local_rank (int, optional): Local_rank for distributed training on GPUs. Local rank means the ranking of the current GPU device on the current node. Defaults to -1, which means non-distributed training. - max_steps (int, optional): Maximum number of training steps. Defaults to 5e5. - warmup_steps_bert (int, optional): Number of steps taken to increase - learning rate from 0 to `learning_rate` for tuning the BERT encoder. - Defaults to 2e4. - warmup_steps_dec (int, optional): Number of steps taken to increase - learning rate from 0 to `learning_rate` for tuning the decoder. - Defaults to 1e4. - learning_rate_bert (float, optional): Learning rate of the optimizer - for the encoder. Defaults to 0.002. - learning_rate_dec (float, optional): Learning rate of the optimizer - for the decoder. Defaults to 0.2. - optimization_method (string, optional): Optimization method used in fine - tuning. Defaults to "adam". - max_grad_norm (float, optional): Maximum gradient norm for gradient clipping. - Defaults to 0. - beta1 (float, optional): The exponential decay rate for the first moment - estimates. Defaults to 0.9. - beta2 (float, optional): The exponential decay rate for the second-moment - estimates. This value should be set close to 1.0 on problems with - a sparse gradient. Defaults to 0.99. - decay_method (string, optional): learning rate decrease method. - Default to 'noam'. + max_steps (int, optional): Maximum number of training steps. + Defaults to 5e4. + warmup_steps (int, optional): Number of steps taken to increase + learning rate from 0 to `learning_rate`. Defaults to 2e3. + learning_rate (float, optional): Learning rate of the optimizer. + Defaults to 0.002. + weight_decay (float, optional): Weight decay to apply after each parameter + update. Defaults to 0.01. + adam_epsilon (float, optional): Epsilon of the AdamW optimizer. + Defaults to 1e-8. + max_grad_norm (float, optional): Maximum gradient norm for + gradient clipping. Defaults to 0. gradient_accumulation_steps (int, optional): Number of batches to accumulate gradients on between each model parameter update. Defaults to 1. report_every (int, optional): The interval by steps to print out the training log. Defaults to 10. - save_every (int, optional): The interval by steps to save the finetuned + save_every (int, optional): The interval by steps to save the finetuned model. Defaults to 100. verbose (bool, optional): Whether to print out the training log. Defaults to True. @@ -439,8 +434,8 @@ def fit( checkpoint_state_dict = None if checkpoint: # checkpoint should have "model", "optimizer", "amp" - checkpoint_state_dict = torch.load(checkpoint, map_location="cpu") - + checkpoint_state_dict = torch.load(checkpoint, map_location="cpu") + # init optimizer device, num_gpus, amp = self.prepare_model_and_optimizer( num_gpus=num_gpus, @@ -455,9 +450,13 @@ def fit( ) self.amp = amp - + global_step = 0 - if checkpoint_state_dict and "global_step" in checkpoint_state_dict and checkpoint_state_dict["global_step"]: + if ( + checkpoint_state_dict + and "global_step" in checkpoint_state_dict + and checkpoint_state_dict["global_step"] + ): global_step = checkpoint_state_dict["global_step"] / world_size print("global_step is {}".format(global_step)) @@ -476,12 +475,8 @@ def fit( train_dataset, num_replicas=world_size, rank=rank ) - def collate_fn(data): - return self.processor.collate_fn( - data, device, train_mode=True - ) - + return self.processor.collate_fn(data, device, train_mode=True) train_dataloader = DataLoader( train_dataset, @@ -496,8 +491,10 @@ def collate_fn(data): max_steps=max_steps, gradient_accumulation_steps=gradient_accumulation_steps, ) - import functools - get_inputs = functools.partial(self.processor.get_inputs, tokenizer=self.processor.tokenizer) + + get_inputs = functools.partial( + self.processor.get_inputs, tokenizer=self.processor.tokenizer + ) super().fine_tune( train_dataloader=train_dataloader, get_inputs=get_inputs, @@ -522,7 +519,7 @@ def collate_fn(data): self.model.cpu() torch.cuda.empty_cache() - #self.save_model(max_steps) + self.save_model(global_step=max_steps) def predict( self, @@ -531,15 +528,16 @@ def predict( gpu_ids=None, local_rank=-1, batch_size=16, - length_penalty=0.95, - beam_size=4, min_length=56, max_length=140, + num_beams=4, + length_penalty=2.0, no_repeat_ngram_size=3, early_stopping=True, fp16=False, verbose=True, checkpoint=None, + **predictor_kwargs ): """ Predict the summarization for the input data iterator. @@ -556,54 +554,56 @@ def predict( inferencing. Defaults to -1, which means non-distributed inferencing. batch_size (int, optional): The number of test examples in each batch. Defaults to 16. - alpha (float, optional): Length penalty. Defaults to 0.6. - beam_size (int, optional): Beam size of beam search. Defaults to 5. min_length (int, optional): Minimum number of tokens in the output sequence. - Defaults to 15. + Defaults to 140. max_length (int, optional): Maximum number of tokens in output sequence. Defaults to 150. + num_beams (int, optional): Beam size for beam search. Defaults to 4. + length_penalty (float, optional): Exponential penalty to the length. + Defaults to 2.0. + no_repeat_ngram_size (int, optional): If set to int >0, all ngrams of size + `no_repeat_ngram_size` can only occur once in the generated summary. + Defaults to 3. + early_stopping (bool, optional): If set to `True` beam search is stopped + when at least `num_beams` sentences finished per batch. Defautls to True. fp16 (bool, optional): Whether to use half-precision model for prediction. Defaults to False. verbose (bool, optional): Whether to print out the training log. Defaults to True. + checkpoint (str, optional): + predictor_kwargs (dict, optional): Additional kwargs that will be forwarded + to `Predictor`. Please consult the arguments in function + `PreTrainedModel::generate`. Returns: List of strings which are the summaries """ - + device, num_gpus = get_device( num_gpus=num_gpus, gpu_ids=gpu_ids, local_rank=local_rank ) model = move_model_to_device(self.model, device) - + checkpoint_state_dict = None if checkpoint: # checkpoint should have "model", "optimizer", "amp" - checkpoint_state_dict = torch.load(checkpoint, map_location="cpu") + checkpoint_state_dict = torch.load(checkpoint, map_location="cpu") model.load_state_dict(checkpoint_state_dict["model"]) - model.eval() model = parallelize_model( - model, - device, - num_gpus=num_gpus, - gpu_ids=gpu_ids, - local_rank=local_rank, + model, device, num_gpus=num_gpus, gpu_ids=gpu_ids, local_rank=local_rank, ) - - + if fp16: model = model.half() test_sampler = SequentialSampler(test_dataset) def collate_fn(data): - return self.processor.collate_fn( - data, device, train_mode=False - ) + return self.processor.collate_fn(data, device, train_mode=False) test_dataloader = DataLoader( test_dataset, @@ -613,36 +613,45 @@ def collate_fn(data): ) print("dataset length is {}".format(len(test_dataset))) - predictor = Predictor(model, min_length, max_length) + + predictor = Predictor( + model, + min_length, + max_length, + num_beams=num_beams, + length_penalty=length_penalty, + no_repeat_ngram_size=no_repeat_ngram_size, + early_stopping=early_stopping, + **predictor_kwargs + ) + # move model to devices def this_model_move_callback(model, device): model = move_model_to_device(model, device) return parallelize_model( model, device, num_gpus=num_gpus, gpu_ids=gpu_ids, local_rank=local_rank ) + predictor = this_model_move_callback(predictor, device) generated_summaries = [] for batch in tqdm( - test_dataloader, desc="Generating summary", disable=True #not verbose + test_dataloader, desc="Generating summary", disable=True # not verbose ): - #if self.model_name.startswith("t5"): - # batch = [self.model.config.prefix + text for text in batch] - #dct = self.tokenizer.batch_encode_plus(batch, max_length=1024, return_tensors="pt", pad_to_max_length=True) - # print(batch) - input_ids, masks = trim_batch(batch["source_ids"], self.tokenizer.pad_token_id, attention_mask=batch["source_mask"]) - summaries = predictor(input_ids, masks) - """ - summaries = self.model.module.generate( - input_ids=batch["source_ids"], + input_ids, masks = trim_batch( + batch["source_ids"], + self.tokenizer.pad_token_id, attention_mask=batch["source_mask"], - min_length=min_length, - max_length=max_length ) - """ - decoded_summaries = [self.tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True) for g in summaries] - + summaries = predictor(input_ids, masks) + decoded_summaries = [ + self.tokenizer.decode( + g, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + for g in summaries + ] + generated_summaries.extend(decoded_summaries) # release GPU memories @@ -671,7 +680,9 @@ def save_model(self, global_step=None, full_name=None): output_model_dir = os.path.join(self.cache_dir, "fine_tuned") os.makedirs(self.cache_dir, exist_ok=True) os.makedirs(output_model_dir, exist_ok=True) - full_name = os.path.join(output_model_dir, "abssum_{}.pt".format(self.model_name)) + full_name = os.path.join( + output_model_dir, "abssum_{}.pt".format(self.model_name) + ) else: path, filename = os.path.split(full_name) print(path) From ed91324801c2ba43256a4a6e43eb27e073fb829f Mon Sep 17 00:00:00 2001 From: Daisy Deng Date: Wed, 20 May 2020 20:03:21 +0000 Subject: [PATCH 10/14] add unit test --- .../test_abstractive_summarization_bartt5.py | 160 ++++++++++++++++++ 1 file changed, 160 insertions(+) create mode 100644 tests/unit/test_abstractive_summarization_bartt5.py diff --git a/tests/unit/test_abstractive_summarization_bartt5.py b/tests/unit/test_abstractive_summarization_bartt5.py new file mode 100644 index 000000000..bc0c4b08d --- /dev/null +++ b/tests/unit/test_abstractive_summarization_bartt5.py @@ -0,0 +1,160 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import argparse +import nltk +from nltk import tokenize +import os +import pytest +import torch + +torch.set_printoptions(threshold=5000) + +from utils_nlp.models.transformers.datasets import SummarizationDataset +from utils_nlp.models.transformers.abstractive_summarization_bartt5 import ( + AbstractiveSummarizer) + +# @pytest.fixture() +def source_data(): + return [ + [ + "Boston, MA welcome to Microsoft/nlp", + "Welcome to text summarization.", + "Welcome to Microsoft NERD.", + "Look outside, what a beautiful Charlse River fall view.", + ], + ["I am just another test case"], + ["want to test more"], + ] + + +# @pytest.fixture() +def target_data(): + return [ + [ + "welcome to microsoft/nlp.", + "Welcome to text summarization.", + "Welcome to Microsoft NERD.", + ], + ["I am just another test summary"], + ["yest, I agree"], + ] + + +#NUM_GPUS = 2 +os.environ["NCCL_IB_DISABLE"] = "0" + + +@pytest.fixture(scope="module") +def test_dataset_for_bartt5(tmp_module): + source = source_data() + target = target_data() + source_file = os.path.join(tmp_module, "source.txt") + target_file = os.path.join(tmp_module, "target.txt") + f = open(source_file, "w") + for i in source: + f.write(" ".join(i)) + f.write("\n") + f.close() + f = open(target_file, "w") + for i in target: + f.write(" ".join(i)) + f.write("\n") + f.close() + train_dataset = SummarizationDataset( + source_file = source_file, + target_file = target_file, + ) + test_dataset = SummarizationDataset( + source_file = source_file, + target_file = target_file, + ) + return train_dataset, test_dataset + + +@pytest.mark.gpu +@pytest.fixture() +def test_train_model(tmp_module, test_dataset_for_bartt5, batch_size=1): + CACHE_PATH = ( + tmp_module + ) + DATA_PATH = ( + tmp_module + ) + MODEL_PATH = ( + tmp_module + ) + + summarizer = AbstractiveSummarizer("t5-small", cache_dir=CACHE_PATH) + + checkpoint = None + train_sum_dataset, _ = test_dataset_for_bartt5 + abs_sum_train = summarizer.processor.preprocess(train_sum_dataset) + + MAX_STEP = 20 + TOP_N = 8 + summarizer.fit( + abs_sum_train, + batch_size=batch_size, + max_steps=MAX_STEP, + local_rank=-1, + learning_rate=0.002, + warmup_steps=20000, + num_gpus=None, + report_every=10, + save_every=100, + fp16=False, + checkpoint=checkpoint, + ) + saved_model_path = os.path.join( + MODEL_PATH, "summarizer_step_{}.pt".format(MAX_STEP) + ) + summarizer.save_model(MAX_STEP, saved_model_path) + + return saved_model_path + + +@pytest.mark.gpu +def test_finetuned_model( + tmp_module, + test_train_model, + test_dataset_for_bartt5, + top_n=8, + batch_size=1, +): + CACHE_PATH = ( + tmp_module + ) + DATA_PATH = ( + tmp_module + ) + MODEL_PATH = ( + tmp_module + ) + + _, test_sum_dataset = test_dataset_for_bartt5 + + summarizer = AbstractiveSummarizer("t5-small", cache_dir=CACHE_PATH) + abs_sum_test = summarizer.processor.preprocess(test_sum_dataset) + checkpoint = torch.load(test_train_model, map_location="cpu") + + summarizer.model.load_state_dict(checkpoint["model"]) + + reference_summaries = [ + "".join(i["tgt"]).rstrip("\n") for i in abs_sum_test + ] + print("start prediction") + generated_summaries = summarizer.predict( + abs_sum_test, batch_size=batch_size, num_gpus=None + ) + + def _write_list_to_file(list_items, filename): + with open(filename, "w") as filehandle: + # for cnt, line in enumerate(filehandle): + for item in list_items: + filehandle.write("%s\n" % item) + + print("writing generated summaries") + _write_list_to_file(generated_summaries, os.path.join(CACHE_PATH, "prediction.txt")) + + assert len(generated_summaries) == len(reference_summaries) From 4e50f68b16aba1f05632b35bcb92129ddf9c3e90 Mon Sep 17 00:00:00 2001 From: Daisy Deng Date: Wed, 20 May 2020 20:26:26 +0000 Subject: [PATCH 11/14] add integration test --- ...tractive_summarization_bartt5_cnndm.ipynb} | 378 +++++++++++++++--- examples/text_summarization/test_bartt5.py | 96 ----- tests/conftest.py | 6 + .../abstractive_summarization_bartt5.py | 157 ++++---- 4 files changed, 428 insertions(+), 209 deletions(-) rename examples/text_summarization/{abstractive_summarization_cnndm_transformers.ipynb => abstractive_summarization_bartt5_cnndm.ipynb} (76%) delete mode 100644 examples/text_summarization/test_bartt5.py diff --git a/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb b/examples/text_summarization/abstractive_summarization_bartt5_cnndm.ipynb similarity index 76% rename from examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb rename to examples/text_summarization/abstractive_summarization_bartt5_cnndm.ipynb index 34067c1c1..db3584c24 100644 --- a/examples/text_summarization/abstractive_summarization_cnndm_transformers.ipynb +++ b/examples/text_summarization/abstractive_summarization_bartt5_cnndm.ipynb @@ -25,8 +25,7 @@ "\n", "### Before You Start\n", "\n", - "\n", - "> **Tip**: If you want to run through the notebook quickly, you can set the **`QUICK_RUN`** flag in the cell below to **`True`** to run the notebook on a small subset of the data and a smaller number of epochs. \n", + "Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of steps. If QUICK_RUN = True, the notebook takes about 5 minutes to run on a VM with 1 Tesla K80 GPUs with 12GB GPU memory.\n", "\n", "### Additional Notes\n", "\n", @@ -92,7 +91,7 @@ "from utils_nlp.dataset.cnndm import CNNDMSummarizationDataset\n", "from utils_nlp.eval import compute_rouge_python, compute_rouge_perl\n", "from utils_nlp.models.transformers.abstractive_summarization_bartt5 import (\n", - " AbstractiveSummarizer, SummarizationProcessor, validate)\n", + " AbstractiveSummarizer)\n", "\n", "from utils_nlp.models.transformers.datasets import SummarizationDataset\n", "import nltk\n", @@ -219,7 +218,71 @@ "parameters" ] }, - "outputs": [], + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "cb2a382fa292493f9bae5794a11848f4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1197.0, style=ProgressStyle(description…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f8630ee654a9488aaef1bcb41526253a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=791656.0, style=ProgressStyle(descripti…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7d849fb5a1bf4bc0874708299e5b3712", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=242136741.0, style=ProgressStyle(descri…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], "source": [ "# Transformer model being used\n", "# MODEL_NAME = \"bart-large\"\n", @@ -263,7 +326,15 @@ "metadata": { "scrolled": false }, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 489k/489k [00:07<00:00, 61.4kKB/s] \n" + ] + } + ], "source": [ "train_dataset, test_dataset = CNNDMSummarizationDataset(top_n=TOP_N, local_cache_path=DATA_PATH, raw=True)" ] @@ -326,8 +397,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 540 ms, sys: 0 ns, total: 540 ms\n", - "Wall time: 539 ms\n" + "CPU times: user 485 ms, sys: 7.85 ms, total: 493 ms\n", + "Wall time: 492 ms\n" ] } ], @@ -481,23 +552,38 @@ "SAVE_EVERY=1000\n", "\n", "# total number of steps for training\n", - "MAX_STEPS=1000\n", + "MAX_STEPS=100\n", "# number of steps for warm up\n", - "WARMUP_STEPS=5e2\n", + "WARMUP_STEPS=5e1\n", " \n", "if not QUICK_RUN:\n", - " MAX_STEPS=2e4\n", - " WARMUP_STEPS=5e3\n", + " MAX_STEPS=1000\n", + " WARMUP_STEPS=5e2\n", + " \n", + "# inference parameters\n", + "TEST_PER_GPU_BATCH_SIZE = 32\n", + "BEAM_SIZE = 3\n", " " ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": { "scrolled": true }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "timestamp: 20/05/2020 19:11:57, average loss: 2.734279, time duration: 84.708848,\n", + " number of examples in current reporting: 400, step 100\n", + " out of total 100\n", + "saving through pytorch to ./t5_cache/fine_tuned/abssum_t5-small.pt\n" + ] + } + ], "source": [ "\n", "summarizer.fit(\n", @@ -516,9 +602,91 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 31, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "./t5_cache\n", + "saving through pytorch to ./t5_cache/abssum_modelname_t5-small_steps_100.pt\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9ad8d415da0842e28f59d09928c8e754", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1197.0, style=ProgressStyle(description…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c6101747540d46afa113cda9a4977c41", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=791656.0, style=ProgressStyle(descripti…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "987b798dca0b489cb4fb8aabaace0ba5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=242136741.0, style=ProgressStyle(descri…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "# save a finetuned model and load a previous saved model\n", "\"\"\"\n", @@ -541,12 +709,25 @@ "source": [ "### Model Evaluation\n", "\n", - "[ROUGE](https://en.wikipedia.org/wiki/ROUGE_(metric)), or Recall-Oriented Understudy for Gisting Evaluation has been commonly used for evaluating text summarization. " + "[ROUGE](https://en.wikipedia.org/wiki/ROUGE_(metric)), or Recall-Oriented Understudy for Gisting Evaluation has been commonly used for evaluating text summarization. \n", + "For the settings in this notebook with QUICK_RUN=True, you should get ROUGE scores close to the following numbers:\n", + "\n", + "``\n", + "{'rouge-1': {'f': 0.31741527202292646,\n", + " 'p': 0.3455155118288276,\n", + " 'r': 0.3045104334747269},\n", + " 'rouge-2': {'f': 0.12227435906684982,\n", + " 'p': 0.13407308558314568,\n", + " 'r': 0.11687233771002672},\n", + " 'rouge-l': {'f': 0.23522707640246865,\n", + " 'p': 0.2558803081762467,\n", + " 'r': 0.22589352441506083}}\n", + " `` " ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -559,21 +740,52 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 18, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dataset length is 100\n", + "CPU times: user 2min 16s, sys: 432 ms, total: 2min 16s\n", + "Wall time: 2min 16s\n" + ] + } + ], "source": [ "%%time\n", - "# checkpoint=\"/dadendev/nlp-recipes/examples/text_summarization/bk_bart-large_step_10000.pt\"\n", - "prediction = summarizer.predict(abs_sum_test, num_gpus=1, batch_size=32,min_length=24, max_length=48)\n", - "#, checkpoint=checkpoint)" + "prediction = summarizer.predict(abs_sum_test, \n", + " num_gpus=NUM_GPUS, \n", + " batch_size=TEST_PER_GPU_BATCH_SIZE*NUM_GPUS,\n", + " min_length=24, \n", + " max_length=48,\n", + " num_beams=BEAM_SIZE)\n" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Number of candidates: 100\n", + "Number of references: 100\n", + "{'rouge-1': {'f': 0.31741527202292646,\n", + " 'p': 0.3455155118288276,\n", + " 'r': 0.3045104334747269},\n", + " 'rouge-2': {'f': 0.12227435906684982,\n", + " 'p': 0.13407308558314568,\n", + " 'r': 0.11687233771002672},\n", + " 'rouge-l': {'f': 0.23522707640246865,\n", + " 'p': 0.2558803081762467,\n", + " 'r': 0.22589352441506083}}\n" + ] + } + ], "source": [ "rouge_scores = compute_rouge_python(cand=prediction, ref=target)\n", "pprint.pprint(rouge_scores)" @@ -581,27 +793,68 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "'french prosecutor brice robin says \" so far no videos were used in the crash investigation. he adds that a person who has such a video needs to immediately give it to investigators.'" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "prediction[0]" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "' marseille prosecutor says \" so far no videos were used in the crash investigation \" despite media reports . journalists at bild and paris match are \" very confident \" the video clip is real , an editor says . andreas lubitz had informed his lufthansa training school of an episode of severe depression , airline says . '" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "target[0]" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "application/scrapbook.scrap.json+json": { + "data": 0.12227435906684982, + "encoder": "json", + "name": "rouge_2_f_score", + "version": 1 + } + }, + "metadata": { + "scrapbook": { + "data": true, + "display": false, + "name": "rouge_2_f_score" + } + }, + "output_type": "display_data" + } + ], "source": [ "# for testing\n", "sb.glue(\"rouge_2_f_score\", rouge_scores['rouge-2']['f'])" @@ -616,7 +869,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -630,7 +883,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ @@ -644,36 +897,74 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 25, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "dict_keys(['src', 'src_txt', 'source_ids', 'source_mask'])" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "preprocessed_dataset[0].keys()" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 26, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "dataset length is 1\n" + ] + } + ], "source": [ "prediction = summarizer.predict(preprocessed_dataset, num_gpus=0, batch_size=1, min_length=24, max_length=48,)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 27, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "'the person would not be held for any length of time in an American facility. border defenses are weakened, officials say. the new rule is set to be announced in the next 48 hours.'" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "prediction[0]" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 28, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Total notebook running time 264.78186321258545\n" + ] + } + ], "source": [ "print(\"Total notebook running time {}\".format(time.time() - start_time))" ] @@ -687,17 +978,14 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "if os.path.exists(DATA_PATH):\n", " shutil.rmtree(DATA_PATH, ignore_errors=True)\n", "if os.path.exists(CACHE_DIR):\n", - " shutil.rmtree(CACHE_DIR, ignore_errors=True)\n", - "if USE_PREPROCSSED_DATA:\n", - " if os.path.exists(PROCESSED_DATA_PATH):\n", - " shutil.rmtree(PROCESSED_DATA_PATH, ignore_errors=True)" + " shutil.rmtree(CACHE_DIR, ignore_errors=True)\n" ] } ], diff --git a/examples/text_summarization/test_bartt5.py b/examples/text_summarization/test_bartt5.py deleted file mode 100644 index 65790b5d3..000000000 --- a/examples/text_summarization/test_bartt5.py +++ /dev/null @@ -1,96 +0,0 @@ -import os -import shutil -import sys -from tempfile import TemporaryDirectory -import torch - -nlp_path = os.path.abspath("../../") -if nlp_path not in sys.path: - sys.path.insert(0, nlp_path) - -from utils_nlp.dataset.cnndm import CNNDMBertSumProcessedData, CNNDMSummarizationDataset -from utils_nlp.eval import compute_rouge_python, compute_rouge_perl -from utils_nlp.models.transformers.abstractive_summarization_bartt5 import ( - AbstractiveSummarizer, SummarizationProcessor, validate) - -from utils_nlp.models.transformers.datasets import SummarizationDataset -import nltk -from nltk import tokenize - -import pandas as pd -import scrapbook as sb -import pprint - - -QUICK_RUN = True -MODEL_NAME = "bart-large" -CACHE_DIR = "./bartt5_cache" #TemporaryDirectory().name - -#processor = SummarizationProcessor(MODEL_NAME,cache_dir=CACHE_DIR ) #tokenizer, config.prefix) -DATA_PATH = "./bartt5_cnndm" #TemporaryDirectory().name -# The number of lines at the head of data file used for preprocessing. -1 means all the lines. -TOP_N = -1 -if not QUICK_RUN: - TOP_N = -1 -#abs_sum_train = torch.load(os.path.join(DATA_PATH, "train_full.pt")) -#abs_sum_test = torch.load(os.path.join(DATA_PATH, "test_full.pt")) - -BATCH_SIZE_PER_GPU = 1# batch size, unit is the number of samples -MAX_POS_LENGTH = 512 -# GPU used for training -NUM_GPUS = torch.cuda.device_count() -# Learning rate -LEARNING_RATE=3e-5 -# How often the statistics reports show up in training, unit is step. -REPORT_EVERY=100 -SAVE_EVERY=1000 -# total number of steps for training -MAX_STEPS=20000 -# number of steps for warm up -WARMUP_STEPS=5e2 -if not QUICK_RUN: - MAX_STEPS=2e4 - WARMUP_STEPS=5e2 - - -summarizer = AbstractiveSummarizer(MODEL_NAME, cache_dir=CACHE_DIR) -processor = summarizer.processor -""" -train_dataset, test_dataset = CNNDMSummarizationDataset(top_n=TOP_N, local_cache_path=DATA_PATH, raw=True) -abs_sum_train = processor.preprocess(train_dataset) -abs_sum_test = processor.preprocess(test_dataset) -""" -#torch.save(abs_sum_train, os.path.join(DATA_PATH, "train_{0}_{1}.pt".format(MODEL_NAME, TOP_N))) -#torch.save(abs_sum_test, os.path.join(DATA_PATH, "test_{0}_{1}.pt".format(MODEL_NAME, TOP_N))) -abs_sum_train = torch.load(os.path.join(DATA_PATH, "train_{0}_{1}.pt".format(MODEL_NAME, TOP_N))) -abs_sum_test = torch.load(os.path.join(DATA_PATH, "test_{0}_{1}.pt".format(MODEL_NAME, TOP_N))) -def new_validate(summarizer): - validate(summarizer, abs_sum_test, num_gpus=1) -#""" -summarizer.fit( - abs_sum_train, - num_gpus=NUM_GPUS, - batch_size=BATCH_SIZE_PER_GPU*NUM_GPUS, - gradient_accumulation_steps=4, - max_steps=MAX_STEPS, - max_grad_norm=0.1, - learning_rate=LEARNING_RATE, - warmup_steps=WARMUP_STEPS, - verbose=True, - report_every=REPORT_EVERY, - validation_function=new_validate, - checkpoint="./bartt5_cache/bart-large_step_10000.pt" - ) -#""" -#prediction = summarizer.predict(abs_sum_test[0:32], num_gpus=NUM_GPUS, batch_size=BATCH_SIZE_PER_GPU*NUM_GPUS) -#print(prediction) -"""summarizer.save_model(global_step=MAX_STEPS, -full_name = os.path.join( - CACHE_DIR, - "abssum_{0}_steps_{1}.pt".format( - MODEL_NAME, MAX_STEPS - ), - ) -) - -""" diff --git a/tests/conftest.py b/tests/conftest.py index ecc6c2acf..87f5a1a31 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -138,6 +138,12 @@ def notebooks(): "text_summarization", "abstractive_summarization_bertsumabs_cnndm.ipynb", ), + "abstractive_summarization_bartt5_cnndm": os.path.join( + folder_notebooks, + "text_summarization", + "abstractive_summarization_bartt5_cnndm.ipynb", + ), + } return paths diff --git a/utils_nlp/models/transformers/abstractive_summarization_bartt5.py b/utils_nlp/models/transformers/abstractive_summarization_bartt5.py index cf40b9b0a..5cfa702cb 100644 --- a/utils_nlp/models/transformers/abstractive_summarization_bartt5.py +++ b/utils_nlp/models/transformers/abstractive_summarization_bartt5.py @@ -56,53 +56,6 @@ logger = logging.getLogger(__name__) - -def encode_example( - example, - tokenizer, - prefix="", - max_source_length=None, - max_target_length=None, - pad_to_max_length=True, - return_tensors="pt", -): - """ - Encode a single example with the specified tokenizer. - - Args: - example: - tokenizer - prefix - max_source_length - max_target_length: - pad_to_max_length: - return_tensors: - - """ - - tokenized_source = tokenizer.batch_encode_plus( - [prefix + example["src"]], - max_length=max_source_length, - pad_to_max_length=pad_to_max_length, - return_tensors=return_tensors, - ) - - source_ids = tokenized_source["input_ids"].squeeze() - src_mask = tokenized_source["attention_mask"].squeeze() - example["source_ids"] = source_ids - example["source_mask"] = src_mask - if "tgt" in example: - tokenized_target = tokenizer.batch_encode_plus( - [example["tgt"]], - max_length=max_target_length, - pad_to_max_length=pad_to_max_length, - return_tensors=return_tensors, - ) - target_ids = tokenized_target["input_ids"].squeeze() - example["target_ids"] = target_ids - return example - - class Predictor(nn.Module): """ Predictor which can run on multi-GPUs. @@ -212,20 +165,53 @@ def __init__( self.max_source_length = max_source_length self.max_target_length = max_target_length - @staticmethod - def trim_seq2seq_batch(batch, pad_token_id): + def preprocess(self, input_data_list): + """ preprocess the data for abstractive summarization. + + Args: + input_data_list (list of dictionary): input list where each item is + an dictionary with fields "src" and "tgt" and both fields are string. - y = trim_batch(batch["target_ids"], pad_token_id) - source_ids, source_mask = trim_batch( - batch["source_ids"], pad_token_id, attention_mask=batch["source_mask"] - ) - return source_ids, source_mask, y + Returns: + list of dictionary with addtional fields "source_ids", + "source_mask" and "target_ids". + """ + def _encode_example( + example, + tokenizer, + prefix="", + max_source_length=None, + max_target_length=None, + pad_to_max_length=True, + return_tensors="pt", + ): + + tokenized_source = tokenizer.batch_encode_plus( + [prefix + example["src"]], + max_length=max_source_length, + pad_to_max_length=pad_to_max_length, + return_tensors=return_tensors, + ) + + source_ids = tokenized_source["input_ids"].squeeze() + src_mask = tokenized_source["attention_mask"].squeeze() + example["source_ids"] = source_ids + example["source_mask"] = src_mask + if "tgt" in example: + tokenized_target = tokenizer.batch_encode_plus( + [example["tgt"]], + max_length=max_target_length, + pad_to_max_length=pad_to_max_length, + return_tensors=return_tensors, + ) + target_ids = tokenized_target["input_ids"].squeeze() + example["target_ids"] = target_ids + return example - def preprocess(self, input_data_list): result = [] for i in input_data_list: result.append( - encode_example( + _encode_example( i, tokenizer=self.tokenizer, prefix=self.prefix, @@ -237,31 +223,66 @@ def preprocess(self, input_data_list): @staticmethod def get_inputs(batch, device, model_name, tokenizer=None, train_mode=True): + """ + Creates an input dictionary given a model name. + + Args: + batch (object): A Batch containing lists of source ids, source_mask. + If train_mode is True, it also contains the list of target ids. + device (torch.device): A PyTorch device. + model_name (bool, optional): Model name used to format the inputs. + tokenizer (AutoTokenizer, optional): tokenizer whose pad_token_id + will be used for processing. + train_mode (bool, optional): Training mode flag. + Defaults to True. + + Returns: + dict: Dictionary containing source ids, attention masks. + Decoder input ids and LM labels are only returned when + train_mode is True. + """ + + pad_token_id = tokenizer.pad_token_id if not train_mode: source_ids, source_mask = batch["source_ids"], batch["source_mask"] - else: - source_ids, source_mask, y = SummarizationProcessor.trim_seq2seq_batch( - batch, pad_token_id - ) - y_ids = y[:, :-1].contiguous() - lm_labels = y[:, 1:].clone() - lm_labels[y[:, 1:] == pad_token_id] = -100 - - if train_mode: return { "input_ids": source_ids, "attention_mask": source_mask, - "decoder_input_ids": y_ids, - "lm_labels": lm_labels, } + else: + y = trim_batch(batch["target_ids"], pad_token_id) + source_ids, source_mask = trim_batch( + batch["source_ids"], pad_token_id, attention_mask=batch["source_mask"] + ) + y_ids = y[:, :-1].contiguous() + lm_labels = y[:, 1:].clone() + lm_labels[y[:, 1:] == pad_token_id] = -100 + return { "input_ids": source_ids, "attention_mask": source_mask, + "decoder_input_ids": y_ids, + "lm_labels": lm_labels, } - + def collate_fn(self, batch, device, train_mode=False): + """ Collate formats the data passed to the data loader. + In particular we tokenize the data batch after batch to avoid keeping them + all in memory. + + Args: + batch (list of dictionary): input data to be loaded. + device (torch.device): A PyTorch device. + train_mode (bool, optional): Training mode flag. + Defaults to True. + + Returns: + namedtuple: a nametuple containing source ids, source mask. + If train_mode is True, it also contains the target ids. + """ + input_ids = torch.stack([x["source_ids"] for x in batch]) masks = torch.stack([x["source_mask"] for x in batch]) pad_token_id = self.tokenizer.pad_token_id From f99019f3d8e27ddb5a7c3851927b53973c0a951d Mon Sep 17 00:00:00 2001 From: Daisy Deng Date: Thu, 21 May 2020 18:21:07 +0000 Subject: [PATCH 12/14] rename variable --- utils_nlp/dataset/cnndm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/utils_nlp/dataset/cnndm.py b/utils_nlp/dataset/cnndm.py index 8890197cb..cdaaeb042 100644 --- a/utils_nlp/dataset/cnndm.py +++ b/utils_nlp/dataset/cnndm.py @@ -78,7 +78,7 @@ def CNNDMSummarizationDataset(*args, **kwargs): URLS = ["https://s3.amazonaws.com/opennmt-models/Summary/cnndm.tar.gz"] def _setup_datasets( - url, top_n=-1, local_cache_path=".data", raw=False, prepare_extractive=True + url, top_n=-1, local_cache_path=".data", sent_split=True, prepare_extractive=True ): FILE_NAME = "cnndm.tar.gz" maybe_download(url, FILE_NAME, local_cache_path) @@ -93,7 +93,7 @@ def _setup_datasets( test_source_file = fname if fname.endswith("test.txt.tgt.tagged"): test_target_file = fname - if raw: + if not sent_split: return ( SummarizationDataset( train_source_file, From 88352d67aa75da4799c049fb3f250ceddb72e9ac Mon Sep 17 00:00:00 2001 From: Daisy Deng Date: Tue, 26 May 2020 20:30:37 +0000 Subject: [PATCH 13/14] add time duration for quick=false --- ...stractive_summarization_bartt5_cnndm.ipynb | 561 ++++++++---------- .../abstractive_summarization_bartt5.py | 121 ++-- 2 files changed, 322 insertions(+), 360 deletions(-) diff --git a/examples/text_summarization/abstractive_summarization_bartt5_cnndm.ipynb b/examples/text_summarization/abstractive_summarization_bartt5_cnndm.ipynb index db3584c24..194e2a1d2 100644 --- a/examples/text_summarization/abstractive_summarization_bartt5_cnndm.ipynb +++ b/examples/text_summarization/abstractive_summarization_bartt5_cnndm.ipynb @@ -25,7 +25,7 @@ "\n", "### Before You Start\n", "\n", - "Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of steps. If QUICK_RUN = True, the notebook takes about 5 minutes to run on a VM with 1 Tesla K80 GPUs with 12GB GPU memory.\n", + "Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of steps. If QUICK_RUN = True, the notebook takes about 5 minutes to run on a VM with 1 Tesla K80 GPUs with 12GB GPU memory. If QUICK_RUN = False, it takes around 15 minutes for data preprocessing, 15 minutes for fine-tuning and 3 hours for running evaluation on the whole CNN/DM test dataset.\n", "\n", "### Additional Notes\n", "\n", @@ -50,7 +50,7 @@ "\n", "%autoreload 2\n", "## Set QUICK_RUN = True to run the notebook on a small subset of data and a smaller number of epochs.\n", - "QUICK_RUN = True" + "QUICK_RUN = False" ] }, { @@ -218,71 +218,7 @@ "parameters" ] }, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "cb2a382fa292493f9bae5794a11848f4", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1197.0, style=ProgressStyle(description…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "f8630ee654a9488aaef1bcb41526253a", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=791656.0, style=ProgressStyle(descripti…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "7d849fb5a1bf4bc0874708299e5b3712", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=242136741.0, style=ProgressStyle(descri…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - } - ], + "outputs": [], "source": [ "# Transformer model being used\n", "# MODEL_NAME = \"bart-large\"\n", @@ -326,17 +262,9 @@ "metadata": { "scrolled": false }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 489k/489k [00:07<00:00, 61.4kKB/s] \n" - ] - } - ], + "outputs": [], "source": [ - "train_dataset, test_dataset = CNNDMSummarizationDataset(top_n=TOP_N, local_cache_path=DATA_PATH, raw=True)" + "train_dataset, test_dataset = CNNDMSummarizationDataset(top_n=TOP_N, local_cache_path=DATA_PATH, sent_split=False)" ] }, { @@ -365,18 +293,17 @@ "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "100" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "287227\n", + "11490\n" + ] } ], "source": [ - "len(test_dataset)" + "print(len(train_dataset))\n", + "print(len(test_dataset))" ] }, { @@ -388,7 +315,62 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "from multiprocessing import Pool\n", + "def preprocess(summarizer, input_data_list, num_workers=50, chunk_size=100, internal_batch_size=5e3):\n", + " \"\"\" preprocess the data for abstractive summarization.\n", + "\n", + " Args:\n", + " input_data_list (list of dictionary): input list where each item is\n", + " an dictionary with fields \"src\" and \"tgt\" and both fields are string.\n", + " num_workers (int, optional): The number of workers in the pool.\n", + " Defautls to 50.\n", + " chunk_size (int, optional): The size that a worker processes.\n", + " Defaults to 100.\n", + " internal_batch_size (int, optional): The size that one pool processes.\n", + " Defaults to 5000. Reduce this number if you see segment fault.\n", + "\n", + " Returns:\n", + " list of dictionary with addtional fields \"source_ids\",\n", + " \"source_mask\" and \"target_ids\".\n", + " \"\"\"\n", + " i = 0\n", + " temp_dir = TemporaryDirectory().name\n", + " os.makedirs(temp_dir, mode=0o777, exist_ok=False)\n", + " temp_file = \".temp_preprocess\"\n", + " processed_length = 0\n", + " result = []\n", + " print(len(input_data_list))\n", + " pool = Pool(num_workers, initializer=summarizer.processor.initializer)\n", + " while processed_length < len(input_data_list):\n", + " max_length = int(min(processed_length+internal_batch_size, len(input_data_list)))\n", + " temp = []\n", + " for j in range(processed_length, max_length):\n", + " temp.append(input_data_list[j])\n", + " result_generator = pool.imap(summarizer.processor.encode_example, temp, chunk_size)\n", + " torch.save(list(result_generator), os.path.join(temp_dir, temp_file+str(i)))\n", + " i += 1\n", + " processed_length = max_length\n", + " #print(processed_length)\n", + "\n", + " pool.close()\n", + " pool.join()\n", + " result = []\n", + " total_batch_number = i\n", + " for i in range(total_batch_number):\n", + " result.extend(torch.load(os.path.join(temp_dir, temp_file+str(i))))\n", + " if os.path.exists(temp_dir):\n", + " shutil.rmtree(temp_dir, ignore_errors=True)\n", + " return result\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, "metadata": { "scrolled": false }, @@ -397,40 +379,64 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 485 ms, sys: 7.85 ms, total: 493 ms\n", - "Wall time: 492 ms\n" + "287227\n", + "CPU times: user 1min 43s, sys: 57.7 s, total: 2min 40s\n", + "Wall time: 13min 22s\n" ] } ], "source": [ "%%time\n", - "abs_sum_train = summarizer.processor.preprocess(train_dataset)\n", + "# abs_sum_train = summarizer.processor.preprocess(train_dataset)\n", + "abs_sum_train = preprocess(summarizer, train_dataset)\n", "# torch.save(abs_sum_train, os.path.join(DATA_PATH, \"train_{0}_full.pt\".format(MODEL_NAME)))\n", "# abs_sum_train = torch.load(os.path.join(DATA_PATH, \"train_{0}_full.pt\".format(MODEL_NAME)))" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ - "abs_sum_test = summarizer.processor.preprocess(test_dataset)\n", + "# abs_sum_train = torch.load(os.path.join(DATA_PATH, \"train_{0}_full.pt\".format(MODEL_NAME)))\n", + "\n", "# torch.save(abs_sum_test, os.path.join(DATA_PATH, \"test_{0}_full.pt\".format(MODEL_NAME)))\n", "# abs_sum_test = torch.load(os.path.join(DATA_PATH, \"test_{0}_full.pt\".format(MODEL_NAME)))" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "100\n", - "100\n" + "11490\n", + "CPU times: user 4.37 s, sys: 5.82 s, total: 10.2 s\n", + "Wall time: 37.6 s\n" + ] + } + ], + "source": [ + "%%time\n", + "# abs_sum_test= summarizer.processor.preprocess(test_dataset)\n", + "abs_sum_test= preprocess(summarizer, test_dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "287227\n", + "11490\n" ] } ], @@ -448,16 +454,16 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "dict_keys(['src', 'src_txt', 'tgt', 'tgt_txt', 'source_ids', 'source_mask', 'target_ids'])" + "dict_keys(['source_ids', 'source_mask', 'target_ids'])" ] }, - "execution_count": 12, + "execution_count": 14, "metadata": {}, "output_type": "execute_result" } @@ -468,7 +474,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 15, "metadata": { "scrolled": false }, @@ -476,11 +482,7 @@ { "data": { "text/plain": [ - "{'src': 'editor \\'s note : in our behind the scenes series , cnn correspondents share their experiences in covering news and analyze the stories behind the events . here , soledad o\\'brien takes users inside a jail where many of the inmates are mentally ill . an inmate housed on the \" forgotten floor , \" where many mentally ill inmates are housed in miami before trial . miami , florida ( cnn ) -- the ninth floor of the miami-dade pretrial detention facility is dubbed the \" forgotten floor . \" here , inmates with the most severe mental illnesses are incarcerated until they \\'re ready to appear in court . most often , they face drug charges or charges of assaulting an officer -- charges that judge steven leifman says are usually \" avoidable felonies . \" he says the arrests often result from confrontations with police . mentally ill people often wo n\\'t do what they \\'re told when police arrive on the scene -- confrontation seems to exacerbate their illness and they become more paranoid , delusional , and less likely to follow directions , according to leifman . so , they end up on the ninth floor severely mentally disturbed , but not getting any real help because they \\'re in jail . we toured the jail with leifman . he is well known in miami as an advocate for justice and the mentally ill . even though we were not exactly welcomed with open arms by the guards , we were given permission to shoot videotape and tour the floor . go inside the ` forgotten floor \\' \" at first , it \\'s hard to determine where the people are . the prisoners are wearing sleeveless robes . imagine cutting holes for arms and feet in a heavy wool sleeping bag -- that \\'s kind of what they look like . they \\'re designed to keep the mentally ill patients from injuring themselves . that \\'s also why they have no shoes , laces or mattresses . leifman says about one-third of all people in miami-dade county jails are mentally ill . so , he says , the sheer volume is overwhelming the system , and the result is what we see on the ninth floor . of course , it is a jail , so it \\'s not supposed to be warm and comforting , but the lights glare , the cells are tiny and it \\'s loud . we see two , sometimes three men -- sometimes in the robes , sometimes naked , lying or sitting in their cells . \" i am the son of the president . you need to get me out of here ! \" one man shouts at me . he is absolutely serious , convinced that help is on the way -- if only he could reach the white house . leifman tells me that these prisoner-patients will often circulate through the system , occasionally stabilizing in a mental hospital , only to return to jail to face their charges . it \\'s brutally unjust , in his mind , and he has become a strong advocate for changing things in miami . over a meal later , we talk about how things got this way for mental patients . leifman says 200 years ago people were considered \" lunatics \" and they were locked up in jails even if they had no charges against them . they were just considered unfit to be in society . over the years , he says , there was some public outcry , and the mentally ill were moved out of jails and into hospitals . but leifman says many of these mental hospitals were so horrible they were shut down . where did the patients go ? nowhere . the streets . they became , in many cases , the homeless , he says . they never got treatment . leifman says in 1955 there were more than half a million people in state mental hospitals , and today that number has been reduced 90 percent , and 40,000 to 50,000 people are in mental hospitals . the judge says he \\'s working to change this . starting in 2008 , many inmates who would otherwise have been brought to the \" forgotten floor \" will instead be sent to a new mental health facility -- the first step on a journey toward long-term treatment , not just punishment . leifman says it \\'s not the complete answer , but it \\'s a start . leifman says the best part is that it \\'s a win-win solution . the patients win , the families are relieved , and the state saves money by simply not cycling these prisoners through again and again . and , for leifman , justice is served . e-mail to a friend .\\n',\n", - " 'src_txt': \"editor 's note : in our behind the scenes series , cnn correspondents share their experiences in covering news and analyze the stories behind the events . here , soledad o'brien takes users inside a jail where many of the inmates are mentally ill . an inmate housed on the `` forgotten floor , '' where many mentally ill inmates are housed in miami before trial . miami , florida -lrb- cnn -rrb- -- the ninth floor of the miami-dade pretrial detention facility is dubbed the `` forgotten floor . '' here , inmates with the most severe mental illnesses are incarcerated until they 're ready to appear in court . most often , they face drug charges or charges of assaulting an officer -- charges that judge steven leifman says are usually `` avoidable felonies . '' he says the arrests often result from confrontations with police . mentally ill people often wo n't do what they 're told when police arrive on the scene -- confrontation seems to exacerbate their illness and they become more paranoid , delusional , and less likely to follow directions , according to leifman . so , they end up on the ninth floor severely mentally disturbed , but not getting any real help because they 're in jail . we toured the jail with leifman . he is well known in miami as an advocate for justice and the mentally ill . even though we were not exactly welcomed with open arms by the guards , we were given permission to shoot videotape and tour the floor . go inside the ` forgotten floor ' '' at first , it 's hard to determine where the people are . the prisoners are wearing sleeveless robes . imagine cutting holes for arms and feet in a heavy wool sleeping bag -- that 's kind of what they look like . they 're designed to keep the mentally ill patients from injuring themselves . that 's also why they have no shoes , laces or mattresses . leifman says about one-third of all people in miami-dade county jails are mentally ill . so , he says , the sheer volume is overwhelming the system , and the result is what we see on the ninth floor . of course , it is a jail , so it 's not supposed to be warm and comforting , but the lights glare , the cells are tiny and it 's loud . we see two , sometimes three men -- sometimes in the robes , sometimes naked , lying or sitting in their cells . `` i am the son of the president . you need to get me out of here ! '' one man shouts at me . he is absolutely serious , convinced that help is on the way -- if only he could reach the white house . leifman tells me that these prisoner-patients will often circulate through the system , occasionally stabilizing in a mental hospital , only to return to jail to face their charges . it 's brutally unjust , in his mind , and he has become a strong advocate for changing things in miami . over a meal later , we talk about how things got this way for mental patients . leifman says 200 years ago people were considered `` lunatics '' and they were locked up in jails even if they had no charges against them . they were just considered unfit to be in society . over the years , he says , there was some public outcry , and the mentally ill were moved out of jails and into hospitals . but leifman says many of these mental hospitals were so horrible they were shut down . where did the patients go ? nowhere . the streets . they became , in many cases , the homeless , he says . they never got treatment . leifman says in 1955 there were more than half a million people in state mental hospitals , and today that number has been reduced 90 percent , and 40,000 to 50,000 people are in mental hospitals . the judge says he 's working to change this . starting in 2008 , many inmates who would otherwise have been brought to the `` forgotten floor '' will instead be sent to a new mental health facility -- the first step on a journey toward long-term treatment , not just punishment . leifman says it 's not the complete answer , but it 's a start . leifman says the best part is that it 's a win-win solution . the patients win , the families are relieved , and the state saves money by simply not cycling these prisoners through again and again . and , for leifman , justice is served . e-mail to a friend .\\n\",\n", - " 'tgt': ' mentally ill inmates in miami are housed on the \" forgotten floor \" judge steven leifman says most are there as a result of \" avoidable felonies \" while cnn tours facility , patient shouts : \" i am the son of the president \" leifman says the system is unjust and he \\'s fighting for change . \\n',\n", - " 'tgt_txt': \" mentally ill inmates in miami are housed on the `` forgotten floor '' judge steven leifman says most are there as a result of `` avoidable felonies '' while cnn tours facility , patient shouts : `` i am the son of the president '' leifman says the system is unjust and he 's fighting for change . \\n\",\n", - " 'source_ids': tensor([21603, 10, 6005, ..., 3, 31, 7]),\n", + "{'source_ids': tensor([21603, 10, 6005, ..., 3, 31, 7]),\n", " 'source_mask': tensor([1, 1, 1, ..., 1, 1, 1]),\n", " 'target_ids': tensor([19367, 3, 1092, 16, 11171, 16, 1337, 3690, 33, 629,\n", " 26, 30, 8, 96, 11821, 1501, 96, 5191, 3, 849,\n", @@ -508,7 +510,7 @@ " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])}" ] }, - "execution_count": 13, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -527,7 +529,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 17, "metadata": { "tags": [ "parameters" @@ -561,14 +563,14 @@ " WARMUP_STEPS=5e2\n", " \n", "# inference parameters\n", - "TEST_PER_GPU_BATCH_SIZE = 32\n", + "TEST_PER_GPU_BATCH_SIZE = 96\n", "BEAM_SIZE = 3\n", " " ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 17, "metadata": { "scrolled": true }, @@ -577,9 +579,38 @@ "name": "stdout", "output_type": "stream", "text": [ - "timestamp: 20/05/2020 19:11:57, average loss: 2.734279, time duration: 84.708848,\n", + "timestamp: 21/05/2020 14:55:00, average loss: 2.950966, time duration: 91.118061,\n", " number of examples in current reporting: 400, step 100\n", - " out of total 100\n", + " out of total 1000\n", + "timestamp: 21/05/2020 14:56:28, average loss: 2.496747, time duration: 87.725076,\n", + " number of examples in current reporting: 400, step 200\n", + " out of total 1000\n", + "timestamp: 21/05/2020 14:57:57, average loss: 2.232086, time duration: 88.453045,\n", + " number of examples in current reporting: 400, step 300\n", + " out of total 1000\n", + "timestamp: 21/05/2020 14:59:25, average loss: 2.104675, time duration: 88.361590,\n", + " number of examples in current reporting: 400, step 400\n", + " out of total 1000\n", + "timestamp: 21/05/2020 15:00:53, average loss: 1.996439, time duration: 88.524355,\n", + " number of examples in current reporting: 400, step 500\n", + " out of total 1000\n", + "timestamp: 21/05/2020 15:02:22, average loss: 1.918175, time duration: 89.008806,\n", + " number of examples in current reporting: 400, step 600\n", + " out of total 1000\n", + "timestamp: 21/05/2020 15:03:52, average loss: 1.981885, time duration: 89.146901,\n", + " number of examples in current reporting: 400, step 700\n", + " out of total 1000\n", + "timestamp: 21/05/2020 15:05:20, average loss: 1.892776, time duration: 88.711287,\n", + " number of examples in current reporting: 400, step 800\n", + " out of total 1000\n", + "timestamp: 21/05/2020 15:06:49, average loss: 1.842356, time duration: 88.697730,\n", + " number of examples in current reporting: 400, step 900\n", + " out of total 1000\n", + "timestamp: 21/05/2020 15:08:17, average loss: 1.920532, time duration: 88.288858,\n", + " number of examples in current reporting: 400, step 1000\n", + " out of total 1000\n", + "./t5_cache\n", + "saving through pytorch to ./t5_cache/t5-small_step_1000.pt\n", "saving through pytorch to ./t5_cache/fine_tuned/abssum_t5-small.pt\n" ] } @@ -602,87 +633,16 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 18, "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "./t5_cache\n", - "saving through pytorch to ./t5_cache/abssum_modelname_t5-small_steps_100.pt\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "9ad8d415da0842e28f59d09928c8e754", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1197.0, style=ProgressStyle(description…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "c6101747540d46afa113cda9a4977c41", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=791656.0, style=ProgressStyle(descripti…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "987b798dca0b489cb4fb8aabaace0ba5", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=242136741.0, style=ProgressStyle(descri…" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n" - ] - }, { "data": { "text/plain": [ - "" + "'\\nimport torch\\nmodel_path = os.path.join(\\n CACHE_DIR,\\n \"abssum_modelname_{0}_steps_{1}.pt\".format(\\n MODEL_NAME, MAX_STEPS\\n ))\\nsummarizer.save_model(global_step=MAX_STEPS, full_name=model_path)\\n\\nsummarizer = AbstractiveSummarizer(MODEL_NAME, cache_dir=CACHE_DIR)\\nsummarizer.model.load_state_dict(torch.load(model_path, map_location=\"cpu\")[\\'model\\'])\\n'" ] }, - "execution_count": 31, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } @@ -710,46 +670,69 @@ "### Model Evaluation\n", "\n", "[ROUGE](https://en.wikipedia.org/wiki/ROUGE_(metric)), or Recall-Oriented Understudy for Gisting Evaluation has been commonly used for evaluating text summarization. \n", - "For the settings in this notebook with QUICK_RUN=True, you should get ROUGE scores close to the following numbers:\n", + "For the settings in this notebook with QUICK_RUN=False, you should get ROUGE scores close to the following numbers:\n", "\n", "``\n", - "{'rouge-1': {'f': 0.31741527202292646,\n", - " 'p': 0.3455155118288276,\n", - " 'r': 0.3045104334747269},\n", - " 'rouge-2': {'f': 0.12227435906684982,\n", - " 'p': 0.13407308558314568,\n", - " 'r': 0.11687233771002672},\n", - " 'rouge-l': {'f': 0.23522707640246865,\n", - " 'p': 0.2558803081762467,\n", - " 'r': 0.22589352441506083}}\n", + "{'rouge-1': {'f': 0.3532833731474843,\n", + " 'p': 0.5062112092750258,\n", + " 'r': 0.2854026986121758},\n", + " 'rouge-2': {'f': 0.1627400891022247,\n", + " 'p': 0.23802173638805246,\n", + " 'r': 0.13034686738843493},\n", + " 'rouge-l': {'f': 0.2587374492685969,\n", + " 'p': 0.3710902340617733,\n", + " 'r': 0.20909466938819835}}\n", " `` " ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "source = []\n", "target = []\n", - "for i in abs_sum_test:\n", + "\n", + " \n", + "for i in test_dataset:\n", " source.append(i[\"src_txt\"]) \n", " target.append(i['tgt'].replace(\"\",\"\").replace(\"\", \"\").replace(\"\\n\", \"\")) " ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 19, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\r", + "Generating summary: 0%| | 0/120 [00:00 Date: Tue, 26 May 2020 20:35:18 +0000 Subject: [PATCH 14/14] update documentation --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d436397d4..7ef16a522 100755 --- a/README.md +++ b/README.md @@ -50,7 +50,7 @@ The following is a summary of the commonly used NLP scenarios covered in the rep |-------------------------| ------------------- |-------|---| |Text Classification |BERT, XLNet, RoBERTa| Text classification is a supervised learning method of learning and predicting the category or the class of a document given its text content. |English, Hindi, Arabic| |Named Entity Recognition |BERT| Named entity recognition (NER) is the task of classifying words or key phrases of a text into predefined entities of interest. |English| -|Text Summarization|BERTSumExt
BERTSumAbs
UniLM (s2s-ft)
MiniLM |Text summarization is a language generation task of summarizing the input text into a shorter paragraph of text.|English +|Text Summarization|BERTSumExt
BERTSumAbs
UniLM (s2s-ft)
MiniLM
T5
BART|Text summarization is a language generation task of summarizing the input text into a shorter paragraph of text.|English |Entailment |BERT, XLNet, RoBERTa| Textual entailment is the task of classifying the binary relation between two natural-language texts, *text* and *hypothesis*, to determine if the *text* agrees with the *hypothesis* or not. |English| |Question Answering |BiDAF, BERT, XLNet| Question answering (QA) is the task of retrieving or generating a valid answer for a given query in natural language, provided with a passage related to the query. |English| |Sentence Similarity |BERT, GenSen| Sentence similarity is the process of computing a similarity score given a pair of text documents. |English|