forked from bayesian-optimization/BayesianOptimization
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added observer pattern to allow callbacks
- Loading branch information
Showing
6 changed files
with
155 additions
and
72 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .bayesian_optimization import BayesianOptimization | ||
from .bayesian_optimization import BayesianOptimization, Events | ||
from .helpers import UtilityFunction | ||
from .observer import Observer | ||
|
||
__all__ = ["BayesianOptimization", "UtilityFunction"] | ||
__all__ = ["BayesianOptimization", "UtilityFunction", "Events", "Observer"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# Inspired/Taken from https://www.protechtraining.com/blog/post/879#simple-observer | ||
|
||
|
||
class Observer: | ||
def update(self, event, instance): | ||
# Avoid circular import | ||
from .bayesian_optimization import Events | ||
if event is Events.INIT_DONE: | ||
print("Initialization completed") | ||
elif event is Events.FIT_STEP_DONE: | ||
print("Optimization step finished, current max: ", instance.res['max']) | ||
elif event is Events.FIT_DONE: | ||
print("Optimization finished, maximum value at: ", instance.res['max']) | ||
|
||
|
||
class Observable(object): | ||
def __init__(self, events): | ||
# maps event names to subscribers | ||
# str -> dict | ||
self.events = {event: dict() | ||
for event in events} | ||
|
||
def get_subscribers(self, event): | ||
return self.events[event] | ||
|
||
def register(self, event, who, callback=None): | ||
if callback == None: | ||
callback = getattr(who, 'update') | ||
self.get_subscribers(event)[who] = callback | ||
|
||
def unregister(self, event, who): | ||
del self.get_subscribers(event)[who] | ||
|
||
def dispatch(self, event): | ||
for subscriber, callback in self.get_subscribers(event).items(): | ||
callback(event, self) |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import unittest | ||
from bayes_opt.observer import Observable | ||
|
||
|
||
class TestObserver(): | ||
def __init__(self): | ||
self.counter = 0 | ||
|
||
def update(self, event, instance): | ||
self.counter += 1 | ||
|
||
class TestObserverPattern(unittest.TestCase): | ||
def setUp(self): | ||
events = ['a', 'b'] | ||
self.observable = Observable(events) | ||
self.observer = TestObserver() | ||
|
||
def test_get_subscribers(self): | ||
self.observable.register('a', self.observer) | ||
self.assertTrue(self.observer in self.observable.get_subscribers('a')) | ||
self.assertTrue(len(self.observable.get_subscribers('a').keys()) == 1) | ||
self.assertTrue(len(self.observable.get_subscribers('b').keys()) == 0) | ||
|
||
def test_register(self): | ||
self.observable.register('a', self.observer) | ||
self.assertTrue(self.observer in self.observable.get_subscribers('a')) | ||
|
||
def test_unregister(self): | ||
self.observable.register('a', self.observer) | ||
self.observable.unregister('a', self.observer) | ||
self.assertTrue(self.observer not in self.observable.get_subscribers('a')) | ||
|
||
def test_dispatch(self): | ||
test_observer = TestObserver() | ||
self.observable.register('b', test_observer) | ||
self.observable.dispatch('b') | ||
self.observable.dispatch('b') | ||
|
||
self.assertTrue(test_observer.counter == 2) | ||
|
||
|
||
if __name__ == '__main__': | ||
r""" | ||
CommandLine: | ||
python tests/test_observer.py | ||
""" | ||
# unittest.main() | ||
import pytest | ||
pytest.main([__file__]) |