Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
sdaza committed Dec 10, 2024
1 parent 185b11b commit 9a1a290
Showing 1 changed file with 16 additions and 20 deletions.
36 changes: 16 additions & 20 deletions experiment_utils/experiment_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,7 +458,7 @@ def get_effects(self, min_binary_count=100, adjustment=None):

key_experiments = self.data.select(*self.experiment_identifier).distinct().collect()

results = []
temp_results = []

if adjustment is None:
adjustment = self.adjustment
Expand Down Expand Up @@ -486,7 +486,7 @@ def get_effects(self, min_binary_count=100, adjustment=None):

treatvalues = set(temp_pd[self.treatment_col].unique())
if len(treatvalues) != 2:
self.logger.warning('Skipping as it is not a valid treatment-control group!')
self.logger.warning('Skipping as there are no valid treatment-control groups!')
continue
if not (0 in treatvalues and 1 in treatvalues):
log_and_raise_error(self.logger, f'The treatment column {self.treatment_col} must be 0 and 1')
Expand Down Expand Up @@ -517,10 +517,7 @@ def get_effects(self, min_binary_count=100, adjustment=None):

if len(final_covariates) > 0:
temp_pd["weights"] = 1
temp_pd = self.standardize_covariates(
temp_pd, final_covariates
)

temp_pd = self.standardize_covariates(temp_pd, final_covariates)
balance = self.calculate_smd(
data=temp_pd, covariates=final_covariates
)
Expand All @@ -529,7 +526,6 @@ def get_effects(self, min_binary_count=100, adjustment=None):
self._balance.append(balance)
self.logger.info('::::: Balance: %.2f', np.round(balance["balance_flag"].mean(), 2))
if adjustment == "IPW":
temp_pd = self.standardize_covariates(temp_pd, final_covariates)
temp_pd = self.estimate_ipw(
data=temp_pd,
covariates=[f"z_{cov}" for cov in final_covariates],
Expand Down Expand Up @@ -561,11 +557,11 @@ def get_effects(self, min_binary_count=100, adjustment=None):
output['adjustment'] = 'No adjustment' if adjustment is None else adjustment
if adjustment == 'IPW':
output['balance'] = np.round(adjusted_balance['balance_flag'].mean(), 2)
elif len(final_covariates) > 0:
elif (len(final_covariates) > 0):
output['balance'] = np.round(balance['balance_flag'].mean(), 2)
output['experiment'] = experiment_tuple

results.append(output)
temp_results.append(output)

result_columns = ['experiment', 'outcome', 'adjustment',
'treated_units', 'control_units', 'control_value',
Expand All @@ -577,15 +573,15 @@ def get_effects(self, min_binary_count=100, adjustment=None):
index_to_insert = result_columns.index('adjustment') + 1
result_columns.insert(index_to_insert, 'balance')

clean_results = pd.DataFrame(results)
clean_results = clean_results[result_columns]
clean_temp_results = pd.DataFrame(temp_results)
clean_temp_results = clean_temp_results[result_columns]

if len(self._balance) > 0:
self._balance = pd.concat(self._balance)
if len(self._adjusted_balance) > 0:
self._adjusted_balance = pd.concat(self._adjusted_balance)

self._results = self.__transform_tuple_column(clean_results, 'experiment', self.experiment_identifier)
self._results = self.__transform_tuple_column(clean_temp_results, 'experiment', self.experiment_identifier)

def combine_effects(self, data: pd.DataFrame = None, grouping_cols: List = None):
"""
Expand Down Expand Up @@ -647,7 +643,7 @@ def __get_fixed_meta_analysis_estimate(self, data):
except FloatingPointError:
pvalue = np.nan

results = {
meta_results = {
'experiments': int(data.shape[0]),
'treated_units': int(data['treated_units'].sum()),
'control_units': int(data['control_units'].sum()),
Expand All @@ -658,9 +654,9 @@ def __get_fixed_meta_analysis_estimate(self, data):
}

if 'balance' in data.columns:
results['balance'] = data['balance'].mean()
results['stat_significance'] = 1 if results['pvalue'] < self.alpha else 0
return results
meta_results['balance'] = data['balance'].mean()
meta_results['stat_significance'] = 1 if meta_results['pvalue'] < self.alpha else 0
return meta_results

def aggregate_effects(self, data: pd.DataFrame = None, grouping_cols: List = None):
"""
Expand Down Expand Up @@ -689,17 +685,17 @@ def aggregate_effects(self, data: pd.DataFrame = None, grouping_cols: List = Non
if 'outcome' not in grouping_cols:
grouping_cols.append('outcome')

results = data.groupby(grouping_cols).apply(self.__compute_weighted_effect).reset_index()
aggregate_results = data.groupby(grouping_cols).apply(self.__compute_weighted_effect).reset_index()

self.logger.info('Aggregating effects using weighted averages!')
self.logger.info('For a better standard error estimation, use meta-analysis or `combine_effects`')

# keep initial order
result_columns = grouping_cols + ['experiments', 'balance']
existing_columns = [col for col in result_columns if col in results.columns]
remaining_columns = [col for col in results.columns if col not in existing_columns]
existing_columns = [col for col in result_columns if col in aggregate_results.columns]
remaining_columns = [col for col in aggregate_results.columns if col not in existing_columns]
final_columns = existing_columns + remaining_columns
return results[final_columns]
return aggregate_results[final_columns]

def __compute_weighted_effect(self, group):

Expand Down

0 comments on commit 9a1a290

Please sign in to comment.