Skip to content

Commit

Permalink
TrainingContainer: Add train_test_split method
Browse files Browse the repository at this point in the history
Calls the corresponding method on TrainingStorage.
  • Loading branch information
pmrv committed Jun 25, 2024
1 parent 03a4912 commit 9c07e02
Showing 1 changed file with 41 additions and 0 deletions.
41 changes: 41 additions & 0 deletions pyiron_potentialfit/atomistics/job/trainingcontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,47 @@ def iter(self, *arrays, wrap_atoms=True):
"""
yield from self._container.iter(*arrays, wrap_atoms=wrap_atoms)

def train_test_split(self, train_size: float, seed=None,
run: bool = True,
train_name: str = None, test_name: str = None, project=None):
"""
Split into two random sub sets for training and testing.
Args:
train_size (float): fraction of data points for the training set, must be within (0, 1)
seed (optional): how to initialize the RNG, see numpy.random.default_rng() for details, but an int will do
run (bool, optional): whether to immediately run and save the new containers
train_name (str, optional): default is the name of this container with suffix _train
test_name (str, optional): default is the name of this container with suffix _test
project (Project, optional): where to create the new containers; defaults to this project
Returns:
(TrainingContainer, TrainingContainer): two training storages for training and testing
Raises:
ValueError: if either `train_name` or `test_name` already exist in `project`.
ValueError: from :meth:`.TrainingStorage.train_test_split` if `train_size` cannot be realized
"""
if project is None:
project = self.project
if train_name is None:
train_name = f"{self.name}_train"
if test_name is None:
test_name = f"{self.name}_test"
if len({train_name, test_name}.intersection(project.list_nodes())) > 0:
raise ValueError("Target containers already exist!")

train, test = self._container.train_test_split(train_size, seed)
trainc = project.create.job.TrainingContainer(train_name)
trainc.include_storage(train)
testc = project.create.job.TrainingContainer(test_name)
testc.include_storage(test)

if run:
trainc.run()
testc.run()
return trainc, testc


class TrainingPlots(StructurePlots):
"""
Expand Down

0 comments on commit 9c07e02

Please sign in to comment.