diff --git a/hanlp/common/torch_component.py b/hanlp/common/torch_component.py index 82ccf1be0..8765deeab 100644 --- a/hanlp/common/torch_component.py +++ b/hanlp/common/torch_component.py @@ -246,12 +246,13 @@ def fit(self, first_device = -1 if _device_placeholder and first_device >= 0: _dummy_placeholder = self._create_dummy_placeholder_on(first_device) - if finetune: - if isinstance(finetune, str): - self.load(finetune, devices=devices, **self.config) - else: - self.load(save_dir, devices=devices, **self.config) - self.config.finetune = finetune + if finetune or self.model: + if not self.model: + if isinstance(finetune, str): + self.load(finetune, devices=devices, **self.config) + else: + self.load(save_dir, devices=devices, **self.config) + self.config.finetune = finetune or True self.vocabs.unlock() # For extending vocabs logger.info( f'Finetune model loaded with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}' diff --git a/plugins/hanlp_demo/hanlp_demo/zh/train/finetune_ner.py b/plugins/hanlp_demo/hanlp_demo/zh/train/finetune_ner.py index 9086c338b..6a1a2703b 100644 --- a/plugins/hanlp_demo/hanlp_demo/zh/train/finetune_ner.py +++ b/plugins/hanlp_demo/hanlp_demo/zh/train/finetune_ner.py @@ -9,6 +9,7 @@ cdroot() +# 0. Prepare your dataset for finetuning your_training_corpus = 'data/ner/finetune/word_to_iobes.tsv' your_development_corpus = your_training_corpus # Use a different one in reality save_dir = 'data/ner/finetune/model' @@ -25,18 +26,22 @@ ''' ) +# 1. Load a pretrained model for finetuning ner = TransformerNamedEntityRecognizer() +ner.load(hanlp.pretrained.ner.MSRA_NER_ELECTRA_SMALL_ZH) + +# 2. Override hyper-parameters +ner.config['epochs'] = 50 # Since the corpus is small, overfit it + +# 3. Fit on your dataset ner.fit( trn_data=your_training_corpus, dev_data=your_development_corpus, save_dir=save_dir, - epochs=50, # Since the corpus is small, overfit it - finetune=hanlp.pretrained.ner.MSRA_NER_ELECTRA_SMALL_ZH, - # You MUST set the same parameters with the fine-tuning model: - average_subwords=True, - transformer='hfl/chinese-electra-180g-small-discriminator', + **ner.config ) +# 4. Test it out on your data points HanLP = hanlp.pipeline()\ .append(hanlp.load(hanlp.pretrained.tok.FINE_ELECTRA_SMALL_ZH), output_key='tok')\ .append(ner, output_key='ner')