Skip to content

Commit

Permalink
Relax number of points required wrt regression spec (#244)
Browse files Browse the repository at this point in the history
* Relax required nb of points wrt regression spec

* Refactor type to specify the number of clusters

* Linting

* Fix docstring
  • Loading branch information
relf authored Mar 6, 2025
1 parent 1389542 commit df91c07
Show file tree
Hide file tree
Showing 17 changed files with 187 additions and 98 deletions.
4 changes: 2 additions & 2 deletions crates/ego/examples/mopta08.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use clap::Parser;
use egobox_ego::{EgorBuilder, GroupFunc, InfillOptimizer};
use egobox_moe::{CorrelationSpec, RegressionSpec};
use egobox_moe::{CorrelationSpec, NbClusters, RegressionSpec};
use ndarray::{s, Array1, Array2, ArrayView1, ArrayView2};
use std::fs::{remove_file, File};
use std::io::prelude::*;
Expand Down Expand Up @@ -267,7 +267,7 @@ fn main() -> anyhow::Result<()> {
config
.n_cstr(N_CSTR)
.cstr_tol(cstr_tol.clone())
.n_clusters(1)
.n_clusters(NbClusters::fixed(1))
.n_start(50)
.n_doe(n_doe)
.max_iters(max_iters)
Expand Down
3 changes: 2 additions & 1 deletion crates/ego/src/egor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,7 @@ mod tests {
use approx::assert_abs_diff_eq;
use argmin_testfunctions::rosenbrock;
use egobox_doe::{Lhs, SamplingMethod};
use egobox_moe::NbClusters;
use ndarray::{array, s, Array1, Array2, ArrayView2, Ix1, Zip};

use ndarray_npy::read_npy;
Expand Down Expand Up @@ -513,7 +514,7 @@ mod tests {
#[serial]
fn test_xsinx_auto_clustering_egor_builder() {
let res = EgorBuilder::optimize(xsinx)
.configure(|config| config.n_clusters(0).max_iters(20))
.configure(|config| config.n_clusters(NbClusters::auto()).max_iters(20))
.min_within(&array![[0.0, 25.0]])
.run()
.expect("Egor with auto clustering should minimize xsinx");
Expand Down
7 changes: 3 additions & 4 deletions crates/ego/src/gpmix/mixint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use egobox_gp::metrics::CrossValScore;
use egobox_gp::ThetaTuning;
use egobox_moe::{
Clustered, Clustering, CorrelationSpec, FullGpSurrogate, GpMixture, GpMixtureParams,
GpSurrogate, GpSurrogateExt, MixtureGpSurrogate, RegressionSpec,
GpSurrogate, GpSurrogateExt, MixtureGpSurrogate, NbClusters, RegressionSpec,
};
use linfa::traits::{Fit, PredictInplace};
use linfa::{DatasetBase, Float, ParamGuard};
Expand Down Expand Up @@ -358,8 +358,7 @@ impl MixintGpMixtureValidParams {
.surrogate_builder
.clone()
.check()?
.train(&xcast, &yt.to_owned())
.unwrap(),
.train(&xcast, &yt.to_owned())?,
xtypes: self.xtypes.clone(),
work_in_folded_space: self.work_in_folded_space,
training_data: (xt.to_owned(), yt.to_owned()),
Expand Down Expand Up @@ -437,7 +436,7 @@ impl SurrogateBuilder for MixintGpMixtureParams {
}

/// Sets the number of clusters used by the mixture of surrogate experts.
fn set_n_clusters(&mut self, n_clusters: usize) {
fn set_n_clusters(&mut self, n_clusters: NbClusters) {
self.0 = MixintGpMixtureValidParams {
surrogate_builder: self.0.surrogate_builder.clone().n_clusters(n_clusters),
xtypes: self.0.xtypes.clone(),
Expand Down
4 changes: 2 additions & 2 deletions crates/ego/src/gpmix/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pub mod spec;

use egobox_gp::ThetaTuning;
use egobox_moe::{
Clustering, CorrelationSpec, GpMixtureParams, MixtureGpSurrogate, RegressionSpec,
Clustering, CorrelationSpec, GpMixtureParams, MixtureGpSurrogate, NbClusters, RegressionSpec,
};
use ndarray::{ArrayView1, ArrayView2};

Expand Down Expand Up @@ -40,7 +40,7 @@ impl SurrogateBuilder for GpMixtureParams<f64> {
}

/// Sets the number of clusters used by the mixture of surrogate experts.
fn set_n_clusters(&mut self, n_clusters: usize) {
fn set_n_clusters(&mut self, n_clusters: NbClusters) {
*self = self.clone().n_clusters(n_clusters);
}

Expand Down
12 changes: 6 additions & 6 deletions crates/ego/src/solver/egor_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use crate::criteria::*;
use crate::types::*;
use crate::HotStartMode;
use egobox_moe::NbClusters;
use egobox_moe::{CorrelationSpec, RegressionSpec};
use ndarray::Array1;
use ndarray::Array2;
Expand Down Expand Up @@ -72,9 +73,9 @@ pub struct EgorConfig {
/// Optional dimension reduction (see [egobox_moe])
pub(crate) kpls_dim: Option<usize>,
/// Number of clusters used by mixture of experts (see [egobox_moe])
/// When set to 0 the clusters are computes automatically and refreshed
/// When set to Auto the clusters are computes automatically and refreshed
/// every 10-points (tentative) additions
pub(crate) n_clusters: usize,
pub(crate) n_clusters: NbClusters,
/// Specification of a target objective value which is used to stop the algorithm once reached
pub(crate) target: f64,
/// Directory to save intermediate results: inital doe + evalutions at each iteration
Expand Down Expand Up @@ -108,7 +109,7 @@ impl Default for EgorConfig {
regression_spec: RegressionSpec::CONSTANT,
correlation_spec: CorrelationSpec::SQUAREDEXPONENTIAL,
kpls_dim: None,
n_clusters: 1,
n_clusters: NbClusters::default(),
target: f64::NEG_INFINITY,
outdir: None,
warm_start: false,
Expand Down Expand Up @@ -239,9 +240,8 @@ impl EgorConfig {

/// Sets the number of clusters used by the mixture of surrogate experts.
///
/// When set to Some(0), the number of clusters is determined automatically
/// When set None, default to 1
pub fn n_clusters(mut self, n_clusters: usize) -> Self {
/// When set to Auto, the number of clusters is determined automatically
pub fn n_clusters(mut self, n_clusters: NbClusters) -> Self {
self.n_clusters = n_clusters;
self
}
Expand Down
5 changes: 3 additions & 2 deletions crates/ego/src/solver/egor_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ where
C: CstrFn,
{
pub fn have_to_recluster(&self, added: usize, prev_added: usize) -> bool {
self.config.n_clusters == 0 && (added != 0 && added % 10 == 0 && added - prev_added > 0)
self.config.n_clusters.is_auto()
&& (added != 0 && added % 10 == 0 && added - prev_added > 0)
}

/// Build surrogate given training data and surrogate builder
Expand All @@ -124,7 +125,7 @@ where
builder.set_kpls_dim(self.config.kpls_dim);
builder.set_regression_spec(self.config.regression_spec);
builder.set_correlation_spec(self.config.correlation_spec);
builder.set_n_clusters(self.config.n_clusters);
builder.set_n_clusters(self.config.n_clusters.clone());

if make_clustering
/* init || recluster */
Expand Down
4 changes: 2 additions & 2 deletions crates/ego/src/types.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::gpmix::spec::*;
use crate::{errors::Result, EgorState};
use argmin::core::CostFunction;
use egobox_moe::{Clustering, MixtureGpSurrogate, ThetaTuning};
use egobox_moe::{Clustering, MixtureGpSurrogate, NbClusters, ThetaTuning};
use linfa::Float;
use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -140,7 +140,7 @@ pub trait SurrogateBuilder: Clone + Serialize + Sync {
fn set_kpls_dim(&mut self, kpls_dim: Option<usize>);

/// Sets the number of clusters used by the mixture of surrogate experts.
fn set_n_clusters(&mut self, n_clusters: usize);
fn set_n_clusters(&mut self, n_clusters: NbClusters);

/// Sets the hyperparameters tuning strategy
fn set_theta_tunings(&mut self, theta_tunings: &[ThetaTuning<f64>]);
Expand Down
4 changes: 2 additions & 2 deletions crates/moe/examples/clustering.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use egobox_doe::{Lhs, SamplingMethod};
use egobox_moe::{GpMixture, Recombination};
use egobox_moe::{GpMixture, NbClusters, Recombination};
use linfa::prelude::{Dataset, Fit};
use ndarray::{arr2, Array, Array2, Axis, Zip};
use std::error::Error;
Expand All @@ -24,7 +24,7 @@ fn main() -> Result<(), Box<dyn Error>> {
let ds = Dataset::new(xtrain, ytrain.remove_axis(Axis(1)));
let moe1 = GpMixture::params().fit(&ds)?;
let moe3 = GpMixture::params()
.n_clusters(3)
.n_clusters(NbClusters::fixed(3))
.recombination(Recombination::Hard)
.fit(&ds)?;

Expand Down
6 changes: 4 additions & 2 deletions crates/moe/examples/moe_norm1.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use csv::ReaderBuilder;
use egobox_doe::{FullFactorial, SamplingMethod};
use egobox_moe::GpMixture;
use egobox_moe::{GpMixture, NbClusters};
use linfa::{traits::Fit, Dataset};
use ndarray::{arr2, s, Array2, Axis};
use ndarray_csv::Array2Reader;
Expand All @@ -24,7 +24,9 @@ fn main() -> Result<(), Box<dyn Error>> {
let xtrain = data_train.slice(s![.., ..2_usize]).to_owned();
let ytrain = data_train.slice(s![.., 2_usize..]).to_owned();
let ds = Dataset::new(xtrain, ytrain.remove_axis(Axis(1)));
let moe = GpMixture::params().n_clusters(4).fit(&ds)?;
let moe = GpMixture::params()
.n_clusters(NbClusters::fixed(4))
.fit(&ds)?;

let xlimits = arr2(&[[-1., 1.], [-1., 1.]]);
let xtest = FullFactorial::new(&xlimits).sample(100);
Expand Down
4 changes: 2 additions & 2 deletions crates/moe/examples/norm1.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use egobox_doe::{Lhs, SamplingMethod};
use egobox_moe::{GpMixture, Recombination};
use egobox_moe::{GpMixture, NbClusters, Recombination};
use linfa::{traits::Fit, Dataset};
use ndarray::{arr2, Array2, Axis};
use std::error::Error;
Expand All @@ -14,7 +14,7 @@ fn main() -> Result<(), Box<dyn Error>> {
let ds = Dataset::new(xtrain, ytrain.remove_axis(Axis(1)));
let moe1 = GpMixture::params().fit(&ds)?;
let moe5 = GpMixture::params()
.n_clusters(6)
.n_clusters(NbClusters::fixed(6))
.recombination(Recombination::Hard)
.fit(&ds)?;

Expand Down
Loading

0 comments on commit df91c07

Please sign in to comment.