Skip to content

Commit

Permalink
test option dic checks
Browse files Browse the repository at this point in the history
  • Loading branch information
gregor-schueler committed Jan 10, 2025
1 parent 9c471e9 commit 46717ff
Showing 1 changed file with 129 additions and 7 deletions.
136 changes: 129 additions & 7 deletions tests/test_options_tests.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,133 @@
from dcegm.pre_processing.setup_model import setup_model
import numpy as np
import pytest

from dcegm.pre_processing.check_options import check_options_and_set_defaults

def test_n_periods():

test_model = setup_model(
options={"state_space": {"n_periods": 1}},
utility_functions={},
utility_functions_final_period={},
budget_constraint=lambda x: x,
@pytest.fixture
def valid_options():
"""Fixture providing a valid options dictionary."""
return {
"state_space": {
"n_periods": 5,
"choices": [1, 2, 3],
"endogenous_states": {
"education": np.arange(2, dtype=int),
},
"continuous_states": {
"wealth": np.linspace(0, 10, 11),
"experience": np.linspace(0, 5, 6),
},
"exogenous_processes": {
"health": {
"transition": lambda x: x,
"states": np.arange(3, dtype=int),
},
},
},
"model_params": {},
"tuning_params": {},
}


def test_invalid_options_type():
with pytest.raises(ValueError, match="Options must be a dictionary."):
check_options_and_set_defaults([])


def test_missing_state_space(valid_options):
del valid_options["state_space"]
with pytest.raises(
ValueError, match="Options must contain a state space dictionary."
):
check_options_and_set_defaults(valid_options)


def test_invalid_state_space_type():
with pytest.raises(ValueError, match="State space must be a dictionary."):
check_options_and_set_defaults({"state_space": "not_a_dict"})


def test_missing_n_periods(valid_options):
del valid_options["state_space"]["n_periods"]
with pytest.raises(
ValueError, match="State space must contain the number of periods."
):
check_options_and_set_defaults(valid_options)


def test_invalid_n_periods_type(valid_options):
valid_options["state_space"]["n_periods"] = "not_an_int"
with pytest.raises(ValueError, match="Number of periods must be an integer."):
check_options_and_set_defaults(valid_options)


def test_invalid_n_periods_value(valid_options):
valid_options["state_space"]["n_periods"] = 1
with pytest.raises(ValueError, match="Number of periods must be greater than 1."):
check_options_and_set_defaults(valid_options)


@pytest.mark.parametrize(
"choices, expected_array",
[
([1, 2, 3], np.array([1, 2, 3], dtype=np.uint8)),
(5, np.array([5], dtype=np.uint8)),
(np.array([1, 2, 3]), np.array([1, 2, 3], dtype=np.uint8)),
],
)
def test_valid_choices_conversion(valid_options, choices, expected_array):
valid_options["state_space"]["choices"] = choices
options = check_options_and_set_defaults(valid_options)
np.testing.assert_array_equal(options["state_space"]["choices"], expected_array)


def test_invalid_choices_type(valid_options):
valid_options["state_space"]["choices"] = "not_a_valid_type"
with pytest.raises(ValueError, match="Choices must be a list or an integer."):
check_options_and_set_defaults(valid_options)


def test_missing_model_params(valid_options):
del valid_options["model_params"]
with pytest.raises(
ValueError, match="Options must contain a model parameters dictionary."
):
check_options_and_set_defaults(valid_options)


def test_invalid_model_params_type(valid_options):
valid_options["model_params"] = "not_a_dict"
with pytest.raises(ValueError, match="Model parameters must be a dictionary."):
check_options_and_set_defaults(valid_options)


# maybe also check this in the check_options_and_set_defaults function
def test_missing_continuous_states(valid_options):
del valid_options["state_space"]["continuous_states"]
with pytest.raises(KeyError):
check_options_and_set_defaults(valid_options)


def test_tuning_params_defaults(valid_options):
del valid_options["tuning_params"]
options = check_options_and_set_defaults(valid_options)
assert options["tuning_params"]["extra_wealth_grid_factor"] == 0.2
assert options["tuning_params"]["n_constrained_points_to_add"] == 1


def test_tuning_params_invalid_grid_factors(valid_options):
valid_options["tuning_params"]["extra_wealth_grid_factor"] = 0.01
valid_options["tuning_params"]["n_constrained_points_to_add"] = 100
with pytest.raises(
ValueError, match="The extra wealth grid factor .* is too small"
):
check_options_and_set_defaults(valid_options)


def test_second_continuous_state_handling(valid_options):
options = check_options_and_set_defaults(valid_options)
assert options["second_continuous_state_name"] == "experience"
assert options["tuning_params"]["n_second_continuous_grid"] == len(
valid_options["state_space"]["continuous_states"]["experience"]
)

0 comments on commit 46717ff

Please sign in to comment.