Skip to content

Commit

Permalink
Update, fix and expand tests
Browse files Browse the repository at this point in the history
This commit updates old tests, fix broken ones, and adds new ones as well.

NOTE: test_acceptance.py is commented out since it takes a while to run. However,
it is arguably the most important test of all.
  • Loading branch information
fmfn committed Nov 25, 2018
1 parent 09158d7 commit 39d1d42
Show file tree
Hide file tree
Showing 12 changed files with 659 additions and 346 deletions.
15 changes: 9 additions & 6 deletions bayes_opt/bayesian_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,15 @@ def __next__(self):
self._queue = self._queue[1:]
return obj

def next(self):
return self.__next__()

def add(self, obj):
"""Add object to end of queue."""
self._queue.append(obj)


class Observable:
class Observable(object):
"""
Inspired/Taken from
Expand Down Expand Up @@ -147,11 +150,11 @@ def _prime_subscriptions(self):
self.subscribe(Events.OPTMIZATION_END, _logger)

def maximize(self,
init_points: int=5,
n_iter: int=25,
acq: str='ucb',
kappa: float=2.576,
xi: float=0.0,
init_points=5,
n_iter=25,
acq='ucb',
kappa=2.576,
xi=0.0,
**gp_params):
"""Mazimize your function"""
self._prime_subscriptions()
Expand Down
3 changes: 2 additions & 1 deletion bayes_opt/observer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
observers...
"""
from __future__ import print_function
import os
import json
from datetime import datetime
Expand All @@ -14,7 +15,7 @@ def update(self, event, instance):
raise NotImplementedError


class _Tracker:
class _Tracker(object):
def __init__(self):
self._iterations = 0

Expand Down
18 changes: 12 additions & 6 deletions bayes_opt/target_space.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numpy as np
from .util import ensure_rng, unique_rows
from .util import ensure_rng


def _hashable(x):
Expand All @@ -22,7 +22,7 @@ class TargetSpace(object):
>>> y = space.register_point(x)
>>> assert self.max_point()['max_val'] == y
"""
def __init__(self, target_func, pbounds: dict, random_state=None):
def __init__(self, target_func, pbounds, random_state=None):
"""
Parameters
----------
Expand Down Expand Up @@ -112,9 +112,15 @@ def _as_array(self, x):
x = np.asarray(x, dtype=float)
except TypeError:
x = self.params_to_array(x)
finally:
x = x.ravel()
assert x.size == self.dim, 'x must have the same dimensions'

x = x.ravel()
try:
assert x.size == self.dim
except AssertionError:
raise ValueError(
"Size of array ({}) is different than the ".format(len(x)) +
"expected number of parameters ({}).".format(len(self.keys))
)
return x

def register(self, params, target):
Expand Down Expand Up @@ -243,6 +249,6 @@ def set_bounds(self, new_bounds):
new_bounds : dict
A dictionary with the parameter name and its new bounds
"""
for row, key in enumerate(self._keys):
for row, key in enumerate(self.keys):
if key in new_bounds:
self._bounds[row] = new_bounds[key]
48 changes: 24 additions & 24 deletions bayes_opt/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,30 +156,6 @@ def load_logs(optimizer, logs):
return optimizer


def unique_rows(a):
"""
A function to trim repeated rows that may appear when optimizing.
This is necessary to avoid the sklearn GP object from breaking
:param a: array to trim repeated rows from
:return: mask of unique rows
"""
if a.size == 0:
return np.empty((0,))

# Sort array and kep track of where things should go back to
order = np.lexsort(a.T)
reorder = np.argsort(order)

a = a[order]
diff = np.diff(a, axis=0)
ui = np.ones(len(a), 'bool')
ui[1:] = (diff != 0).any(axis=1)

return ui[reorder]


def ensure_rng(random_state=None):
"""
Creates a random number generator based on an optional seed. This can be
Expand Down Expand Up @@ -262,3 +238,27 @@ def underline(cls, s):
def yellow(cls, s):
"""Wrap text in yellow."""
return cls._wrap_colour(s, cls.YELLOW)


# def unique_rows(a):
# """
# A function to trim repeated rows that may appear when optimizing.
# This is necessary to avoid the sklearn GP object from breaking

# :param a: array to trim repeated rows from

# :return: mask of unique rows
# """
# if a.size == 0:
# return np.empty((0,))

# # Sort array and kep track of where things should go back to
# order = np.lexsort(a.T)
# reorder = np.argsort(order)

# a = a[order]
# diff = np.diff(a, axis=0)
# ui = np.ones(len(a), 'bool')
# ui[1:] = (diff != 0).any(axis=1)

# return ui[reorder]
69 changes: 69 additions & 0 deletions tests/test_acceptance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# import numpy as np

# from bayes_opt import BayesianOptimization
# from bayes_opt.util import ensure_rng


# def test_simple_optimization():
# """
# ...
# """
# def f(x, y):
# return -x ** 2 - (y - 1) ** 2 + 1


# optimizer = BayesianOptimization(
# f=f,
# pbounds={"x": (-3, 3), "y": (-3, 3)},
# random_state=12356,
# verbose=0,
# )

# optimizer.maximize(init_points=0, n_iter=25)

# max_target = optimizer.max["target"]
# max_x = optimizer.max["params"]["x"]
# max_y = optimizer.max["params"]["y"]

# assert (1 - max_target) < 1e-3
# assert np.abs(max_x - 0) < 1e-1
# assert np.abs(max_y - 1) < 1e-1


# def test_intermediate_optimization():
# """
# ...
# """
# def f(x, y, z):
# x_factor = np.exp(-(x - 2) ** 2) + (1 / (x ** 2 + 1))
# y_factor = np.exp(-(y - 6) ** 2 / 10)
# z_factor = (1 + 0.2 * np.cos(z)) / (1 + z ** 2)
# return (x_factor + y_factor) * z_factor

# optimizer = BayesianOptimization(
# f=f,
# pbounds={"x": (-7, 7), "y": (-7, 7), "z": (-7, 7)},
# random_state=56,
# verbose=0,
# )

# optimizer.maximize(init_points=0, n_iter=150)

# max_target = optimizer.max["target"]
# max_x = optimizer.max["params"]["x"]
# max_y = optimizer.max["params"]["y"]
# max_z = optimizer.max["params"]["z"]

# assert (2.640 - max_target) < 0
# assert np.abs(2 - max_x) < 1e-1
# assert np.abs(6 - max_y) < 1e-1
# assert np.abs(0 - max_z) < 1e-1


# if __name__ == '__main__':
# r"""
# CommandLine:
# python tests/test_bayesian_optimization.py
# """
# import pytest
# pytest.main([__file__])
Loading

0 comments on commit 39d1d42

Please sign in to comment.