Skip to content

Latest commit

 

History

History
126 lines (85 loc) · 3.39 KB

README.md

File metadata and controls

126 lines (85 loc) · 3.39 KB

Multiple Positives and Negatives Ranking Loss

On AllNLI, training w/ MPNRL has higher training throughput and better memory utilization than training w/ MNRL. They're on par in terms of task performance, but I need to run more experiments.

Why

No-duplicates sampling causes batch sizes to decay if there's high skewnewss in the number of positives per anchors.

Plot for AllNLI, no-duplicates sampling

Reproduce by running:

python compare_dataloaders.py \
    --dataset_name "sentence-transformers/all-nli" \
    --dataset_config "triplet" \
    --dataset_split "train" \
    --batch_size 128 \
    --dataset_size_train 10000 \
    --seed 42

Here are CUDA memory snapshots across time for MNRL + AllNLI (first 10k triplets, inputted batch size of 200):

The drops in memory are caused by drops in the batch size. There is a long tail of under-utilization. Peak usage is determined by the first few batches, which is a small portion of time.

It's simpler to use a loss which seamlessley handles multiple positives. As a result, training throughput is higher, and GPU utilization (in terms of % memory and % time) is more stable. Data loading itself is also 15x faster, as there's no de-duplication.

Plot for All-NLI, plain sampling with grouping in the collator

Here are CUDA memory snapshots across time for MPNRL:

Here's a comparison of time-based GPU utilization:

drawing

The small experiment in ./demos/train_allnli.ipynb demonstrates that task/statistical performance is on par with MNRL.

In an experiment on the first 100k triplets in AllNLI and an inputted batch size of 200, MNRL took ~33 minutes while MPNRL took ~20 minutes. Statistical performance was similar.

Setup

python -m pip install git+https://github.com/kddubey/mpnrl.git

To run ./run.py, clone the repo and then:

python -m pip install ".[demos]"

NOTE: this isn't meant to be a stable Python package. There are many TODOs.

Usage

Make sure to not use the no-duplicates sampler for MPNRL.

from sentence_transformers.sampler import BatchSamplers
from sentence_transformers import (
    SentenceTransformer,
    SentenceTransformerTrainer,
)

import mpnrl

model = SentenceTransformer("your-model")

train_dataset = ...
# records of {"anchor": ..., "positive": ..., "negative"(s): ...}

trainer = SentenceTransformerTrainer(
    model=model
    train_dataset=train_dataset,
    args=SentenceTransformerTrainingArguments(
        ...
        batch_sampler=BatchSamplers.BATCH_SAMPLER,
    ),
    loss=mpnrl.losses.MultiplePositivesNegativesRankingLoss(model),
    data_collator=mpnrl.data_collator.GroupingDataCollator(
        train_dataset, tokenize_fn=model.tokenize
    ),
)

trainer.train()

There's a small demo in ./demos/train_allnli.ipynb.

TODOs

  • mpnrl.collator TODOs.
  • mpnrl.loss TODOs.
  • Measure how long it takes for MNRL vs MPRNL to get to a good model (pearson/spearman correlation on validation data).
  • Repeat for a few datasets and study how the level of data duplication affects these outcomes.