Skip to content

Commit

Permalink
ensuring boolean columns get cast to float for statsmodels
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathan-taylor committed Jan 18, 2024
1 parent 18e95bc commit 3ebd239
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 4 deletions.
12 changes: 8 additions & 4 deletions ISLP/models/model_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ def fit(self, X, y=None):
cats = self.encoder_.categories_[0]
column_names = [str(n) for n in cats]


if isinstance(X, pd.DataFrame): # expecting a column, we take .iloc[:,0]
X = X.iloc[:,0]

Expand Down Expand Up @@ -635,18 +634,23 @@ def build_model(column_info,
if isinstance(X, (pd.Series, pd.DataFrame)):
df = pd.concat(dfs, axis=1)
df.index = X.index
return df
else:
return np.column_stack(dfs)
return np.column_stack(dfs).astype(float)
else: # return a 0 design
zero = np.zeros(X.shape[0])
if isinstance(X, (pd.Series, pd.DataFrame)):
df = pd.DataFrame({'zero': zero})
df.index = X.index
return df
else:
return zero

# if we reach here, we will be returning a DataFrame

for col in df.columns:
if df[col].dtype == bool:
df[col] = df[col].astype(float)
return df

def derived_feature(variables, encoder=None, name=None, use_transform=True):
"""
Create a Feature, optionally
Expand Down
23 changes: 23 additions & 0 deletions tests/models/test_boolean_columns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pandas as pd
import statsmodels.api as sm
import numpy as np
from itertools import combinations

from ISLP.models import ModelSpec as MS

rng = np.random.default_rng(0)

df = pd.DataFrame({'A':rng.standard_normal(10),
'B':np.array([1,2,3,2,1,1,1,3,2,1], int),
'C':np.array([True,False,False,True,True]*2, bool),
'D':rng.standard_normal(10)})
Y = rng.standard_normal(10)

def test_all():

for i in range(1, 5):
for comb in combinations(['A','B','C','D'], i):

X = MS(comb).fit_transform(df)
sm.OLS(Y, X).fit()

0 comments on commit 3ebd239

Please sign in to comment.