Skip to content

Commit

Permalink
Support reusing existing model and config for finetuning, fix: #1942
Browse files Browse the repository at this point in the history
  • Loading branch information
hankcs committed Jan 15, 2025
1 parent 81983df commit 199f3f3
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 11 deletions.
13 changes: 7 additions & 6 deletions hanlp/common/torch_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,13 @@ def fit(self,
first_device = -1
if _device_placeholder and first_device >= 0:
_dummy_placeholder = self._create_dummy_placeholder_on(first_device)
if finetune:
if isinstance(finetune, str):
self.load(finetune, devices=devices, **self.config)
else:
self.load(save_dir, devices=devices, **self.config)
self.config.finetune = finetune
if finetune or self.model:
if not self.model:
if isinstance(finetune, str):
self.load(finetune, devices=devices, **self.config)
else:
self.load(save_dir, devices=devices, **self.config)
self.config.finetune = finetune or True
self.vocabs.unlock() # For extending vocabs
logger.info(
f'Finetune model loaded with {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}'
Expand Down
15 changes: 10 additions & 5 deletions plugins/hanlp_demo/hanlp_demo/zh/train/finetune_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

cdroot()

# 0. Prepare your dataset for finetuning
your_training_corpus = 'data/ner/finetune/word_to_iobes.tsv'
your_development_corpus = your_training_corpus # Use a different one in reality
save_dir = 'data/ner/finetune/model'
Expand All @@ -25,18 +26,22 @@
'''
)

# 1. Load a pretrained model for finetuning
ner = TransformerNamedEntityRecognizer()
ner.load(hanlp.pretrained.ner.MSRA_NER_ELECTRA_SMALL_ZH)

# 2. Override hyper-parameters
ner.config['epochs'] = 50 # Since the corpus is small, overfit it

# 3. Fit on your dataset
ner.fit(
trn_data=your_training_corpus,
dev_data=your_development_corpus,
save_dir=save_dir,
epochs=50, # Since the corpus is small, overfit it
finetune=hanlp.pretrained.ner.MSRA_NER_ELECTRA_SMALL_ZH,
# You MUST set the same parameters with the fine-tuning model:
average_subwords=True,
transformer='hfl/chinese-electra-180g-small-discriminator',
**ner.config
)

# 4. Test it out on your data points
HanLP = hanlp.pipeline()\
.append(hanlp.load(hanlp.pretrained.tok.FINE_ELECTRA_SMALL_ZH), output_key='tok')\
.append(ner, output_key='ner')
Expand Down

0 comments on commit 199f3f3

Please sign in to comment.