We have trained a robot with rich knowledge and strong interaction ability, which can support chat and advice in various fields such as health, life, study and so on. The basic framework is GPT-2, and the corpus is Chinese Baike Q&A data.
This repo supports two pretrained model, GPT-chit-chat version and CDial-GPT version. Please prepare the environments respectively.
conda create -n chitchat
- transformers==4.4.2
- torch==1.8.0
conda create -n CDial
- transformers==2.1.1
- torch==1.8.0
Please download our pretrained models (GPT-chat-chat, CDial), and easily replace the checkpoint path in interact.py and interact_CDial_GPT.py.
conda activate chitchat
python iteract.py
conda activate CDial
python iteract_CDial_GPT.py
Download the cropus baike2018qa, and then process and encode the data (here we preverse 52w data, 50w data for training, 2w data for validation. You can change the vacab.txt, but alway fix it when processing data, training and interacting.)
python pre_json.py
Thanks to the great repo GPT2-chit-chat, after download their shared pretrained model, we can do this work on selected cropus.
python finetune.py
Thanks to the great model hub CDial-GPT, we download the Large version, and we finetune the model on selected cropus.
python finetune_CDial_GPT.py
- 下载百科问答语料,这里选取的数据集是:百科类问答json版(baike2018qa)。进行数据预处理。
Step 1: 解析百科问答语料的json数据,提取对话内容,完成对话的拼接 [CLS]提问[SEP]回答[SEP].
Step 2: 选取合适的Tokenizer进行text-encoding,进行序列化存储,保存为pkl文件。
Step 3: load dataset,数据集的长度截断到max_len,才可以封装到batch。
- 阅读GPT2基本框架
- 修改部分前端内容,demo如下所示
- TODO: 重载GPT2-Chinese的预训练模型, inference完成。 ddl (May 5, 2022)
- 这里使用的是GPT2-Chinese的通用中文模型-base。当然也可以使用通用中文模型-small,可以加快推理速度。
- Problrms: 如果控制回答的长度(输出回复长度较短)
存在问题:语言模型用来对话效果很差,逻辑性很差。 尝试:使用闲聊语料的预训练模型而不是通用语料的预训练模型,是不是会不一样?
TODO: run起来fine-tune的代码。ddl(May 7, 2022)
train代码整理到Kaggle上 (Done!)
? finetune loss得不到收敛,需要调lr, bs, ls_scheduler......
微调chit-chat 1 epoch
微调chit-chat 2 epoch
Finetuned the gpt-chit-chat model on 32w datasets.
- Inference the finetuned 1 epoch checkpoint, compared with the original chit-chat model.
- Inference the finetuned 4 epochs checkpoint, compared with the original chit-chat model.
- TODO: Inference finetuned model.
- TODO: Transfer PyTorch to PaddlePaddle
[1] https://github.com/yangjianxin1/GPT2-chitchat
[2] https://github.com/thu-coai/CDial-GPT
[3] https://github.com/brightmart/nlp_chinese_corpus
[4] https://github.com/Morizeyao/GPT2-Chinese