Skip to content

Commit

Permalink
Merge pull request #272 from python-adaptive/pickle-LearnerND
Browse files Browse the repository at this point in the history
Make LearnerND pickleable
  • Loading branch information
basnijholt authored Sep 10, 2021
2 parents a612a0e + 700bbc8 commit 20e5986
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 13 deletions.
7 changes: 4 additions & 3 deletions adaptive/learner/base_learner.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
from contextlib import suppress
from copy import deepcopy

import cloudpickle

from adaptive.utils import _RequireAttrsABCMeta, load, save

Expand Down Expand Up @@ -191,7 +192,7 @@ def load(self, fname, compress=True):
self._set_data(data)

def __getstate__(self):
return deepcopy(self.__dict__)
return cloudpickle.dumps(self.__dict__)

def __setstate__(self, state):
self.__dict__ = state
self.__dict__ = cloudpickle.loads(state)
19 changes: 11 additions & 8 deletions adaptive/learner/learnerND.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import random
from collections import OrderedDict
from collections.abc import Iterable
from copy import deepcopy

import numpy as np
import scipy.spatial
Expand Down Expand Up @@ -319,6 +320,7 @@ def __init__(self, func, bounds, loss_per_simplex=None):
else:
self._bounds_points = sorted(list(map(tuple, itertools.product(*bounds))))
self._bbox = tuple(tuple(map(float, b)) for b in bounds)
self._interior = None

self.ndim = len(self._bbox)

Expand All @@ -337,6 +339,7 @@ def __init__(self, func, bounds, loss_per_simplex=None):
# for the output
self._min_value = None
self._max_value = None
self._old_scale = None
self._output_multiplier = (
1 # If we do not know anything, do not scale the values
)
Expand Down Expand Up @@ -453,7 +456,7 @@ def _simplex_exists(self, simplex):

def inside_bounds(self, point):
"""Check whether a point is inside the bounds."""
if hasattr(self, "_interior"):
if self._interior is not None:
return self._interior.find_simplex(point, tol=1e-8) >= 0
else:
eps = 1e-8
Expand Down Expand Up @@ -988,13 +991,6 @@ def plot_3D(self, with_triangulation=False):

return plotly.offline.iplot(fig)

def _get_data(self):
return self.data

def _set_data(self, data):
if data:
self.tell_many(*zip(*data.items()))

def _get_iso(self, level=0.0, which="surface"):
if which == "surface":
if self.ndim != 3 or self.vdim != 1:
Expand Down Expand Up @@ -1182,3 +1178,10 @@ def _get_plane_color(simplex):
opacity=opacity,
lighting=lighting,
)

def _get_data(self):
return deepcopy(self.__dict__)

def _set_data(self, state):
for k, v in state.items():
setattr(self, k, v)
2 changes: 2 additions & 0 deletions adaptive/tests/test_pickling.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
IntegratorLearner,
Learner1D,
Learner2D,
LearnerND,
SequenceLearner,
)
from adaptive.runner import simple
Expand Down Expand Up @@ -70,6 +71,7 @@ def balancing_learner(f, learner_type, learner_kwargs):
balancing_learner,
dict(learner_type=Learner1D, learner_kwargs=dict(bounds=(-1, 1))),
),
(LearnerND, dict(bounds=((-1, 1), (-1, 1), (-1, 1)))),
]

serializers = [(pickle, pickleable_f)]
Expand Down
5 changes: 3 additions & 2 deletions adaptive/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from contextlib import contextmanager
from itertools import product

import cloudpickle
from atomicwrites import AtomicWriter


Expand Down Expand Up @@ -46,7 +47,7 @@ def save(fname, data, compress=True):
if dirname:
os.makedirs(dirname, exist_ok=True)

blob = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
blob = cloudpickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
if compress:
blob = gzip.compress(blob)

Expand All @@ -58,7 +59,7 @@ def load(fname, compress=True):
fname = os.path.expanduser(fname)
_open = gzip.open if compress else open
with _open(fname, "rb") as f:
return pickle.load(f)
return cloudpickle.load(f)


def copy_docstring_from(other):
Expand Down

0 comments on commit 20e5986

Please sign in to comment.