Skip to content

Commit

Permalink
Fix issue of Param_Discrete_Numeric: search_categories init and creat…
Browse files Browse the repository at this point in the history
…ion of X_search_t
  • Loading branch information
xuyuting committed Sep 23, 2024
1 parent 74b7f32 commit fb83a60
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 10 deletions.
24 changes: 16 additions & 8 deletions obsidian/parameters/discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,21 +214,29 @@ def range(self):

def __init__(self,
name,
categories: list[int | float]):
categories: list[int | float],
search_categories: list[int | float] = None):

if not isinstance(categories, list):
raise TypeError('Categories must be a number or list of numbers')

self.categories = categories
for c in self.categories:
self._validate_value(c)


self.name = name
self.categories = categories if isinstance(categories, list) else [categories]
self.categories.sort()
for c in categories:
for c in self.categories:
self._validate_value(c)


# Set the search space to the parameter space by default
if not search_categories:
search_categories = self.categories
else:
if not isinstance(search_categories, list):
search_categories = [search_categories]
for c in search_categories:
self._validate_value(c)
search_categories.sort()
self.set_search(search_categories)

def _validate_value(self,
value: int | float):
"""
Expand Down
4 changes: 2 additions & 2 deletions obsidian/parameters/param_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,13 +242,13 @@ def search_space(self) -> pd.DataFrame:
X_search_t = pd.concat([X_search_t, cont_bounds], axis=1)

# For discrete, encode the available categories, then log the min-max of encoded columns
elif isinstance(param, Param_Discrete) and (not isinstance(param, Param_Ordinal)):
elif isinstance(param, Param_Discrete) and (not isinstance(param, (Param_Ordinal,Param_Discrete_Numeric))):
# Discrete parameter bounds aren't actually handled here; they are handled in optimizer._fixed_features())
cat_e = param.encode(param.search_categories)
disc_bounds = pd.DataFrame(np.vstack([[0]*cat_e.shape[-1], cat_e.max().values]),
columns=cat_e.columns)
X_search_t = pd.concat([X_search_t, disc_bounds], axis=1)
elif isinstance(param, Param_Ordinal):
elif isinstance(param, (Param_Ordinal,Param_Discrete_Numeric)):
cat_e = param.encode(param.search_categories)
cont_bounds = pd.DataFrame([min(cat_e), max(cat_e)], columns=[param.name])
X_search_t = pd.concat([X_search_t, cont_bounds], axis=1)
Expand Down

0 comments on commit fb83a60

Please sign in to comment.