Skip to content

Commit

Permalink
Merge pull request #172 from st-tech/feature/mips-slope-with-true-iw
Browse files Browse the repository at this point in the history
Allowing slope to use the true marginal importance weight for mips
  • Loading branch information
usaito authored Jun 15, 2022
2 parents 122743e + 0e94113 commit 9d62615
Showing 1 changed file with 73 additions and 30 deletions.
103 changes: 73 additions & 30 deletions obp/ope/estimators_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ def estimate_policy_value(
position=position,
pi_b=pi_b,
action_dist=action_dist,
p_e_a=p_e_a,
)
elif self.embedding_selection_method == "greedy":
return self._estimate_with_greedy_pruning(
Expand All @@ -313,6 +314,7 @@ def estimate_policy_value(
position=position,
pi_b=pi_b,
action_dist=action_dist,
p_e_a=p_e_a,
)
else:
return self._estimate_round_rewards(
Expand All @@ -335,6 +337,7 @@ def _estimate_with_exact_pruning(
pi_b: np.ndarray,
action_dist: np.ndarray,
position: np.ndarray,
p_e_a: Optional[np.ndarray] = None,
) -> float:
"""Apply an exact version of data-drive action embedding selection."""
n_emb_dim = action_embed.shape[1]
Expand All @@ -344,16 +347,29 @@ def _estimate_with_exact_pruning(
comb_list = list(itertools.combinations(feat_list, i))
theta_list_, cnf_list_ = [], []
for comb in comb_list:
theta, cnf = self._estimate_round_rewards(
context=context,
reward=reward,
action=action,
action_embed=action_embed[:, comb],
pi_b=pi_b,
action_dist=action_dist,
position=position,
with_dev=True,
)
if p_e_a is None:
theta, cnf = self._estimate_round_rewards(
context=context,
reward=reward,
action=action,
action_embed=action_embed[:, comb],
pi_b=pi_b,
action_dist=action_dist,
position=position,
with_dev=True,
)
else:
theta, cnf = self._estimate_round_rewards(
context=context,
reward=reward,
action=action,
action_embed=action_embed[:, comb],
pi_b=pi_b,
action_dist=action_dist,
position=position,
p_e_a=p_e_a[:, :, comb],
with_dev=True,
)
if len(theta_list) > 0:
theta_list_.append(theta), cnf_list_.append(cnf)
else:
Expand All @@ -380,23 +396,37 @@ def _estimate_with_greedy_pruning(
pi_b: np.ndarray,
action_dist: np.ndarray,
position: np.ndarray,
p_e_a: Optional[np.ndarray] = None,
) -> float:
"""Apply a greedy version of data-drive action embedding selection."""
n_emb_dim = action_embed.shape[1]
theta_list, cnf_list = [], []
current_feat, C = np.arange(n_emb_dim), np.sqrt(6) - 1

# init
theta, cnf = self._estimate_round_rewards(
context=context,
reward=reward,
action=action,
action_embed=action_embed[:, current_feat],
pi_b=pi_b,
action_dist=action_dist,
position=position,
with_dev=True,
)
if p_e_a is None:
theta, cnf = self._estimate_round_rewards(
context=context,
reward=reward,
action=action,
action_embed=action_embed[:, current_feat],
pi_b=pi_b,
action_dist=action_dist,
position=position,
with_dev=True,
)
else:
theta, cnf = self._estimate_round_rewards(
context=context,
reward=reward,
action=action,
action_embed=action_embed[:, current_feat],
pi_b=pi_b,
action_dist=action_dist,
position=position,
p_e_a=p_e_a[:, :, current_feat],
with_dev=True,
)
theta_list.append(theta), cnf_list.append(cnf)

# iterate
Expand All @@ -405,16 +435,29 @@ def _estimate_with_greedy_pruning(
for d in current_feat:
idx_without_d = np.where(current_feat != d, True, False)
candidate_feat = current_feat[idx_without_d]
theta, cnf = self._estimate_round_rewards(
context=context,
reward=reward,
action=action,
action_embed=action_embed[:, candidate_feat],
pi_b=pi_b,
action_dist=action_dist,
position=position,
with_dev=True,
)
if p_e_a is None:
theta, cnf = self._estimate_round_rewards(
context=context,
reward=reward,
action=action,
action_embed=action_embed[:, candidate_feat],
pi_b=pi_b,
action_dist=action_dist,
position=position,
with_dev=True,
)
else:
theta, cnf = self._estimate_round_rewards(
context=context,
reward=reward,
action=action,
action_embed=action_embed[:, candidate_feat],
pi_b=pi_b,
action_dist=action_dist,
position=position,
p_e_a=p_e_a[:, :, candidate_feat],
with_dev=True,
)
d_list_.append(d)
theta_list_.append(theta), cnf_list_.append(cnf)

Expand Down

0 comments on commit 9d62615

Please sign in to comment.