From 78868d8108725e7bd108fa19d6103ab99a3fb93e Mon Sep 17 00:00:00 2001 From: jimme0421 Date: Tue, 17 Aug 2021 11:19:41 +0800 Subject: [PATCH 1/2] :bug: --- ark_nlp/factory/optimizer/__init__.py | 2 +- .../model/re/casrel_bert/casrel_relation_extraction_task.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/ark_nlp/factory/optimizer/__init__.py b/ark_nlp/factory/optimizer/__init__.py index cbaeb95..cbbf854 100644 --- a/ark_nlp/factory/optimizer/__init__.py +++ b/ark_nlp/factory/optimizer/__init__.py @@ -80,7 +80,7 @@ def get_default_optimizer(module, module_name='bert', **kwargs): if module_name == 'bert': return get_default_bert_optimizer(module, **kwargs) elif module_name == 'crf_bert': - return get_default_bert_crf_optimizer(module, **kwargs) + return get_default_crf_bert_optimizer(module, **kwargs) else: raise ValueError("The default optimizer does not exist") diff --git a/ark_nlp/model/re/casrel_bert/casrel_relation_extraction_task.py b/ark_nlp/model/re/casrel_bert/casrel_relation_extraction_task.py index 183c898..63abc00 100644 --- a/ark_nlp/model/re/casrel_bert/casrel_relation_extraction_task.py +++ b/ark_nlp/model/re/casrel_bert/casrel_relation_extraction_task.py @@ -176,11 +176,15 @@ def _compute_loss( loss = self.loss_function(logits, inputs) + if self.logs: + self._compute_loss_record(inputs, inputs['label_ids'], logits, loss, verbose, **kwargs) + return loss def _compute_loss_record( self, - inputs, + inputs, + lables, logits, loss, verbose, From b14b29c3bd701ef482779728785d67b2ddee0cd6 Mon Sep 17 00:00:00 2001 From: xiangking Date: Tue, 17 Aug 2021 13:37:21 +0800 Subject: [PATCH 2/2] fix (casrel bert) : Fixing loss log bug --- .../casrel_relation_extraction_task.py | 20 ++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/ark_nlp/model/re/casrel_bert/casrel_relation_extraction_task.py b/ark_nlp/model/re/casrel_bert/casrel_relation_extraction_task.py index 63abc00..361cd52 100644 --- a/ark_nlp/model/re/casrel_bert/casrel_relation_extraction_task.py +++ b/ark_nlp/model/re/casrel_bert/casrel_relation_extraction_task.py @@ -165,6 +165,20 @@ def _on_epoch_begin_record(self, **kwargs): self.logs['epoch_loss'] = 0 self.logs['epoch_example'] = 0 self.logs['epoch_step'] = 0 + + def _get_train_loss( + self, + inputs, + logits, + verbose=True, + **kwargs + ): + # 计算损失 + loss = self._compute_loss(inputs, logits, **kwargs) + + self._compute_loss_record(inputs, logits, loss, verbose, **kwargs) + + return loss def _compute_loss( self, @@ -176,15 +190,11 @@ def _compute_loss( loss = self.loss_function(logits, inputs) - if self.logs: - self._compute_loss_record(inputs, inputs['label_ids'], logits, loss, verbose, **kwargs) - return loss def _compute_loss_record( self, inputs, - lables, logits, loss, verbose, @@ -283,7 +293,7 @@ def fit( logits = self.module(**inputs) # 计算损失 - loss = self._compute_loss(inputs, logits, **kwargs) + loss = self._get_train_loss(inputs, logits, **kwargs) loss = self._on_backward(inputs, logits, loss, **kwargs)