diff --git a/obsidian/parameters/discrete.py b/obsidian/parameters/discrete.py index cf4730f..8a6e9b9 100644 --- a/obsidian/parameters/discrete.py +++ b/obsidian/parameters/discrete.py @@ -214,21 +214,29 @@ def range(self): def __init__(self, name, - categories: list[int | float]): + categories: list[int | float], + search_categories: list[int | float] = None): if not isinstance(categories, list): raise TypeError('Categories must be a number or list of numbers') - - self.categories = categories - for c in self.categories: - self._validate_value(c) - + self.name = name self.categories = categories if isinstance(categories, list) else [categories] self.categories.sort() - for c in categories: + for c in self.categories: self._validate_value(c) - + + # Set the search space to the parameter space by default + if not search_categories: + search_categories = self.categories + else: + if not isinstance(search_categories, list): + search_categories = [search_categories] + for c in search_categories: + self._validate_value(c) + search_categories.sort() + self.set_search(search_categories) + def _validate_value(self, value: int | float): """ diff --git a/obsidian/parameters/param_space.py b/obsidian/parameters/param_space.py index e68070b..a8e86f0 100644 --- a/obsidian/parameters/param_space.py +++ b/obsidian/parameters/param_space.py @@ -242,13 +242,13 @@ def search_space(self) -> pd.DataFrame: X_search_t = pd.concat([X_search_t, cont_bounds], axis=1) # For discrete, encode the available categories, then log the min-max of encoded columns - elif isinstance(param, Param_Discrete) and (not isinstance(param, Param_Ordinal)): + elif isinstance(param, Param_Discrete) and (not isinstance(param, (Param_Ordinal, Param_Discrete_Numeric))): # Discrete parameter bounds aren't actually handled here; they are handled in optimizer._fixed_features()) cat_e = param.encode(param.search_categories) disc_bounds = pd.DataFrame(np.vstack([[0]*cat_e.shape[-1], cat_e.max().values]), columns=cat_e.columns) X_search_t = pd.concat([X_search_t, disc_bounds], axis=1) - elif isinstance(param, Param_Ordinal): + elif isinstance(param, (Param_Ordinal, Param_Discrete_Numeric)): cat_e = param.encode(param.search_categories) cont_bounds = pd.DataFrame([min(cat_e), max(cat_e)], columns=[param.name]) X_search_t = pd.concat([X_search_t, cont_bounds], axis=1)