diff --git a/dowhy/gcm/anomaly.py b/dowhy/gcm/anomaly.py index 46f67e6de..2026787eb 100644 --- a/dowhy/gcm/anomaly.py +++ b/dowhy/gcm/anomaly.py @@ -94,7 +94,7 @@ def attribute_anomalies( anomaly_samples: pd.DataFrame, anomaly_scorer: Optional[AnomalyScorer] = None, attribute_mean_deviation: bool = False, - num_distribution_samples: int = 1500, + num_distribution_samples: int = 3000, shapley_config: Optional[ShapleyConfig] = None, ) -> Dict[Any, np.ndarray]: """Estimates the contributions of upstream nodes to the anomaly score of the target_node for each sample in diff --git a/dowhy/gcm/shapley.py b/dowhy/gcm/shapley.py index 29c660245..0670a1e98 100644 --- a/dowhy/gcm/shapley.py +++ b/dowhy/gcm/shapley.py @@ -44,7 +44,7 @@ class ShapleyConfig: def __init__( self, approximation_method: ShapleyApproximationMethods = ShapleyApproximationMethods.AUTO, - num_permutations: int = 2000, + num_permutations: int = 25, num_subset_samples: int = 5000, min_percentage_change_threshold: float = 0.05, n_jobs: Optional[int] = None, @@ -110,7 +110,7 @@ def estimate_shapley_values( if num_players <= 5: approximation_method = ShapleyApproximationMethods.EXACT else: - approximation_method = ShapleyApproximationMethods.EARLY_STOPPING + approximation_method = ShapleyApproximationMethods.PERMUTATION if approximation_method == ShapleyApproximationMethods.EXACT: return _estimate_shapley_values_exact(set_func=set_func, num_players=num_players, n_jobs=shapley_config.n_jobs) @@ -369,7 +369,7 @@ def _approximate_shapley_values_via_early_stopping( permutation, evaluated_subsets, full_subset_result, empty_subset_result ) - if run_counter > max_num_permutations: + if run_counter * num_permutations_per_run > max_num_permutations: break new_shap_proxy = np.array(shapley_values) @@ -579,12 +579,15 @@ def parallel_job(input_subset: Tuple[int], parallel_random_seed: int) -> Union[f return set_func(np.array(input_subset)) + if isinstance(evaluation_subsets, set): + evaluation_subsets = list(evaluation_subsets) + random_seeds = np.random.randint(np.iinfo(np.int32).max, size=len(evaluation_subsets)) subset_results = parallel_context( - delayed(parallel_job)(subset_to_evaluate, int(random_seed)) - for subset_to_evaluate, random_seed in tqdm( - zip(evaluation_subsets, random_seeds), - desc="Evaluate set function", + delayed(parallel_job)(evaluation_subsets[i], int(random_seeds[i])) + for i in tqdm( + range(len(evaluation_subsets)), + desc="Evaluating set functions...", position=0, leave=True, disable=not config.show_progress_bars or not show_progressbar,