Skip to content

Commit

Permalink
Change default parameters for Shapley estimator
Browse files Browse the repository at this point in the history
This should significantly speed up some calculations for larger graphs. In this regard, also increased the default samples for the anomaly attribution to balance the reduced number of Shapley run.

Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
  • Loading branch information
bloebp committed Jul 29, 2024
1 parent 9b9d609 commit 6a9a229
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
2 changes: 1 addition & 1 deletion dowhy/gcm/anomaly.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def attribute_anomalies(
anomaly_samples: pd.DataFrame,
anomaly_scorer: Optional[AnomalyScorer] = None,
attribute_mean_deviation: bool = False,
num_distribution_samples: int = 1500,
num_distribution_samples: int = 3000,
shapley_config: Optional[ShapleyConfig] = None,
) -> Dict[Any, np.ndarray]:
"""Estimates the contributions of upstream nodes to the anomaly score of the target_node for each sample in
Expand Down
17 changes: 10 additions & 7 deletions dowhy/gcm/shapley.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class ShapleyConfig:
def __init__(
self,
approximation_method: ShapleyApproximationMethods = ShapleyApproximationMethods.AUTO,
num_permutations: int = 2000,
num_permutations: int = 25,
num_subset_samples: int = 5000,
min_percentage_change_threshold: float = 0.05,
n_jobs: Optional[int] = None,
Expand Down Expand Up @@ -110,7 +110,7 @@ def estimate_shapley_values(
if num_players <= 5:
approximation_method = ShapleyApproximationMethods.EXACT
else:
approximation_method = ShapleyApproximationMethods.EARLY_STOPPING
approximation_method = ShapleyApproximationMethods.PERMUTATION

if approximation_method == ShapleyApproximationMethods.EXACT:
return _estimate_shapley_values_exact(set_func=set_func, num_players=num_players, n_jobs=shapley_config.n_jobs)
Expand Down Expand Up @@ -369,7 +369,7 @@ def _approximate_shapley_values_via_early_stopping(
permutation, evaluated_subsets, full_subset_result, empty_subset_result
)

if run_counter > max_num_permutations:
if run_counter * num_permutations_per_run > max_num_permutations:
break

new_shap_proxy = np.array(shapley_values)
Expand Down Expand Up @@ -579,12 +579,15 @@ def parallel_job(input_subset: Tuple[int], parallel_random_seed: int) -> Union[f

return set_func(np.array(input_subset))

if isinstance(evaluation_subsets, set):
evaluation_subsets = list(evaluation_subsets)

random_seeds = np.random.randint(np.iinfo(np.int32).max, size=len(evaluation_subsets))
subset_results = parallel_context(
delayed(parallel_job)(subset_to_evaluate, int(random_seed))
for subset_to_evaluate, random_seed in tqdm(
zip(evaluation_subsets, random_seeds),
desc="Evaluate set function",
delayed(parallel_job)(evaluation_subsets[i], int(random_seeds[i]))
for i in tqdm(
range(len(evaluation_subsets)),
desc="Evaluating set functions...",
position=0,
leave=True,
disable=not config.show_progress_bars or not show_progressbar,
Expand Down

0 comments on commit 6a9a229

Please sign in to comment.