Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Contextualized bias mitigation (#5176)
Browse files Browse the repository at this point in the history
* added linear and hard debiasers

* worked on documentation

* committing changes before branch switch

* committing changes before switching branch

* finished bias direction, linear and hard debiasers, need to write tests

* finished bias direction test

* Commiting changes before switching branch

* finished hard and linear debiasers

* finished OSCaR

* bias mitigators tests and bias metrics remaining

* added bias mitigator tests

* added bias mitigator tests

* finished tests for bias mitigation methods

* fixed gpu issues

* fixed gpu issues

* fixed gpu issues

* resolve issue with count_nonzero not being differentiable

* added more references

* fairness during finetuning

* finished bias mitigator wrapper

* added reference

* updated CHANGELOG and fixed minor docs issues

* move id tensors to embedding device

* fixed to use predetermined bias direction

* fixed minor doc errors

* snli reader registration issue

* fixed _pretrained from params issue

* fixed device issues

* evaluate bias mitigation initial commit

* finished evaluate bias mitigation

* handles multiline prediction files

* fixed minor bugs

* fixed minor bugs

* improved prediction diff JSON format

* forgot to resolve a conflict

* Refactored evaluate bias mitigation to use NLI metric

* Added SNLIPredictionsDiff class

* ensured dataloader is same for bias mitigated and baseline models

* finished evaluate bias mitigation

* Update CHANGELOG.md

* Replaced local data files with github raw content links

* Update allennlp/fairness/bias_mitigator_applicator.py

Co-authored-by: Pete <petew@allenai.org>

* deleted evaluate_bias_mitigation from git tracking

* removed evaluate-bias-mitigation instances from rest of repo

* addressed Akshita's comments

* moved bias mitigator applicator test to allennlp-models

* removed unnecessary files

Co-authored-by: Arjun Subramonian <arjuns@Arjuns-MacBook-Pro.local>
Co-authored-by: Arjun Subramonian <arjuns@ip-192-168-0-106.us-west-2.compute.internal>
Co-authored-by: Arjun Subramonian <arjuns@ip-192-168-0-108.us-west-2.compute.internal>
Co-authored-by: Arjun Subramonian <arjuns@ip-192-168-1-108.us-west-2.compute.internal>
Co-authored-by: Akshita Bhagia <akshita23bhagia@gmail.com>
Co-authored-by: Pete <petew@allenai.org>
  • Loading branch information
7 people authored Jun 2, 2021
1 parent aa52a9a commit b92fd9a
Show file tree
Hide file tree
Showing 12 changed files with 2,557 additions and 4 deletions.
7 changes: 4 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
### Added

- Added `TaskSuite` base class and command line functionality for running [`checklist`](https://github.com/marcotcr/checklist) test suites, along with implementations for `SentimentAnalysisSuite`, `QuestionAnsweringSuite`, and `TextualEntailmentSuite`. These can be found in the `allennlp.confidence_checks.task_checklists` module.
- Added `BiasMitigatorApplicator`, which wraps any Model and mitigates biases by finetuning
on a downstream task.
- Added `allennlp diff` command to compute a diff on model checkpoints, analogous to what `git diff` does on two files.
- Meta data defined by the class `allennlp.common.meta.Meta` is now saved in the serialization directory and archive file
when training models from the command line. This is also now part of the `Archive` named tuple that's returned from `load_archive()`.
Expand Down Expand Up @@ -54,7 +56,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed `wandb` callback to work in distributed training.
- Fixed `tqdm` logging into multiple files with `allennlp-optuna`.


## [v2.4.0](https://github.com/allenai/allennlp/releases/tag/v2.4.0) - 2021-04-22

### Added
Expand All @@ -80,8 +81,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Add new dimension to the `interpret` module: influence functions via the `InfluenceInterpreter` base class, along with a concrete implementation: `SimpleInfluence`.
- Added a `quiet` parameter to the `MultiProcessDataLoading` that disables `Tqdm` progress bars.
- The test for distributed metrics now takes a parameter specifying how often you want to run it.
- Created the fairness module and added four fairness metrics: `Independence`, `Separation`, and `Sufficiency`.
- Added three bias metrics to the fairness module: `WordEmbeddingAssociationTest`, `EmbeddingCoherenceTest`, `NaturalLanguageInference`, and `AssociationWithoutGroundTruth`.
- Created the fairness module and added three fairness metrics: `Independence`, `Separation`, and `Sufficiency`.
- Added four bias metrics to the fairness module: `WordEmbeddingAssociationTest`, `EmbeddingCoherenceTest`, `NaturalLanguageInference`, and `AssociationWithoutGroundTruth`.
- Added four bias direction methods (`PCABiasDirection`, `PairedPCABiasDirection`, `TwoMeansBiasDirection`, `ClassificationNormalBiasDirection`) and four bias mitigation methods (`LinearBiasMitigator`, `HardBiasMitigator`, `INLPBiasMitigator`, `OSCaRBiasMitigator`).

### Changed
Expand Down
17 changes: 16 additions & 1 deletion allennlp/fairness/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
1. measure the fairness of models according to multiple definitions of fairness
2. measure bias amplification
3. debias embeddings during training time and post-processing
3. mitigate bias in static and contextualized embeddings during training time and
post-processing
"""

from allennlp.fairness.fairness_metrics import Independence, Separation, Sufficiency
Expand All @@ -25,3 +26,17 @@
INLPBiasMitigator,
OSCaRBiasMitigator,
)
from allennlp.fairness.bias_utils import load_words, load_word_pairs
from allennlp.fairness.bias_mitigator_applicator import BiasMitigatorApplicator
from allennlp.fairness.bias_mitigator_wrappers import (
HardBiasMitigatorWrapper,
LinearBiasMitigatorWrapper,
INLPBiasMitigatorWrapper,
OSCaRBiasMitigatorWrapper,
)
from allennlp.fairness.bias_direction_wrappers import (
PCABiasDirectionWrapper,
PairedPCABiasDirectionWrapper,
TwoMeansBiasDirectionWrapper,
ClassificationNormalBiasDirectionWrapper,
)
269 changes: 269 additions & 0 deletions allennlp/fairness/bias_direction_wrappers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
import torch
from typing import Union, Optional
from os import PathLike

from allennlp.fairness.bias_direction import (
BiasDirection,
PCABiasDirection,
PairedPCABiasDirection,
TwoMeansBiasDirection,
ClassificationNormalBiasDirection,
)
from allennlp.fairness.bias_utils import load_word_pairs, load_words

from allennlp.common import Registrable
from allennlp.data.tokenizers.tokenizer import Tokenizer
from allennlp.data import Vocabulary


class BiasDirectionWrapper(Registrable):
"""
Parent class for bias direction wrappers.
"""

def __init__(self):
self.direction: BiasDirection = None
self.noise: float = None

def __call__(self, module):
raise NotImplementedError

def train(self, mode: bool = True):
"""
# Parameters
mode : `bool`, optional (default=`True`)
Sets `requires_grad` to value of `mode` for bias direction.
"""
self.direction.requires_grad = mode

def add_noise(self, t: torch.Tensor):
"""
# Parameters
t : `torch.Tensor`
Tensor to which to add small amount of Gaussian noise.
"""
return t + self.noise * torch.randn(t.size(), device=t.device)


@BiasDirectionWrapper.register("pca")
class PCABiasDirectionWrapper(BiasDirectionWrapper):
"""
# Parameters
seed_words_file : `Union[PathLike, str]`
Path of file containing seed words.
tokenizer : `Tokenizer`
Tokenizer used to tokenize seed words.
direction_vocab : `Vocabulary`, optional (default=`None`)
Vocabulary of tokenizer. If `None`, assumes tokenizer is of
type `PreTrainedTokenizer` and uses tokenizer's `vocab` attribute.
namespace : `str`, optional (default=`"tokens"`)
Namespace of direction_vocab to use when tokenizing.
Disregarded when direction_vocab is `None`.
requires_grad : `bool`, optional (default=`False`)
Option to enable gradient calculation for bias direction.
noise : `float`, optional (default=`1e-10`)
To avoid numerical instability if embeddings are initialized uniformly.
"""

def __init__(
self,
seed_words_file: Union[PathLike, str],
tokenizer: Tokenizer,
direction_vocab: Optional[Vocabulary] = None,
namespace: str = "tokens",
requires_grad: bool = False,
noise: float = 1e-10,
):
self.ids = load_words(seed_words_file, tokenizer, direction_vocab, namespace)
self.direction = PCABiasDirection(requires_grad=requires_grad)
self.noise = noise

def __call__(self, module):
# embed subword token IDs and mean pool to get
# embedding of original word
ids_embeddings = []
for i in self.ids:
i = i.to(module.weight.device)
ids_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True))
ids_embeddings = torch.cat(ids_embeddings)

# adding trivial amount of noise
# to eliminate linear dependence amongst all embeddings
# when training first starts
ids_embeddings = self.add_noise(ids_embeddings)

return self.direction(ids_embeddings)


@BiasDirectionWrapper.register("paired_pca")
class PairedPCABiasDirectionWrapper(BiasDirectionWrapper):
"""
# Parameters
seed_word_pairs_file : `Union[PathLike, str]`
Path of file containing seed word pairs.
tokenizer : `Tokenizer`
Tokenizer used to tokenize seed words.
direction_vocab : `Vocabulary`, optional (default=`None`)
Vocabulary of tokenizer. If `None`, assumes tokenizer is of
type `PreTrainedTokenizer` and uses tokenizer's `vocab` attribute.
namespace : `str`, optional (default=`"tokens"`)
Namespace of direction_vocab to use when tokenizing.
Disregarded when direction_vocab is `None`.
requires_grad : `bool`, optional (default=`False`)
Option to enable gradient calculation for bias direction.
noise : `float`, optional (default=`1e-10`)
To avoid numerical instability if embeddings are initialized uniformly.
"""

def __init__(
self,
seed_word_pairs_file: Union[PathLike, str],
tokenizer: Tokenizer,
direction_vocab: Optional[Vocabulary] = None,
namespace: str = "tokens",
requires_grad: bool = False,
noise: float = 1e-10,
):
self.ids1, self.ids2 = load_word_pairs(
seed_word_pairs_file, tokenizer, direction_vocab, namespace
)
self.direction = PairedPCABiasDirection(requires_grad=requires_grad)
self.noise = noise

def __call__(self, module):
# embed subword token IDs and mean pool to get
# embedding of original word
ids1_embeddings = []
for i in self.ids1:
i = i.to(module.weight.device)
ids1_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True))
ids2_embeddings = []
for i in self.ids2:
i = i.to(module.weight.device)
ids2_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True))
ids1_embeddings = torch.cat(ids1_embeddings)
ids2_embeddings = torch.cat(ids2_embeddings)

ids1_embeddings = self.add_noise(ids1_embeddings)
ids2_embeddings = self.add_noise(ids2_embeddings)

return self.direction(ids1_embeddings, ids2_embeddings)


@BiasDirectionWrapper.register("two_means")
class TwoMeansBiasDirectionWrapper(BiasDirectionWrapper):
"""
# Parameters
seed_word_pairs_file : `Union[PathLike, str]`
Path of file containing seed word pairs.
tokenizer : `Tokenizer`
Tokenizer used to tokenize seed words.
direction_vocab : `Vocabulary`, optional (default=`None`)
Vocabulary of tokenizer. If `None`, assumes tokenizer is of
type `PreTrainedTokenizer` and uses tokenizer's `vocab` attribute.
namespace : `str`, optional (default=`"tokens"`)
Namespace of direction_vocab to use when tokenizing.
Disregarded when direction_vocab is `None`.
requires_grad : `bool`, optional (default=`False`)
Option to enable gradient calculation for bias direction.
noise : `float`, optional (default=`1e-10`)
To avoid numerical instability if embeddings are initialized uniformly.
"""

def __init__(
self,
seed_word_pairs_file: Union[PathLike, str],
tokenizer: Tokenizer,
direction_vocab: Optional[Vocabulary] = None,
namespace: str = "tokens",
requires_grad: bool = False,
noise: float = 1e-10,
):
self.ids1, self.ids2 = load_word_pairs(
seed_word_pairs_file, tokenizer, direction_vocab, namespace
)
self.direction = TwoMeansBiasDirection(requires_grad=requires_grad)
self.noise = noise

def __call__(self, module):
# embed subword token IDs and mean pool to get
# embedding of original word
ids1_embeddings = []
for i in self.ids1:
i = i.to(module.weight.device)
ids1_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True))
ids2_embeddings = []
for i in self.ids2:
i = i.to(module.weight.device)
ids2_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True))
ids1_embeddings = torch.cat(ids1_embeddings)
ids2_embeddings = torch.cat(ids2_embeddings)

ids1_embeddings = self.add_noise(ids1_embeddings)
ids2_embeddings = self.add_noise(ids2_embeddings)

return self.direction(ids1_embeddings, ids2_embeddings)


@BiasDirectionWrapper.register("classification_normal")
class ClassificationNormalBiasDirectionWrapper(BiasDirectionWrapper):
"""
# Parameters
seed_word_pairs_file : `Union[PathLike, str]`
Path of file containing seed word pairs.
tokenizer : `Tokenizer`
Tokenizer used to tokenize seed words.
direction_vocab : `Vocabulary`, optional (default=`None`)
Vocabulary of tokenizer. If `None`, assumes tokenizer is of
type `PreTrainedTokenizer` and uses tokenizer's `vocab` attribute.
namespace : `str`, optional (default=`"tokens"`)
Namespace of direction_vocab to use when tokenizing.
Disregarded when direction_vocab is `None`.
noise : `float`, optional (default=`1e-10`)
To avoid numerical instability if embeddings are initialized uniformly.
"""

def __init__(
self,
seed_word_pairs_file: Union[PathLike, str],
tokenizer: Tokenizer,
direction_vocab: Optional[Vocabulary] = None,
namespace: str = "tokens",
noise: float = 1e-10,
):
self.ids1, self.ids2 = load_word_pairs(
seed_word_pairs_file, tokenizer, direction_vocab, namespace
)
self.direction = ClassificationNormalBiasDirection()
self.noise = noise

def __call__(self, module):
# embed subword token IDs and mean pool to get
# embedding of original word
ids1_embeddings = []
for i in self.ids1:
i = i.to(module.weight.device)
ids1_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True))
ids2_embeddings = []
for i in self.ids2:
i = i.to(module.weight.device)
ids2_embeddings.append(torch.mean(module.forward(i), dim=0, keepdim=True))
ids1_embeddings = torch.cat(ids1_embeddings)
ids2_embeddings = torch.cat(ids2_embeddings)

ids1_embeddings = self.add_noise(ids1_embeddings)
ids2_embeddings = self.add_noise(ids2_embeddings)

return self.direction(ids1_embeddings, ids2_embeddings)
2 changes: 2 additions & 0 deletions allennlp/fairness/bias_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,8 @@ class NaturalLanguageInference(Metric):
3. Threshold:tau (T:tau): A parameterized measure that reports the fraction
of examples whose probability of neutral is above tau.
# Parameters
neutral_label : `int`, optional (default=`2`)
The discrete integer label corresponding to a neutral entailment prediction.
taus : `List[float]`, optional (default=`[0.5, 0.7]`)
Expand Down
Loading

0 comments on commit b92fd9a

Please sign in to comment.