Skip to content

Commit

Permalink
Merge branch 'untested_fixes' into pre-commit-ci-update-config
Browse files Browse the repository at this point in the history
  • Loading branch information
segsell authored Oct 14, 2024
2 parents ea9cf13 + 3228076 commit 480b3b7
Show file tree
Hide file tree
Showing 12 changed files with 218 additions and 111 deletions.
16 changes: 8 additions & 8 deletions src/dcegm/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def policy_and_value_for_state_choice_vec(
state_choice_vec,
wealth,
map_state_choice_to_index,
state_space_names,
discrete_states_names,
endog_grid_solved,
policy_solved,
value_solved,
Expand All @@ -30,7 +30,7 @@ def policy_and_value_for_state_choice_vec(
"""
state_choice_tuple = tuple(
state_choice_vec[st] for st in state_space_names + ["choice"]
state_choice_vec[st] for st in discrete_states_names + ["choice"]
)

state_choice_index = map_state_choice_to_index[state_choice_tuple]
Expand All @@ -50,7 +50,7 @@ def value_for_state_choice_vec(
state_choice_vec,
wealth,
map_state_choice_to_index,
state_space_names,
discrete_states_names,
endog_grid_solved,
value_solved,
compute_utility,
Expand All @@ -68,7 +68,7 @@ def value_for_state_choice_vec(
"""
state_choice_tuple = tuple(
state_choice_vec[st] for st in state_space_names + ["choice"]
state_choice_vec[st] for st in discrete_states_names + ["choice"]
)

state_choice_index = map_state_choice_to_index[state_choice_tuple]
Expand All @@ -88,7 +88,7 @@ def policy_for_state_choice_vec(
state_choice_vec,
wealth,
map_state_choice_to_index,
state_space_names,
discrete_states_names,
endog_grid_solved,
policy_solved,
):
Expand All @@ -104,7 +104,7 @@ def policy_for_state_choice_vec(
"""
state_choice_tuple = tuple(
state_choice_vec[st] for st in state_space_names + ["choice"]
state_choice_vec[st] for st in discrete_states_names + ["choice"]
)

state_choice_index = map_state_choice_to_index[state_choice_tuple]
Expand All @@ -119,10 +119,10 @@ def policy_for_state_choice_vec(


def get_state_choice_index_per_discrete_state(
map_state_choice_to_index, states, state_space_names
map_state_choice_to_index, states, discrete_states_names
):
indexes = map_state_choice_to_index[
tuple((states[key],) for key in state_space_names)
tuple((states[key],) for key in discrete_states_names)
]
# As the code above generates a dummy dimension in the first we eliminate that
return indexes[0]
32 changes: 32 additions & 0 deletions src/dcegm/interpolation/interp2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,38 @@ def interp2d_policy_and_value_on_wealth_and_regular_grid(
return policy_interp, value_interp


def interp2d_value_on_wealth_and_regular_grid(
regular_grid: jnp.ndarray,
wealth_grid: jnp.ndarray,
value_grid: jnp.ndarray,
regular_point_to_interp: jnp.ndarray | float,
wealth_point_to_interp: jnp.ndarray | float,
compute_utility: Callable,
state_choice_vec: Dict[str, int],
params: dict,
):
regular_points, wealth_points, coords_idxs = find_grid_coords_for_interp(
regular_grid=regular_grid,
wealth_grid=wealth_grid,
regular_point_to_interp=regular_point_to_interp,
wealth_point_to_interp=wealth_point_to_interp,
)
value_interp = interp2d_value_and_check_creditconstraint(
regular_points=regular_points,
wealth_points=wealth_points,
value_grid=value_grid,
coords_idxs=coords_idxs,
regular_point_to_interp=regular_point_to_interp,
wealth_point_to_interp=wealth_point_to_interp,
compute_utility=compute_utility,
wealth_min_unconstrained=wealth_grid[:, 1],
value_at_zero_wealth=value_grid[:, 0],
state_choice_vec=state_choice_vec,
params=params,
)
return value_interp


def interp2d_policy(
regular_points,
wealth_points,
Expand Down
1 change: 1 addition & 0 deletions src/dcegm/law_of_motion.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def calc_wealth_for_each_continuous_state_and_savings_grid_point(
params,
compute_beginning_of_period_wealth,
):

out = compute_beginning_of_period_wealth(
**state_vec,
continuous_state=continuous_state_beginning_of_period,
Expand Down
142 changes: 94 additions & 48 deletions src/dcegm/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
)
from dcegm.interface import get_state_choice_index_per_discrete_state
from dcegm.interpolation.interp1d import interp_value_on_wealth
from dcegm.interpolation.interp2d import interp2d_value_on_wealth_and_regular_grid
from dcegm.solve import get_solve_func_for_model


def create_individual_likelihood_function_for_model(
model: Dict[str, Any],
observed_states: Dict[str, int],
observed_wealth: np.array,
observed_choices: np.array,
params_all,
unobserved_state_specs=None,
Expand All @@ -35,15 +35,13 @@ def create_individual_likelihood_function_for_model(
choice_prob_func = create_partial_choice_prob_calculation(
observed_states=observed_states,
observed_choices=observed_choices,
observed_wealth=observed_wealth,
model=model,
)
else:

choice_prob_func = create_choice_prob_func_unobserved_states(
model=model,
observed_states=observed_states,
observed_wealth=observed_wealth,
observed_choices=observed_choices,
unobserved_state_specs=unobserved_state_specs,
weight_full_states=True,
Expand Down Expand Up @@ -74,19 +72,27 @@ def individual_likelihood(params):
def create_choice_prob_func_unobserved_states(
model: Dict[str, Any],
observed_states: Dict[str, int],
observed_wealth: np.array,
observed_choices: np.array,
unobserved_state_specs,
weight_full_states=True,
):
# First prepare full observed states, choices and pre period states for weighting
full_mask = unobserved_state_specs["observed_bool"]
if len(model["options"]["exog_grids"]) == 2:
second_cont_state_name = model["options"]["second_continuous_state_name"]
state_space_names = model["model_structure"]["discrete_states_names"] + [
"wealth",
second_cont_state_name,
]
else:
state_space_names = model["model_structure"]["discrete_states_names"] + [
"wealth"
]

full_observed_states = {
name: observed_states[name][full_mask]
for name in model["model_structure"]["state_space_names"]
name: observed_states[name][full_mask] for name in state_space_names
}
full_observed_choices = observed_choices[full_mask]
full_observed_wealth = observed_wealth[full_mask]
# Now the states of last period for weighting and also the unobserved states
# for this period
pre_period_full_observed_states = {
Expand All @@ -102,7 +108,6 @@ def create_choice_prob_func_unobserved_states(
partial_choice_probs_full_observed_states = create_partial_choice_prob_calculation(
observed_states=full_observed_states,
observed_choices=full_observed_choices,
observed_wealth=full_observed_wealth,
model=model,
)

Expand All @@ -121,8 +126,7 @@ def create_choice_prob_func_unobserved_states(

# Read out the observed states of the unobserved states
unobserved_states = {
name: observed_states[name][~full_mask]
for name in model["model_structure"]["state_space_names"]
name: observed_states[name][~full_mask] for name in state_space_names
}
# Also pre period states
pre_period_unobserved_states = {
Expand Down Expand Up @@ -167,7 +171,6 @@ def create_choice_prob_func_unobserved_states(
create_partial_choice_prob_calculation(
observed_states=unobserved_state,
observed_choices=observed_choices[~full_mask],
observed_wealth=observed_wealth[~full_mask],
model=model,
)
)
Expand Down Expand Up @@ -255,28 +258,23 @@ def choice_prob_func(value_in, endog_grid_in, params_in):
def create_partial_choice_prob_calculation(
observed_states,
observed_choices,
observed_wealth,
model,
):
observed_state_choice_indexes = get_state_choice_index_per_discrete_state(
discrete_observed_state_choice_indexes = get_state_choice_index_per_discrete_state(
states=observed_states,
map_state_choice_to_index=model["model_structure"]["map_state_choice_to_index"],
state_space_names=model["model_structure"]["state_space_names"],
discrete_states_names=model["model_structure"]["discrete_states_names"],
)

options = model["options"]

def partial_choice_prob_func(value_in, endog_grid_in, params_in):
return calc_choice_prob_for_state_choices(
value_solved=value_in,
endog_grid_solved=endog_grid_in,
params=params_in,
states=observed_states,
choices=observed_choices,
state_choice_indexes=observed_state_choice_indexes,
oberseved_wealth=observed_wealth,
choice_range=np.arange(options["model_params"]["n_choices"], dtype=int),
compute_utility=model["model_funcs"]["compute_utility"],
state_choice_indexes=discrete_observed_state_choice_indexes,
model=model,
)

return partial_choice_prob_func
Expand All @@ -289,9 +287,7 @@ def calc_choice_prob_for_state_choices(
states,
choices,
state_choice_indexes,
oberseved_wealth,
choice_range,
compute_utility,
model,
):
"""This function interpolates the policy and value function for all agents.
Expand All @@ -305,9 +301,7 @@ def calc_choice_prob_for_state_choices(
params=params,
observed_states=states,
state_choice_indexes=state_choice_indexes,
oberseved_wealth=oberseved_wealth,
choice_range=choice_range,
compute_utility=compute_utility,
model=model,
)
choice_probs = jnp.take_along_axis(
choice_prob_across_choices, choices[:, None], axis=1
Expand All @@ -321,41 +315,93 @@ def calc_choice_probs_for_states(
params,
observed_states,
state_choice_indexes,
oberseved_wealth,
choice_range,
compute_utility,
model,
):
value_grid_agent = jnp.take(
value_solved, state_choice_indexes, axis=0, mode="fill", fill_value=jnp.nan
)
endog_grid_agent = jnp.take(endog_grid_solved, state_choice_indexes, axis=0)
vectorized_interp = jax.vmap(
jax.vmap(
interpolate_value_for_state_in_each_choice,
in_axes=(None, None, 0, 0, 0, None, None),
),
in_axes=(0, 0, 0, 0, None, None, None),
)

value_per_agent_interp = vectorized_interp(
observed_states,
oberseved_wealth,
endog_grid_agent,
value_grid_agent,
choice_range,
params,
compute_utility,
)
# Read out relevant model objects
options = model["options"]
choice_range = options["state_space"]["choices"]
compute_utility = model["model_funcs"]["compute_utility"]

if len(options["exog_grids"]) == 2:
vectorized_interp2d = jax.vmap(
jax.vmap(
interp2d_value_for_state_in_each_choice,
in_axes=(None, None, 0, 0, 0, None, None, None),
),
in_axes=(0, 0, 0, 0, None, None, None, None),
)
# Extract second cont state name
second_continuous_state_name = options["second_continuous_state_name"]
second_cont_value = observed_states[second_continuous_state_name]

value_per_agent_interp = vectorized_interp2d(
observed_states,
second_cont_value,
endog_grid_agent,
value_grid_agent,
choice_range,
params,
options["exog_grids"]["second_continuous"],
compute_utility,
)

else:
vectorized_interp1d = jax.vmap(
jax.vmap(
interp1d_value_for_state_in_each_choice,
in_axes=(None, 0, 0, 0, None, None),
),
in_axes=(0, 0, 0, None, None, None),
)

value_per_agent_interp = vectorized_interp1d(
observed_states,
endog_grid_agent,
value_grid_agent,
choice_range,
params,
compute_utility,
)
choice_prob_across_choices, _, _ = calculate_choice_probs_and_unsqueezed_logsum(
choice_values_per_state=value_per_agent_interp,
taste_shock_scale=params["lambda"],
)
return choice_prob_across_choices


def interpolate_value_for_state_in_each_choice(
def interp2d_value_for_state_in_each_choice(
state,
second_cont_state,
endog_grid_agent,
value_agent,
choice,
params,
regular_grid,
compute_utility,
):
state_choice_vec = {**state, "choice": choice}

value_interp = interp2d_value_on_wealth_and_regular_grid(
regular_grid=regular_grid,
wealth_grid=endog_grid_agent,
value_grid=value_agent,
regular_point_to_interp=second_cont_state,
wealth_point_to_interp=state["wealth"],
compute_utility=compute_utility,
state_choice_vec=state_choice_vec,
params=params,
)

return value_interp


def interp1d_value_for_state_in_each_choice(
state,
resource_at_beginning_of_period,
endog_grid_agent,
value_agent,
choice,
Expand All @@ -365,7 +411,7 @@ def interpolate_value_for_state_in_each_choice(
state_choice_vec = {**state, "choice": choice}

value_interp = interp_value_on_wealth(
wealth=resource_at_beginning_of_period,
wealth=state["wealth"],
endog_grid=endog_grid_agent,
value=value_agent,
compute_utility=compute_utility,
Expand Down
Loading

0 comments on commit 480b3b7

Please sign in to comment.