Skip to content

Commit

Permalink
Dev (#28)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbussemaker authored Jan 6, 2025
2 parents 3bd3889 + df35d57 commit 0a3c2d9
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 11 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/tests_slow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ on:
branches: [ "main" ]
pull_request:
branches: [ "main", "dev" ]
schedule:
- cron: "21 3 * * 1"

jobs:
test:
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,13 @@ The library provides:

First, create a conda environment (skip if you already have one):
```
conda create --name opt python=3.9
conda create --name opt python=3.11
conda activate opt
```

Then install the package:
```
conda install numpy
conda install "numpy<2.0"
pip install sb-arch-opt
```

Expand Down
4 changes: 2 additions & 2 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,13 @@ The library provides:

First, create a conda environment (skip if you already have one):
```
conda create --name opt python=3.9
conda create --name opt python=3.11
conda activate opt
```

Then install the package:
```
conda install numpy
conda install "numpy<2.0"
pip install sb-arch-opt
```

Expand Down
2 changes: 1 addition & 1 deletion sb_arch_opt/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.5.3'
__version__ = '1.5.4'
25 changes: 20 additions & 5 deletions sb_arch_opt/algo/arch_sbo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,14 @@
from smt.surrogate_models.krg import KRG, KrgBased
from smt.surrogate_models.kpls import KPLS
from smt.surrogate_models.krg_based import MixIntKernelType, MixHrcKernelType
from smt.surrogate_models.rbf import RBF

from smt.utils.design_space import BaseDesignSpace
import smt.utils.design_space as ds
try:
from smt.utils.design_space import BaseDesignSpace
import smt.utils.design_space as ds
except ModuleNotFoundError: # Module moved from SMT 2.8
from smt.design_space import BaseDesignSpace
import smt.design_space as ds

from smt import __version__

Expand All @@ -59,8 +64,11 @@ class BaseDesignSpace:
class SurrogateModel:
pass

class RBF:
pass

__all__ = ['check_dependencies', 'HAS_ARCH_SBO', 'HAS_SMT', 'ModelFactory', 'MixedDiscreteNormalization', 'SBArchOptDesignSpace',
'MultiSurrogateModel']
'MultiSurrogateModel', 'FixedRBF']


def check_dependencies():
Expand Down Expand Up @@ -150,8 +158,7 @@ def get_md_normalization(self):
@staticmethod
def get_rbf_model():
check_dependencies()
from smt.surrogate_models.rbf import RBF
return RBF(print_global=False, d0=1., poly_degree=-1, reg=1e-10)
return FixedRBF(print_global=False, d0=1., poly_degree=-1, reg=1e-10)

@staticmethod
def get_kriging_model(multi=True, kpls_n_comp: int = None, **kwargs):
Expand Down Expand Up @@ -425,3 +432,11 @@ def predict_variances(self, x: np.ndarray, is_acting=None) -> np.ndarray:

def _predict_values(self, x: np.ndarray, is_acting=None) -> np.ndarray:
raise RuntimeError


class FixedRBF(RBF):
"""RBF model that can be deep-copied or pickled before initialization"""

def __setstate__(self, state):
"""Override to remove the call to _setup"""
self.__dict__.update(state)
2 changes: 1 addition & 1 deletion sb_arch_opt/tests/algo/test_arch_sbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def _get_xy_train(self, x: np.ndarray, y: np.ndarray) -> Tuple[np.ndarray, np.nd
@check_dependency()
def test_invalid_training_set(problem: ArchOptProblemBase):
assert HAS_SMT
sbo = FailedXYRemovingSBO(RBF(print_global=False), FunctionEstimateInfill(), pop_size=100, termination=100,
sbo = FailedXYRemovingSBO(FixedRBF(print_global=False), FunctionEstimateInfill(), pop_size=100, termination=100,
repair=ArchOptRepair()).algorithm(infill_size=1, init_size=10)
sbo.setup(problem)

Expand Down

0 comments on commit 0a3c2d9

Please sign in to comment.