Skip to content

Commit

Permalink
Add smoking test
Browse files Browse the repository at this point in the history
  • Loading branch information
lixfz committed Jan 4, 2023
1 parent d6fefbc commit 2b69ec8
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tests/why_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ def _validate_it(why, test_data, check_score=True):
print("rloss:", score)


def test_smoke():
from ylearn.api.smoke import smoke
smoke()


def test_basis():
data, test_data, outcome, treatment, adjustment, covariate = _dgp.generate_data_x1b_y1()
why = Why()
Expand Down
35 changes: 35 additions & 0 deletions ylearn/api/smoke.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from ylearn.api import Why
from ylearn.exp_dataset.exp_data import single_binary_treatment


def smoke(estimator='auto'):
print('-' * 20, 'smoke with estimator', estimator, '-' * 20)

train, val, _ = single_binary_treatment()
te = train.pop('TE')
te = val.pop('TE')
adjustment = [c for c in train.columns.tolist() if c.startswith('w')]
covariate = [c for c in train.columns.tolist() if c.startswith('c')]

if estimator == 'grf':
covariate.extend(adjustment)
adjustment = None

why = Why(estimator=estimator)
why.fit(train, outcome='outcome', treatment='treatment', adjustment=adjustment, covariate=covariate)

cate = why.causal_effect(val)
print('CATE:\n', cate)

auuc = why.score(val, scorer='auuc')
print('AUUC', auuc)


if __name__ == '__main__':
from ylearn.utils import logging

logging.set_level('info')
for est in ['slearner', 'tlearner', 'xlearner', 'dr', 'dml', 'tree', 'grf']:
smoke(est)

print('\n<done>')

0 comments on commit 2b69ec8

Please sign in to comment.