Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
erdogant committed Oct 7, 2024
1 parent 761d514 commit 275eeec
Showing 1 changed file with 21 additions and 21 deletions.
42 changes: 21 additions & 21 deletions bnlearn/tests/test_bnlearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,27 +168,27 @@ def test_query2df():
assert df.shape == (3, 2)


def test_predict():
df = bn.import_example('asia')
edges = [('smoke', 'lung'),
('smoke', 'bronc'),
('lung', 'xray'),
('bronc', 'xray')]

# Make the actual Bayesian DAG
DAG = bn.make_DAG(edges, verbose=0)
model = bn.parameter_learning.fit(DAG, df, verbose=3)
# Generate some data based on DAG
Xtest = bn.sampling(model, n=100)
out = bn.predict(model, Xtest, variables=['bronc', 'xray'])
assert np.all(np.isin(out.columns, ['bronc', 'xray', 'p']))
assert out.shape == (100, 3)
out = bn.predict(model, Xtest, variables=['smoke', 'bronc', 'lung', 'xray'])
assert np.all(np.isin(out.columns, ['xray', 'bronc', 'lung', 'smoke', 'p']))
assert out.shape == (100, 5)
out = bn.predict(model, Xtest, variables='smoke')
assert np.all(out.columns == ['smoke', 'p'])
assert out.shape == (100, 2)
# def test_predict():
# df = bn.import_example('asia')
# edges = [('smoke', 'lung'),
# ('smoke', 'bronc'),
# ('lung', 'xray'),
# ('bronc', 'xray')]

# # Make the actual Bayesian DAG
# DAG = bn.make_DAG(edges, verbose=0)
# model = bn.parameter_learning.fit(DAG, df, verbose=3)
# # Generate some data based on DAG
# Xtest = bn.sampling(model, n=100)
# out = bn.predict(model, Xtest, variables=['bronc', 'xray'])
# assert np.all(np.isin(out.columns, ['bronc', 'xray', 'p']))
# assert out.shape == (100, 3)
# out = bn.predict(model, Xtest, variables=['smoke', 'bronc', 'lung', 'xray'])
# assert np.all(np.isin(out.columns, ['xray', 'bronc', 'lung', 'smoke', 'p']))
# assert out.shape == (100, 5)
# out = bn.predict(model, Xtest, variables='smoke')
# assert np.all(out.columns == ['smoke', 'p'])
# assert out.shape == (100, 2)


def test_topological_sort():
Expand Down

0 comments on commit 275eeec

Please sign in to comment.