Skip to content

Commit

Permalink
fix bug: max_seq_length mismatch between student and teacher; keep fu…
Browse files Browse the repository at this point in the history
…ll precision of pseudo labels
  • Loading branch information
kwang2049 committed Jan 11, 2022
1 parent 59526e3 commit 9c17ecd
Show file tree
Hide file tree
Showing 8 changed files with 153 additions and 124 deletions.
33 changes: 31 additions & 2 deletions gpl/toolkit/pl.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from sentence_transformers import CrossEncoder
from .dataset import HardNegativeDataset
from torch.utils.data import DataLoader
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
import tqdm
import os
import logging
Expand All @@ -15,16 +17,41 @@ def hard_negative_collate_fn(batch):

class PseudoLabeler(object):

def __init__(self, generated_path, gen_queries, corpus, total_steps, batch_size, cross_encoder, sep=' '):
def __init__(self, generated_path, gen_queries, corpus, total_steps, batch_size, cross_encoder, max_seq_length):
assert 'hard-negatives.jsonl' in os.listdir(generated_path)
fpath_hard_negatives = os.path.join(generated_path, 'hard-negatives.jsonl')
self.cross_encoder = CrossEncoder(cross_encoder)
hard_negative_dataset = HardNegativeDataset(fpath_hard_negatives, gen_queries, corpus, sep)
hard_negative_dataset = HardNegativeDataset(fpath_hard_negatives, gen_queries, corpus)
self.hard_negative_dataloader = DataLoader(hard_negative_dataset, shuffle=True, batch_size=batch_size, drop_last=True)
self.hard_negative_dataloader.collate_fn = hard_negative_collate_fn
self.output_path = os.path.join(generated_path, 'gpl-training-data.tsv')
self.total_steps = total_steps

#### retokenization
self.retokenizer = AutoTokenizer.from_pretrained(cross_encoder)
self.max_seq_length = max_seq_length

def retokenize(self, texts):
## We did this retokenization for two reasons:
### (1) Setting the max_seq_length;
### (2) We cannot simply use CrossEncoder(cross_encoder, max_length=max_seq_length),
##### since the max_seq_length will then be reflected on the concatenated sequence,
##### rather than the two sequences independently
texts = list(map(lambda text: text.strip(), texts))
features = self.retokenizer(
texts,
padding=True,
truncation='longest_first',
return_tensors="pt",
max_length=self.max_seq_length
)
decoded = self.retokenizer.batch_decode(
features['input_ids'],
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
return decoded

def run(self):
# header: 'query_id', 'positive_id', 'negative_id', 'pseudo_label_margin'
data = []
Expand All @@ -39,11 +66,13 @@ def run(self):
batch = next(hard_negative_iterator)

(query_id, pos_id, neg_id), (query, pos, neg) = batch
query, pos, neg = [self.retokenize(texts) for texts in [query, pos, neg]]
scores = self.cross_encoder.predict(
list(zip(query, pos)) + list(zip(query, neg)),
show_progress_bar=False
)
labels = scores[:len(query)] - scores[len(query):]
labels = labels.tolist() # Using `tolist` will keep more precision digits!!!

batch_gpl = map(lambda quad: '\t'.join((*quad[:3], str(quad[3]))) + '\n', zip(query_id, pos_id, neg_id, labels))
data.extend(batch_gpl)
Expand Down
2 changes: 1 addition & 1 deletion gpl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def train(
logger.info('Using existing GPL-training data')
else:
logger.info('No GPL-training data found. Now generating it via pseudo labeling')
pseudo_labeler = PseudoLabeler(path_to_generated_data, gen_queries, corpus, gpl_steps, batch_size_gpl, cross_encoder, sep)
pseudo_labeler = PseudoLabeler(path_to_generated_data, gen_queries, corpus, gpl_steps, batch_size_gpl, cross_encoder, max_seq_length)
pseudo_labeler.run()


Expand Down
20 changes: 10 additions & 10 deletions sample-data/generated/fiqa/corpus.jsonl

Large diffs are not rendered by default.

40 changes: 20 additions & 20 deletions sample-data/generated/fiqa/gpl-training-data.tsv
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
genQ5 257122 460230 18.465141
genQ19 511432 257122 15.421183
genQ26 182744 214079 19.373644
genQ3 460230 511432 18.806072
genQ15 214079 460230 12.678501
genQ10 35856 257122 17.834743
genQ28 281423 120306 17.63448
genQ29 281423 257122 17.699568
genQ21 511432 35856 14.623292
genQ7 120306 257122 13.6840515
genQ24 78297 281423 19.405273
genQ23 78297 120306 17.731052
genQ1 460230 316535 20.700878
genQ4 257122 182744 7.277392
genQ16 316535 120306 10.933559
genQ6 257122 316535 1.4065981
genQ11 35856 257122 11.795601
genQ20 511432 316535 14.831335
genQ2 460230 78297 5.676446
genQ30 281423 78297 17.271576
genQ28 585889 269846 9.46237564086914
genQ7 375658 122794 15.471757888793945
genQ12 93353 454072 19.189910888671875
genQ3 454072 375658 18.5987491607666
genQ29 585889 269846 11.639785766601562
genQ2 454072 523755 21.046403884887695
genQ11 93353 269846 17.41644287109375
genQ22 122794 411906 17.80188751220703
genQ24 122794 93353 16.20846939086914
genQ17 269846 479420 14.049872398376465
genQ20 523755 269846 9.419671058654785
genQ25 479420 269846 15.41370677947998
genQ19 523755 269846 2.478337287902832
genQ16 269846 585889 16.2314453125
genQ5 150252 523755 15.29780387878418
genQ21 523755 375658 11.556497573852539
genQ27 479420 523755 11.202600479125977
genQ18 269846 122794 17.46649742126465
genQ4 150252 122794 18.039897918701172
genQ6 150252 523755 15.562084197998047
Loading

0 comments on commit 9c17ecd

Please sign in to comment.