Skip to content

Commit

Permalink
fix polynomial
Browse files Browse the repository at this point in the history
  • Loading branch information
sdaza committed Jan 2, 2025
1 parent b7484eb commit 0ecbd17
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
8 changes: 4 additions & 4 deletions experiment_utils/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ class Estimators:
def __init__(self, treatment_col: str, instrument_col: Optional[str] = None,
target_ipw_effect: str = 'ATT', alpha: float = 0.05,
min_ps_score: float = 0.05, max_ps_score: float = 0.95,
interaction_logistic_ipw: bool = False) -> None:
polynomial_ipw: bool = False) -> None:

self.logger = get_logger('Estimators')
self.treatment_col = treatment_col
Expand All @@ -28,7 +28,7 @@ def __init__(self, treatment_col: str, instrument_col: Optional[str] = None,
self.alpha = alpha
self.max_ps_score = max_ps_score
self.min_ps_score = min_ps_score
self.interaction_logistic_ipw = interaction_logistic_ipw
self.polynomial_ipw = polynomial_ipw

def __create_formula(self, outcome_variable: str, covariates: Optional[List[str]], model_type: str = 'regression') -> str:
"""
Expand Down Expand Up @@ -224,8 +224,8 @@ def ipw_logistic(self, data: pd.DataFrame, covariates: List[str], penalty: str =

logistic_model = LogisticRegression(penalty=penalty, C=C, max_iter=max_iter)

if self.interaction_logistic_ipw:
poly = PolynomialFeatures(interaction_only=True, include_bias=False)
if self.polynomial_ipw:
poly = PolynomialFeatures()
X = poly.fit_transform(data[covariates])
feature_names = poly.get_feature_names_out(covariates)
X = pd.DataFrame(X, columns=feature_names)
Expand Down
4 changes: 2 additions & 2 deletions experiment_utils/experiment_analyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
propensity_score_method: str = 'logistic',
min_ps_score: float = 0.05,
max_ps_score: float = 0.95,
interaction_logistic_ipw: bool = True,
polynomial_ipw: bool = True,
instrument_col: Optional[str] = None,
alpha: float = 0.05,
regression_covariates: Optional[List[str]] = None,
Expand Down Expand Up @@ -71,7 +71,7 @@ def __init__(
"""

super().__init__(treatment_col, instrument_col, target_ipw_effect,
alpha, min_ps_score, max_ps_score, interaction_logistic_ipw)
alpha, min_ps_score, max_ps_score, polynomial_ipw)

self.logger = get_logger('Experiment Analyzer')
self.data = self.__ensure_spark_df(data)
Expand Down

0 comments on commit 0ecbd17

Please sign in to comment.