From fb83a607a68312befb1be784b35d8b09f43a4482 Mon Sep 17 00:00:00 2001 From: Yuting Xu <12775874+xuyuting@users.noreply.github.com> Date: Mon, 23 Sep 2024 17:42:18 -0400 Subject: [PATCH 1/2] Fix issue of Param_Discrete_Numeric: search_categories init and creation of X_search_t --- obsidian/parameters/discrete.py | 24 ++++++++++++++++-------- obsidian/parameters/param_space.py | 4 ++-- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/obsidian/parameters/discrete.py b/obsidian/parameters/discrete.py index cf4730f..8a6e9b9 100644 --- a/obsidian/parameters/discrete.py +++ b/obsidian/parameters/discrete.py @@ -214,21 +214,29 @@ def range(self): def __init__(self, name, - categories: list[int | float]): + categories: list[int | float], + search_categories: list[int | float] = None): if not isinstance(categories, list): raise TypeError('Categories must be a number or list of numbers') - - self.categories = categories - for c in self.categories: - self._validate_value(c) - + self.name = name self.categories = categories if isinstance(categories, list) else [categories] self.categories.sort() - for c in categories: + for c in self.categories: self._validate_value(c) - + + # Set the search space to the parameter space by default + if not search_categories: + search_categories = self.categories + else: + if not isinstance(search_categories, list): + search_categories = [search_categories] + for c in search_categories: + self._validate_value(c) + search_categories.sort() + self.set_search(search_categories) + def _validate_value(self, value: int | float): """ diff --git a/obsidian/parameters/param_space.py b/obsidian/parameters/param_space.py index e68070b..6fdb0d6 100644 --- a/obsidian/parameters/param_space.py +++ b/obsidian/parameters/param_space.py @@ -242,13 +242,13 @@ def search_space(self) -> pd.DataFrame: X_search_t = pd.concat([X_search_t, cont_bounds], axis=1) # For discrete, encode the available categories, then log the min-max of encoded columns - elif isinstance(param, Param_Discrete) and (not isinstance(param, Param_Ordinal)): + elif isinstance(param, Param_Discrete) and (not isinstance(param, (Param_Ordinal,Param_Discrete_Numeric))): # Discrete parameter bounds aren't actually handled here; they are handled in optimizer._fixed_features()) cat_e = param.encode(param.search_categories) disc_bounds = pd.DataFrame(np.vstack([[0]*cat_e.shape[-1], cat_e.max().values]), columns=cat_e.columns) X_search_t = pd.concat([X_search_t, disc_bounds], axis=1) - elif isinstance(param, Param_Ordinal): + elif isinstance(param, (Param_Ordinal,Param_Discrete_Numeric)): cat_e = param.encode(param.search_categories) cont_bounds = pd.DataFrame([min(cat_e), max(cat_e)], columns=[param.name]) X_search_t = pd.concat([X_search_t, cont_bounds], axis=1) From 22f9b466bae46acb85c11712332a990e0592dcd8 Mon Sep 17 00:00:00 2001 From: Yuting Xu <12775874+xuyuting@users.noreply.github.com> Date: Mon, 23 Sep 2024 17:44:56 -0400 Subject: [PATCH 2/2] fix code style --- obsidian/parameters/param_space.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/obsidian/parameters/param_space.py b/obsidian/parameters/param_space.py index 6fdb0d6..a8e86f0 100644 --- a/obsidian/parameters/param_space.py +++ b/obsidian/parameters/param_space.py @@ -242,13 +242,13 @@ def search_space(self) -> pd.DataFrame: X_search_t = pd.concat([X_search_t, cont_bounds], axis=1) # For discrete, encode the available categories, then log the min-max of encoded columns - elif isinstance(param, Param_Discrete) and (not isinstance(param, (Param_Ordinal,Param_Discrete_Numeric))): + elif isinstance(param, Param_Discrete) and (not isinstance(param, (Param_Ordinal, Param_Discrete_Numeric))): # Discrete parameter bounds aren't actually handled here; they are handled in optimizer._fixed_features()) cat_e = param.encode(param.search_categories) disc_bounds = pd.DataFrame(np.vstack([[0]*cat_e.shape[-1], cat_e.max().values]), columns=cat_e.columns) X_search_t = pd.concat([X_search_t, disc_bounds], axis=1) - elif isinstance(param, (Param_Ordinal,Param_Discrete_Numeric)): + elif isinstance(param, (Param_Ordinal, Param_Discrete_Numeric)): cat_e = param.encode(param.search_categories) cont_bounds = pd.DataFrame([min(cat_e), max(cat_e)], columns=[param.name]) X_search_t = pd.concat([X_search_t, cont_bounds], axis=1)