Skip to content

Commit

Permalink
use hgf ModelOutput class
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewyates committed Aug 16, 2022
1 parent a0f89dd commit f7bfaaa
Showing 1 changed file with 1 addition and 4 deletions.
5 changes: 1 addition & 4 deletions capreolus/reranker/TFCEDRKNRM.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,11 +137,8 @@ def call(self, x, **kwargs):
doc_seg = tf.reshape(doc_seg, [batch_size * self.num_passages, self.maxseqlen])

# get BERT embeddings (including CLS) for each passage
# TODO switch to hgf's ModelOutput after bumping tranformers version
outputs = self.bert(doc_input, attention_mask=doc_mask, token_type_ids=doc_seg)
if self.config["pretrained"].startswith("bert-"):
outputs = (outputs[0], outputs[2])
bert_output, all_layer_output = outputs
bert_output, all_layer_output = outputs.last_hidden_state, outputs.hidden_states

# embeddings to create the CLS feature
cls = bert_output[:, 0, :]
Expand Down

0 comments on commit f7bfaaa

Please sign in to comment.