Skip to content

Commit

Permalink
Add unit tests to datasets (#202)
Browse files Browse the repository at this point in the history
* Add tests to folktables datasets

* Update the location of Generic dataset tests

* Update BAF tests
  • Loading branch information
sgpjesus authored Aug 27, 2024
1 parent 1e9e723 commit a89641e
Show file tree
Hide file tree
Showing 5 changed files with 236 additions and 60 deletions.
17 changes: 10 additions & 7 deletions src/aequitas/flow/datasets/baf.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,7 @@ def __init__(
):
super().__init__()

self.label_column = (
LABEL_COLUMN if label_column is None else label_column
)
self.label_column = LABEL_COLUMN if label_column is None else label_column

if sensitive_column == "customer_age" or sensitive_column is None:
self.sensitive_column = SENSITIVE_COLUMN
Expand All @@ -107,11 +105,12 @@ def __init__(
else:
self.variant = variant
self.logger.debug(f"Variant: {self.variant}")
if url(path) or path.exists():
self.path = path

self.extension = extension
self.path = path
if url(path) or self._check_paths():
self._download = False
else:
self.path = path
self._download = True
if split_type not in SPLIT_TYPES:
raise ValueError(f"Invalid split_type value. Try one of: {SPLIT_TYPES}")
Expand All @@ -120,7 +119,6 @@ def __init__(
self.splits = splits
self._validate_splits()
self.logger.debug("Splits successfully validated.")
self.extension = extension
self.seed = seed
self.data: pd.DataFrame = None
self.include_month = include_month
Expand Down Expand Up @@ -203,6 +201,11 @@ def create_splits(self) -> None:
),
)

def _check_paths(self) -> bool:
"""Check if the data is already downloaded."""
check_path = Path(self.path) / f"{self.variant}.{self.extension}"
return check_path.exists()

def _download_data(self) -> None:
"""Obtains the data of the sample dataset from Aequitas repository."""
self.logger.info("Downloading sample data from repository.")
Expand Down
47 changes: 19 additions & 28 deletions src/aequitas/flow/datasets/folktables.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,37 +212,28 @@ def load_data(self):
if self._download:
self._download_data()

if self.split_type == "predefined":
path = []
for split in ["train", "validation", "test"]:
if isinstance(self.path, str):
path.append(self.path + f"/{self.variant}.{split}.{self.extension}")
else:
path.append(self.path / f"{self.variant}.{split}.{self.extension}")
else:
path = self.path / f"{self.variant}.{self.extension}"
path = []
for split in ["train", "validation", "test"]:
if isinstance(self.path, str):
path.append(self.path + f"/{self.variant}.{split}.{self.extension}")
else:
path.append(self.path / f"{self.variant}.{split}.{self.extension}")

if self.extension == "parquet":
if self.split_type == "predefined":
datasets = [pd.read_parquet(p) for p in path]
self._indexes = [d.index for d in datasets]
self.data = pd.concat(datasets)
else:
self.data = pd.read_parquet(path)
datasets = [pd.read_parquet(p) for p in path]
self._indexes = [d.index for d in datasets]
self.data = pd.concat(datasets)
else:
if self.split_type == "predefined":
train = pd.read_csv(path[0])
train_index = train.index[-1]
validation = pd.read_csv(path[1])
validation.set_index(validation.index + train_index + 1, inplace=True)
validation_index = validation.index[-1]
test = pd.read_csv(path[2])
test.set_index(test.index + validation_index + 1, inplace=True)
self._indexes = [train.index, validation.index, test.index]

self.data = pd.concat([train, validation, test])
else:
self.data = pd.read_csv(path)
train = pd.read_csv(path[0])
train_index = train.index[-1]
validation = pd.read_csv(path[1])
validation.set_index(validation.index + train_index + 1, inplace=True)
validation_index = validation.index[-1]
test = pd.read_csv(path[2])
test.set_index(test.index + validation_index + 1, inplace=True)
self._indexes = [train.index, validation.index, test.index]

self.data = pd.concat([train, validation, test])

for col in CATEGORICAL_COLUMNS[self.variant]:
self.data[col] = self.data[col].astype("category")
Expand Down
91 changes: 91 additions & 0 deletions tests/flow/datasets/test_baf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import unittest
from aequitas.flow.datasets.baf import BankAccountFraud, VARIANTS, DEFAULT_PATH


# TODO: These tests can be merged with the ones in test_folktables.py

class TestBankAccountFraudDataset(unittest.TestCase):
# Test loading related functionalities.
def test_load_variants(self):
for variant in VARIANTS:
dataset = BankAccountFraud(variant)
dataset.load_data()
self.assertTrue(len(dataset.data) > 0)
self.assertTrue("customer_age_bin" in dataset.data.columns)
self.assertTrue("fraud_bool" in dataset.data.columns)

def test_load_invalid_variant(self):
with self.assertRaises(ValueError):
BankAccountFraud("invalid_variant")

def test_download(self):
# Remove default folder of datasets even if not empty
if DEFAULT_PATH.exists():
for file in DEFAULT_PATH.iterdir():
file.unlink()
DEFAULT_PATH.rmdir()
for variant in VARIANTS:
dataset = BankAccountFraud(variant)
dataset.load_data()
self.assertTrue(dataset.path.exists())

# Test split related functionalities.
def test_invalid_split_type(self):
with self.assertRaises(ValueError):
BankAccountFraud(VARIANTS[0], split_type="invalid_split_type")

def test_default_split(self):
dataset = BankAccountFraud(VARIANTS[0])
dataset.load_data()
dataset.create_splits()
self.assertTrue(len(dataset.train) > 0)
self.assertTrue(len(dataset.test) > 0)
self.assertTrue(len(dataset.validation) > 0)

def test_random_split(self):
dataset = BankAccountFraud(
VARIANTS[0],
split_type="random",
splits={"train": 0.6, "validation": 0.2, "test": 0.2},
)
dataset.load_data()
dataset.create_splits()
self.assertTrue(len(dataset.train) > 0)
self.assertTrue(len(dataset.test) > 0)
self.assertTrue(len(dataset.validation) > 0)

def test_invalid_random_split_missing_key(self):
with self.assertRaises(ValueError):
BankAccountFraud(
VARIANTS[0],
split_type="random",
splits={"train": 0.6, "validation": 0.2},
)

def test_invalid_random_split_more_than_1(self):
with self.assertRaises(ValueError):
BankAccountFraud(
VARIANTS[0],
split_type="random",
splits={"train": 0.6, "validation": 0.2, "test": 0.3},
)

# Test sensitive column related issues.
def test_housing_sensitive_column(self):
dataset = BankAccountFraud(VARIANTS[0], sensitive_column="housing_status")
dataset.load_data()
self.assertTrue("housing_status" in dataset.data.columns)
self.assertTrue(dataset.data.s.name == "housing_status")

def test_invalid_sensitive_column(self):
with self.assertRaises(ValueError):
BankAccountFraud(VARIANTS[0], sensitive_column="invalid_column")

def test_invalid_sensitive_column_type(self):
with self.assertRaises(ValueError):
BankAccountFraud(VARIANTS[0], sensitive_column="name_email_similarity")
# Numerical column


if __name__ == "__main__":
unittest.main()
90 changes: 90 additions & 0 deletions tests/flow/datasets/test_folktables.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import unittest
from aequitas.flow.datasets.folktables import FolkTables, VARIANTS, DEFAULT_PATH


# TODO: Test CSV related functionalities.

class TestFolktablesDataset(unittest.TestCase):
# Test loading related functionalities.
def test_load_variants(self):
for variant in VARIANTS:
dataset = FolkTables(variant)
dataset.load_data()
self.assertTrue(len(dataset.data) > 0)
self.assertTrue("RAC1P" in dataset.data.columns)
self.assertTrue("AGEP" in dataset.data.columns)

def test_load_invalid_variant(self):
with self.assertRaises(ValueError):
FolkTables("invalid_variant")

def test_download(self):
# Remove default folder of datasets even if not empty
if DEFAULT_PATH.exists():
for file in DEFAULT_PATH.iterdir():
file.unlink()
DEFAULT_PATH.rmdir()
for variant in VARIANTS:
dataset = FolkTables(variant)
dataset.load_data()
self.assertTrue(dataset.path.exists())

# Test split related functionalities.
def test_invalid_split_type(self):
with self.assertRaises(ValueError):
FolkTables(VARIANTS[0], split_type="invalid_split_type")

def test_default_split(self):
dataset = FolkTables(VARIANTS[0])
dataset.load_data()
dataset.create_splits()
self.assertTrue(len(dataset.train) > 0)
self.assertTrue(len(dataset.test) > 0)
self.assertTrue(len(dataset.validation) > 0)

def test_random_split(self):
dataset = FolkTables(
VARIANTS[0],
split_type="random",
splits={"train": 0.6, "validation": 0.2, "test": 0.2},
)
dataset.load_data()
dataset.create_splits()
self.assertTrue(len(dataset.train) > 0)
self.assertTrue(len(dataset.test) > 0)
self.assertTrue(len(dataset.validation) > 0)

def test_invalid_random_split_missing_key(self):
with self.assertRaises(ValueError):
FolkTables(
VARIANTS[0],
split_type="random",
splits={"train": 0.6, "validation": 0.2},
)

def test_invalid_random_split_more_than_1(self):
with self.assertRaises(ValueError):
FolkTables(
VARIANTS[0],
split_type="random",
splits={"train": 0.6, "validation": 0.2, "test": 0.3},
)

# Test sensitive column related issues.
def test_age_sensitive_column(self):
dataset = FolkTables(VARIANTS[0], sensitive_column="AGEP")
dataset.load_data()
self.assertTrue("AGEP" in dataset.data.columns)
self.assertTrue("AGEP_bin" in dataset.data.columns)

def test_invalid_sensitive_column(self):
with self.assertRaises(ValueError):
FolkTables(VARIANTS[0], sensitive_column="invalid_column")

def test_invalid_sensitive_column_type(self):
with self.assertRaises(ValueError):
FolkTables(VARIANTS[0], sensitive_column="SCHL") # Numerical column


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit a89641e

Please sign in to comment.