Skip to content

Commit

Permalink
fix iv balance
Browse files Browse the repository at this point in the history
  • Loading branch information
sdaza committed Jan 5, 2025
1 parent 3e0e3ac commit c50cd75
Showing 1 changed file with 9 additions and 4 deletions.
13 changes: 9 additions & 4 deletions experiment_utils/experiment_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,16 @@ def standardize_covariates(self, data: pd.DataFrame, covariates: List[str]) -> p
data[f"z_{covariate}"] = (data[covariate] - data[covariate].mean()) / data[covariate].std()
return data

def calculate_smd(self, data: pd.DataFrame, covariates: Optional[List[str]] = None, weights_col: str = "weights", threshold: float = 0.1) -> pd.DataFrame:
def calculate_smd(self, data: pd.DataFrame, treatment_col: str = None, covariates: Optional[List[str]] = None, weights_col: str = "weights", threshold: float = 0.1) -> pd.DataFrame:
"""
Calculate standardized mean differences (SMDs) between treatment and control groups.
Parameters
----------
data : DataFrame, optional
DataFrame containing the data to calculate SMDs on. If None, uses the data from the class.
treatment_col : str, optional
Name of the column containing the treatment assignment.
covariates : list, optional
List of column names to calculate SMDs for. If None, uses all numeric and binary covariates.
weights_col : str, optional
Expand All @@ -197,8 +199,11 @@ def calculate_smd(self, data: pd.DataFrame, covariates: Optional[List[str]] = No
DataFrame containing the SMDs and balance flags for each covariate.
"""

treated = data[data[self._treatment_col] == 1]
control = data[data[self._treatment_col] == 0]
if treatment_col is None:
treatment_col = self._treatment_col

treated = data[data[treatment_col] == 1]
control = data[data[treatment_col] == 0]

if covariates is None:
covariates = self._final_covariates
Expand Down Expand Up @@ -398,7 +403,7 @@ def get_effects(self, min_binary_count: int = 100, adjustment: Optional[str] = N
if self._instrument_col is None:
log_and_raise_error(self._logger, "Instrument column is required for IV estimation!")
iv_balance = self.calculate_smd(
data=temp_pd, covariates=final_covariates
data=temp_pd, treatment_col=self._instrument_col, covariates=final_covariates
)
self._logger.info('::::: IV Balance: %.2f', np.round(iv_balance["balance_flag"].mean(), 2))

Expand Down

0 comments on commit c50cd75

Please sign in to comment.