This PyTorch package implements the Multi-Task Deep Neural Networks (MT-DNN) for Natural Language Understanding, as described in:
Xiaodong Liu*, Pengcheng He*, Weizhu Chen and Jianfeng Gao
Multi-Task Deep Neural Networks for Natural Language Understanding
arXiv version
*: Equal contribution
Xiaodong Liu, Pengcheng He, Weizhu Chen and Jianfeng Gao
Improving Multi-Task Deep Neural Networks via Knowledge Distillation for Natural Language Understanding
arXiv version
-
python3.6
-
install requirements
> pip install -r requirements.txt
-
Pull docker
> docker pull allenlao/pytorch-mt-dnn:v0.1
-
Run docker
> docker run -it --rm --runtime nvidia allenlao/pytorch-mt-dnn:v0.1 bash
Please refere the following link if you first use docker: https://docs.docker.com/
-
Download data
> sh download.sh
Please refer to download GLUE dataset: https://gluebenchmark.com/ -
Preprocess data
> python prepro.py
-
Training
> python train.py
Note that we ran experiments on 4 V100 GPUs for base MT-DNN models. You may need to reduce batch size for other GPUs.
-
MTL refinement: refine MT-DNN (shared layers), initialized with the pre-trained BERT model, via MTL using all GLUE tasks excluding WNLI to learn a new shared representation.
Note that we ran this experiment on 8 V100 GPUs (32G) with a batch size of 32.- Preprocess GLUE data via the aforementioned script
- Training:
>scripts\run_mt_dnn.sh
-
Finetuning: finetune MT-DNN to each of the GLUE tasks to get task-specific models.
Here, we provide two examples, STS-B and RTE. You can use similar scripts to finetune all the GLUE tasks.- Finetune on the STS-B task
> scripts\run_stsb.sh
You should get about 90.5/90.4 on STS-B dev in terms of Pearson/Spearman correlation. - Finetune on the RTE task
> scripts\run_rte.sh
You should get about 83.8 on RTE dev in terms of accuracy.
- Finetune on the STS-B task
-
Domain Adaptation on SciTail
>scripts\scitail_domain_adaptation_bash.sh
-
Domain Adaptation on SNLI
>scripts\snli_domain_adaptation_bash.sh
[ ] Release codes/models MT-DNN with Knowledge Distillation.
[ ] Publish pretrained Tensorflow checkpoints.
Yes, we released the pretrained shared embedings via MTL which are aligned to BERT base/large models: mt_dnn_base.pt
and mt_dnn_large.pt
.
To obtain the similar models:
- run the
>sh scripts\run_mt_dnn.sh
, and then pick the best checkpoint based on the average dev preformance of MNLI/RTE. - strip the task-specific layers via
scritps\strip_model.py
.
For SciTail/SNLI tasks, the purpose is to test generalization of the learned embedding and how easy it is adapted to a new domain instead of complicated model structures for a direct comparison with BERT. Thus, we use a linear projection on the all domain adaptation settings.
The difference is in the QNLI dataset. Please refere to the GLUE official homepage for more details.
We can use the multi-task refinement model to run the prediction and produce a reasonable result. But to achieve a better result, it requires a fine-tuneing on each task. It is worthing noting the paper in arxiv is a littled out-dated and on the old GLUE dataset. We will update the paper as we mentioned below.
BERT pytorch is from: https://github.com/huggingface/pytorch-pretrained-BERT
BERT: https://github.com/google-research/bert
We also used some code from: https://github.com/kevinduh/san_mrc
For now, please cite arXiv version:
@article{liu2019mt-dnn,
title={Multi-Task Deep Neural Networks for Natural Language Understanding},
author={Liu, Xiaodong and He, Pengcheng and Chen, Weizhu and Gao, Jianfeng},
journal={arXiv preprint arXiv:1901.11504},
year={2019}
}
and a new version of the paper will be shared later.
@article{liu2019mt-dnn-kd,
title={Improving Multi-Task Deep Neural Networks via Knowledge Distillation for Natural Language Understanding},
author={Liu, Xiaodong and He, Pengcheng and Chen, Weizhu and Gao, Jianfeng},
journal={arXiv preprint arXiv:1904.09482},
year={2019}
}
Typo: there is no activation fuction in Equation 2.
For help or issues using MT-DNN, please submit a GitHub issue.
For personal communication related to MT-DNN, please contact Xiaodong Liu (xiaodl@microsoft.com
), Pengcheng He (penhe@microsoft.com
), Weizhu Chen (wzchen@microsoft.com
) or Jianfeng Gao (jfgao@microsoft.com
).