Skip to content

Commit

Permalink
updated the req files. added more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mauvais2 committed Feb 13, 2025
1 parent 5b284c0 commit 30c4064
Show file tree
Hide file tree
Showing 6 changed files with 119 additions and 21 deletions.
111 changes: 109 additions & 2 deletions atomsci/ddm/test/unit/test_compare_splits_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd

from matplotlib import pyplot as plt
from atomsci.ddm.utils.compare_splits_plots import SplitStats
from atomsci.ddm.utils.compare_splits_plots import SplitStats, split, parse_args
from matplotcheck.base import PlotTester

# --- Fixtures ---
Expand Down Expand Up @@ -157,4 +157,111 @@ def test_dist_hist_plot_train_v_valid(mock_data):
assert ax is not None
pt.assert_num_bins(1) # Check if histogram has been plotted
pt.assert_axis_label_contains("x", "Tanimoto distance")
pt.assert_axis_label_contains("y", "Proportion of compounds")
pt.assert_axis_label_contains("y", "Proportion of compounds")

def test_split_function(mock_data):
"""
Test the `split` function to ensure it correctly splits the dataset into training, test, and validation sets.
Args:
mock_data (tuple): A tuple containing the total dataframe and the split dataframe.
Asserts:
- The length of the training, test, and validation dataframes are as expected.
- The 'cmpd_id' values in each subset are correct.
"""
total_df, split_df = mock_data
train_df, test_df, valid_df = split(total_df, split_df, id_col='cmpd_id')

assert len(train_df) == 2
assert len(test_df) == 2
assert len(valid_df) == 1

assert set(train_df['cmpd_id']) == {'cmpd1', 'cmpd3'}
assert set(test_df['cmpd_id']) == {'cmpd2', 'cmpd5'}
assert set(valid_df['cmpd_id']) == {'cmpd4'}

def test_split_stats_initialization(mock_data):
"""
Test the initialization of the `SplitStats` class to ensure it correctly processes the input data.
Args:
mock_data (tuple): A tuple containing the total dataframe and the split dataframe.
Asserts:
- The attributes of the `SplitStats` object are correctly initialized.
"""
total_df, split_df = mock_data
ss = SplitStats(total_df, split_df, smiles_col='smiles', id_col='cmpd_id', response_cols=['response'])

assert ss.smiles_col == 'smiles'
assert ss.id_col == 'cmpd_id'
assert ss.response_cols == ['response']
assert ss.total_df.equals(total_df)
assert ss.split_df.equals(split_df)
assert len(ss.train_df) == 2
assert len(ss.test_df) == 2
assert len(ss.valid_df) == 1

def test_print_stats(mock_data, capsys):
"""
Test the `print_stats` method to ensure it correctly prints the statistics.
Args:
mock_data (tuple): A tuple containing the total dataframe and the split dataframe.
capsys: Pytest fixture to capture stdout and stderr.
Asserts:
- The printed output contains the expected statistics.
"""
total_df, split_df = mock_data
ss = SplitStats(total_df, split_df, smiles_col='smiles', id_col='cmpd_id', response_cols=['response'])
ss.print_stats()

captured = capsys.readouterr()
assert "dist tvt mean" in captured.out
assert "dist tvv mean" in captured.out
assert "train frac mean" in captured.out
assert "test frac mean" in captured.out
assert "valid frac mean" in captured.out

def test_split(mock_data):
"""
Test the `split` function to ensure it correctly splits the dataset into training, test, and validation sets.
Args:
mock_data (tuple): A tuple containing two DataFrames, `total_df` and `split_df`, which represent the complete dataset
and the split dataset respectively.
Asserts:
- The training set contains the correct compounds.
- The test set contains the correct compounds.
- The validation set contains the correct compounds.
"""
total_df, split_df = mock_data
train_df, test_df, valid_df = split(total_df, split_df, id_col='cmpd_id')

# Check training set
assert set(train_df['cmpd_id']) == {'cmpd1', 'cmpd3'}
# Check test set
assert set(test_df['cmpd_id']) == {'cmpd2', 'cmpd5'}
# Check validation set
assert set(valid_df['cmpd_id']) == {'cmpd4'}

# Check that the splits are mutually exclusive and collectively exhaustive
all_ids = set(total_df['cmpd_id'])
split_ids = set(train_df['cmpd_id']).union(set(test_df['cmpd_id'])).union(set(valid_df['cmpd_id']))
assert all_ids == split_ids

def test_parse_args(mocker):
"""
Test the `parse_args` function to ensure it correctly parses command-line arguments.
Args:
mocker: Pytest mocker fixture to mock command-line arguments.
Asserts:
- The parsed arguments match the expected values.
"""
mocker.patch('sys.argv', [
'compare_splits_plots.py', 'dataset.csv', 'id', 'smiles', 'split_a.csv', 'split_b.csv', 'output_dir'
])
args = parse_args()

assert args.csv == 'dataset.csv'
assert args.id_col == 'id'
assert args.smiles_col == 'smiles'
assert args.split_a == 'split_a.csv'
assert args.split_b == 'split_b.csv'
assert args.output_dir == 'output_dir'
6 changes: 2 additions & 4 deletions pip/cpu_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# for LC developers: comment out pypi index url and use t$
# --extra-index-url https://wci-repo.llnl.gov/repository/$

-r dev_requirements.txt

tensorflow-cpu~=2.14.0

torch==2.0.1
Expand Down Expand Up @@ -38,10 +40,6 @@ maestrowf
MolVS
mordred

pytest
matplotcheck
ipykernel

deepchem==2.7.1
rdkit==2024.3.5
pyyaml==5.4.1
6 changes: 2 additions & 4 deletions pip/cuda_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#--extra-index-url https://wci-repo.llnl.gov/repository/pypi-group/simple
-f https://data.dgl.ai/wheels/cu118/repo.html

-r dev_requirements.txt

tensorflow[and-cuda]~=2.14.0

tensorrt
Expand Down Expand Up @@ -40,10 +42,6 @@ maestrowf
MolVS
mordred

pytest
matplotcheck
ipykernel

deepchem==2.7.1
rdkit==2024.3.5

Expand Down
2 changes: 2 additions & 0 deletions pip/dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ jupyterlab
maestrowf
matplotlib
matplotlib-venn
notebook
pytest
pytest-cov
pytest-mock
pytest-xdist
matplotcheck
ruff
Expand Down
7 changes: 2 additions & 5 deletions pip/docker_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# for LC developers: comment out pypi index url and use t$
# --extra-index-url https://wci-repo.llnl.gov/repository/$

-r dev_requirements.txt

tensorflow-cpu~=2.14.0

torch==2.0.1
Expand Down Expand Up @@ -38,11 +40,6 @@ bravado
maestrowf
MolVS
mordred
pytest
matplotcheck
ipykernel
jupyter
jupyterlab

deepchem==2.7.1
rdkit==2024.3.5
Expand Down
8 changes: 2 additions & 6 deletions pip/mchip_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
# for LC developers: comment out pypi index url and use t$
# --extra-index-url https://wci-repo.llnl.gov/repository/$

-r dev_requirements.txt

# use tensorflow not tensorflow-cpu for arm. see https://www.tensorflow.org/install/pip#linux
tensorflow~=2.14.0

Expand Down Expand Up @@ -39,12 +41,6 @@ maestrowf
MolVS
mordred

pytest
matplotcheck
ipykernel
notebook
jupyterlab

deepchem==2.7.1
rdkit==2024.3.5
pyyaml==5.4.1

0 comments on commit 30c4064

Please sign in to comment.