Skip to content

Commit

Permalink
Implement Chi Square engine and add different types of contingency ta…
Browse files Browse the repository at this point in the history
…bles.
  • Loading branch information
RobbieMcKinstry committed Oct 28, 2024
1 parent 4e3b550 commit b9a5431
Show file tree
Hide file tree
Showing 9 changed files with 316 additions and 81 deletions.
45 changes: 45 additions & 0 deletions src/adapters/engines/chi.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
use crate::{
metrics::ResponseStatusCode,
stats::{EmpiricalTable, ExpectationTable, Group, Observation},
};

use super::DecisionEngine;

/// The [ChiSquareEngine] uses the Chi Square statistical
/// significance test to determine whether the canary should be promoted or not.
#[derive(Default)]
pub struct ChiSquareEngine {
control_data: ExpectationTable<ResponseStatusCode>,
experimental_data: EmpiricalTable<ResponseStatusCode>,
}

impl DecisionEngine<ResponseStatusCode> for ChiSquareEngine {
// TODO: From writing this method, it's apparent there should be a Vec implementation
// that adds Vec::len() to the total and concats the vectors together, because
// otherwise we're wasting a ton of cycles just incrementing counters.
fn add_observation(&mut self, observation: Observation<ResponseStatusCode>) {
match observation.group {
Group::Control => {
// • Increment the number of observations for this category.
self.control_data.increment(observation.outcome);
}
Group::Experimental => {
// • Increment the number of observations in the canary contingency table.
self.experimental_data.increment(observation.outcome);
// • Then, let the control contingency table know that there was
// another experimental observation.
self.control_data.increment_experimental_total();
}
}
}

fn compute(&mut self) -> Option<super::Action> {
todo!()
}
}

impl ChiSquareEngine {
pub fn new() -> Self {
Self::default()
}
}
2 changes: 2 additions & 0 deletions src/adapters/engines/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::stats::{EnumerableCategory, Observation};
use std::hash::Hash;

pub use action::Action;
pub use chi::ChiSquareEngine;

/// Helper trait, since these requirements are often used by
/// our implementation of `ContingencyTables`.
Expand All @@ -23,6 +24,7 @@ pub trait DecisionEngine<T: HashableCategory> {
}

mod action;
mod chi;
mod controller;

/// The AlwaysPromote decision engine will always return the Promote
Expand Down
87 changes: 9 additions & 78 deletions src/stats/chi.rs
Original file line number Diff line number Diff line change
@@ -1,28 +1,10 @@
use std::collections::{HashMap, HashSet};
use std::collections::HashSet;
use std::hash::Hash;
use std::num::NonZeroU64;

use crate::stats::ContingencyTable;
use statrs::distribution::{ChiSquared, ContinuousCDF};

/// A ContingencyTable expresses the frequency with which a group was observed.
/// Usually, it tracks the number of observations in ecah group, but when the
/// number is already known (i.e. its fixed, like a fair dice or coin), it can
/// expose just the frequencies for each group.
pub trait ContingencyTable<Group> {
/// return the number of observations of the in the provided group.
fn group_count(&self, cat: &Group) -> u64;

/// Return the set of groups that serve as columns of the contingency table.
fn groups(&self) -> Box<dyn Iterator<Item = Group>>;

// returns the total number of observations made. This should be the sum
// of the group count for every group.
fn total_count(&self) -> u64 {
self.groups()
.fold(0, |sum, group| sum + self.group_count(&group))
}
}

/// returns the number of degrees of freedom for this table.
/// This is typically the number of groups minus one.
/// # Panics
Expand Down Expand Up @@ -66,60 +48,6 @@ pub trait EnumerableCategory {
fn groups() -> Box<dyn Iterator<Item = Self>>;
}

/// A [FixedContingencyTable] is used to model scenarios where the
/// frequencies are fixed (i.e. known ahead of time), like fair dice.
/// It is mostly used for testing. The category must be hashable
/// because a hashmap is used internally to store the frequencies.
/// If you'd like us to add a B-Tree based alternative, please open an issue.
pub struct FixedContingencyTable<C>
where
C: EnumerableCategory + Hash + Eq,
{
counts: HashMap<C, u64>,
}

impl<C> FixedContingencyTable<C>
where
C: EnumerableCategory + Hash + Eq,
{
/// Construct a new, empty contingency table. All frequencies are
/// initialized to zero.
pub fn new() -> Self {
let mut counts = HashMap::new();
for group in C::groups() {
counts.entry(group).or_insert(0);
}

Self { counts }
}

/// Sets the expected count of the category to the value provided.
pub fn set_group_count(&mut self, cat: C, count: u64) {
self.counts.insert(cat, count);
}

/// Returns the number of observations that were classified as
/// having this group/category.
pub fn group_count(&self, cat: &C) -> u64 {
self.counts[cat]
}
}

impl<C> ContingencyTable<C> for FixedContingencyTable<C>
where
C: EnumerableCategory + Hash + Eq,
{
fn group_count(&self, cat: &C) -> u64 {
// delegate to the method on the base class.
Self::group_count(self, cat)
}

fn groups(&self) -> Box<dyn Iterator<Item = C>> {
// Delegate to the fixed list provided by the EnumerableCategory.
C::groups()
}
}

/// Alpha represents the alpha cutoff, expressed as a floating point from [0, 1] inclusive.
/// For example, 0.95 is the standard 5% confidency interval.
pub fn chi_square_test<Cat>(
Expand Down Expand Up @@ -175,7 +103,10 @@ mod tests {

use std::{collections::HashSet, num::NonZeroU64};

use crate::stats::chi::{degrees_of_freedom, p_value, FixedContingencyTable};
use crate::stats::{
chi::{degrees_of_freedom, p_value},
FixedTable,
};

use super::{test_statistic, ContingencyTable, EnumerableCategory};
use pretty_assertions::assert_eq;
Expand Down Expand Up @@ -203,7 +134,7 @@ mod tests {
/// can have its frequencies set and accessed.
#[test]
fn enumerable_table() {
let mut table = FixedContingencyTable::new();
let mut table = FixedTable::new();
let groups = [(true, 30u64), (false, 70u64)];
// Put the values into the table.
for (group, freq) in groups {
Expand All @@ -224,10 +155,10 @@ mod tests {
/// Let True represent Heads and False represent Tails.
#[test]
fn calc_test_statistic() {
let mut control_group = FixedContingencyTable::new();
let mut control_group = FixedTable::new();
control_group.set_group_count(true, 25);
control_group.set_group_count(false, 25);
let mut experimental_group = FixedContingencyTable::new();
let mut experimental_group = FixedTable::new();
experimental_group.set_group_count(true, 21);
experimental_group.set_group_count(false, 29);
assert_eq!(
Expand Down
12 changes: 9 additions & 3 deletions src/stats/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,13 @@ use std::collections::HashMap;
pub use chi::EnumerableCategory;
pub use group::Group;
pub use observation::Observation;
pub use tables::{ContingencyTable, EmpiricalTable, ExpectationTable, FixedTable};

use crate::metrics::ResponseStatusCode;

// TODO: Before long, we can delete this file since this is an
// old and mostly incorrect implement of X2.

/// The alpha cutoff is the amount of confidence must have in the result
/// to feel comfortable that the result is not due to chance, but instead
/// do to the independent variable. The valu is expressed as a confidence
Expand All @@ -17,8 +21,8 @@ const DEFAULT_ALPHA_CUTOFF: f64 = 0.05;
/// The [ChiSquareEngine] calculates the Chi Square test statistic
/// based on the data stored in its contingency tables.
pub struct ChiSquareEngine {
control: ContingencyTable,
experimental: ContingencyTable,
control: Table,
experimental: Table,
total_control_count: usize,
total_experimental_count: usize,
alpha_cutoff: f64,
Expand Down Expand Up @@ -93,11 +97,13 @@ impl ChiSquareEngine {
}

/// This type maps the dependent variable to its count.
pub type ContingencyTable = HashMap<ResponseStatusCode, usize>;
type Table = HashMap<ResponseStatusCode, usize>;

/// contains the engine to calculate the chi square test statistic.
mod chi;
/// `group` defines the two groups.
mod group;
/// An observation represents a group and the observed category.
mod observation;
/// Different kinds of contingency tables.
mod tables;
18 changes: 18 additions & 0 deletions src/stats/tables/contingency.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
/// A ContingencyTable expresses the frequency with which a group was observed.
/// Usually, it tracks the number of observations in ecah group, but when the
/// number is already known (i.e. its fixed, like a fair dice or coin), it can
/// expose just the frequencies for each group.
pub trait ContingencyTable<Group> {
/// return the number of observations of the in the provided group.
fn group_count(&self, cat: &Group) -> u64;

/// Return the set of groups that serve as columns of the contingency table.
fn groups(&self) -> Box<dyn Iterator<Item = Group>>;

// returns the total number of observations made. This should be the sum
// of the group count for every group.
fn total_count(&self) -> u64 {
self.groups()
.fold(0, |sum, group| sum + self.group_count(&group))
}
}
75 changes: 75 additions & 0 deletions src/stats/tables/empirical.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
use super::ContingencyTable;
use crate::stats::EnumerableCategory;
use std::{collections::HashMap, hash::Hash};

/// An [EmpiricalTable] is used to track observed data. It keeps
/// a talley of each observed category. When queried, it uses
/// the empirical values to emit an observation count.
/// This is in contrast to a ExpectationTable, which also keeps a
/// talley of observations made, but uses the count of observations
/// from an EmpiricalTable to determine the expected ratios.
///
/// The category must be hashable
/// because a hashmap is used internally to store the frequencies.
/// If you'd like us to add a B-Tree based alternative, please open an issue.
pub struct EmpiricalTable<C>
where
C: EnumerableCategory + Hash + Eq,
{
counts: HashMap<C, u64>,
}

impl<C> Default for EmpiricalTable<C>
where
C: EnumerableCategory + Hash + Eq,
{
fn default() -> Self {
Self::new()
}
}

impl<C> EmpiricalTable<C>
where
C: EnumerableCategory + Hash + Eq,
{
/// Construct a new, empty contingency table. All frequencies are
/// initialized to zero.
pub fn new() -> Self {
let mut counts = HashMap::new();
for group in C::groups() {
counts.entry(group).or_insert(0);
}

Self { counts }
}

pub fn increment(&mut self, cat: C) {
self.counts.entry(cat).and_modify(|c| *c += 1);
}

/// Sets the expected count of the category to the value provided.
pub fn set_group_count(&mut self, cat: C, count: u64) {
self.counts.insert(cat, count);
}

/// Returns the number of observations that were classified as
/// having this group/category.
pub fn group_count(&self, cat: &C) -> u64 {
self.counts[cat]
}
}

impl<C> ContingencyTable<C> for EmpiricalTable<C>
where
C: EnumerableCategory + Hash + Eq,
{
fn group_count(&self, cat: &C) -> u64 {
// delegate to the method on the base class.
Self::group_count(self, cat)
}

fn groups(&self) -> Box<dyn Iterator<Item = C>> {
// Delegate to the fixed list provided by the EnumerableCategory.
C::groups()
}
}
Loading

0 comments on commit b9a5431

Please sign in to comment.