Skip to content

Commit

Permalink
Added observer pattern to allow callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
aprams committed Sep 29, 2018
1 parent 317f735 commit 376f205
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 72 deletions.
5 changes: 3 additions & 2 deletions bayes_opt/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .bayesian_optimization import BayesianOptimization
from .bayesian_optimization import BayesianOptimization, Events
from .helpers import UtilityFunction
from .observer import Observer

__all__ = ["BayesianOptimization", "UtilityFunction"]
__all__ = ["BayesianOptimization", "UtilityFunction", "Events", "Observer"]
22 changes: 21 additions & 1 deletion bayes_opt/bayesian_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from sklearn.gaussian_process.kernels import Matern
from .helpers import (UtilityFunction, PrintLog, acq_max, ensure_rng)
from .target_space import TargetSpace
from .observer import Observable


class BayesianOptimization(object):
class BayesianOptimization(Observable):

def __init__(self, f, pbounds, random_state=None, verbose=1):
"""
Expand Down Expand Up @@ -71,6 +72,10 @@ def __init__(self, f, pbounds, random_state=None, verbose=1):
# Verbose
self.verbose = verbose

# Event initialization
events = [Events.INIT_DONE, Events.FIT_STEP_DONE, Events.FIT_DONE]
super(BayesianOptimization, self).__init__(events)

def init(self, init_points):
"""
Initialization method to kick start the optimization process. It is a
Expand Down Expand Up @@ -100,6 +105,9 @@ def init(self, init_points):
# Updates the flag
self.initialized = True

# Notify about finished init method
self.dispatch(Events.INIT_DONE)

def _observe_point(self, x):
y = self.space.observe_point(x)
if self.verbose:
Expand Down Expand Up @@ -303,10 +311,16 @@ def maximize(self,
# Keep track of total number of iterations
self.i += 1

# Notify about finished iteration
self.dispatch(Events.FIT_STEP_DONE)

# Print a final report if verbose active.
if self.verbose:
self.plog.print_summary()

# Notify about finished optimization
self.dispatch(Events.FIT_DONE)

def points_to_csv(self, file_name):
"""
After training all points for which we know target variable
Expand Down Expand Up @@ -353,3 +367,9 @@ def bounds(self):
def dim(self):
warnings.warn("use self.space.dim instead", DeprecationWarning)
return self.space.dim


class Events(object):
INIT_DONE = 'initialized'
FIT_STEP_DONE = 'fit_step_done'
FIT_DONE = 'fit_done'
36 changes: 36 additions & 0 deletions bayes_opt/observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Inspired/Taken from https://www.protechtraining.com/blog/post/879#simple-observer


class Observer:
def update(self, event, instance):
# Avoid circular import
from .bayesian_optimization import Events
if event is Events.INIT_DONE:
print("Initialization completed")
elif event is Events.FIT_STEP_DONE:
print("Optimization step finished, current max: ", instance.res['max'])
elif event is Events.FIT_DONE:
print("Optimization finished, maximum value at: ", instance.res['max'])


class Observable(object):
def __init__(self, events):
# maps event names to subscribers
# str -> dict
self.events = {event: dict()
for event in events}

def get_subscribers(self, event):
return self.events[event]

def register(self, event, who, callback=None):
if callback == None:
callback = getattr(who, 'update')
self.get_subscribers(event)[who] = callback

def unregister(self, event, who):
del self.get_subscribers(event)[who]

def dispatch(self, event):
for subscriber, callback in self.get_subscribers(event).items():
callback(event, self)
114 changes: 46 additions & 68 deletions examples/exploitation vs exploration.ipynb

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion tests/test_helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from sklearn.gaussian_process.kernels import Matern
from bayes_opt.helpers import UtilityFunction, acq_max, ensure_rng


def get_globals():
X = np.array([
[0.00, 0.00],
Expand Down
49 changes: 49 additions & 0 deletions tests/test_observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import unittest
from bayes_opt.observer import Observable


class TestObserver():
def __init__(self):
self.counter = 0

def update(self, event, instance):
self.counter += 1

class TestObserverPattern(unittest.TestCase):
def setUp(self):
events = ['a', 'b']
self.observable = Observable(events)
self.observer = TestObserver()

def test_get_subscribers(self):
self.observable.register('a', self.observer)
self.assertTrue(self.observer in self.observable.get_subscribers('a'))
self.assertTrue(len(self.observable.get_subscribers('a').keys()) == 1)
self.assertTrue(len(self.observable.get_subscribers('b').keys()) == 0)

def test_register(self):
self.observable.register('a', self.observer)
self.assertTrue(self.observer in self.observable.get_subscribers('a'))

def test_unregister(self):
self.observable.register('a', self.observer)
self.observable.unregister('a', self.observer)
self.assertTrue(self.observer not in self.observable.get_subscribers('a'))

def test_dispatch(self):
test_observer = TestObserver()
self.observable.register('b', test_observer)
self.observable.dispatch('b')
self.observable.dispatch('b')

self.assertTrue(test_observer.counter == 2)


if __name__ == '__main__':
r"""
CommandLine:
python tests/test_observer.py
"""
# unittest.main()
import pytest
pytest.main([__file__])

0 comments on commit 376f205

Please sign in to comment.