Skip to content

Commit

Permalink
Ensure neighbor compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
pmrv committed Jun 25, 2024
1 parent 9c07e02 commit aa8edf8
Showing 1 changed file with 21 additions and 4 deletions.
25 changes: 21 additions & 4 deletions pyiron_potentialfit/atomistics/job/trainingcontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,11 +359,11 @@ def train_test_split(self, train_size, seed=None):
# somewhat inefficient, but probably good enough for normal training set sizes
idx = np.arange(len(self))
rng.shuffle(idx)
test_idx = idx[:brk]
train_idx = idx[brk:]
train_idx = idx[:brk]
test_idx = idx[brk:]
return (
self.sample(lambda f, i: i in train_idx),
self.sample(lambda f, i: i in test_idx),
self.sample(lambda f, i: i in train_idx)
)


Expand Down Expand Up @@ -458,6 +458,15 @@ def include_storage(self, storage: TrainingStorage):
Args:
storage (:class:`.TrainingStorage`): structures to add
"""
# Check whether storage defines neighbor information and whether it is
# compatible without our input
info = storage.has_array("indices")
if info is not None and info["shape"][0] != self.input.num_neighbors:
storage = storage.copy()
storage.del_array("indices")
storage.del_array("distances")
storage.del_array("vecs")
storage.del_array("shells")
self._container.extend(storage)

def _get_structure(self, frame=-1, wrap_atoms=True):
Expand Down Expand Up @@ -656,8 +665,16 @@ def train_test_split(self, train_size: float, seed=None,

train, test = self._container.train_test_split(train_size, seed)
trainc = project.create.job.TrainingContainer(train_name)
trainc.include_storage(train)
testc = project.create.job.TrainingContainer(test_name)

# make sure that the split containers do not try to override any
# neighbor information we may have saved before
trainc.input.save_neighbors = self.input.save_neighbors
trainc.input.num_neighbors = self.input.num_neighbors
testc.input.save_neighbors = self.input.save_neighbors
testc.input.num_neighbors = self.input.num_neighbors

trainc.include_storage(train)
testc.include_storage(test)

if run:
Expand Down

0 comments on commit aa8edf8

Please sign in to comment.