From 9c07e026806a64c9d5e3c2c2d6a7eda888cbbbda Mon Sep 17 00:00:00 2001 From: Marvin Poul Date: Mon, 24 Jun 2024 20:38:55 +0200 Subject: [PATCH] TrainingContainer: Add train_test_split method Calls the corresponding method on TrainingStorage. --- .../atomistics/job/trainingcontainer.py | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/pyiron_potentialfit/atomistics/job/trainingcontainer.py b/pyiron_potentialfit/atomistics/job/trainingcontainer.py index 5913c323..399fb607 100644 --- a/pyiron_potentialfit/atomistics/job/trainingcontainer.py +++ b/pyiron_potentialfit/atomistics/job/trainingcontainer.py @@ -624,6 +624,47 @@ def iter(self, *arrays, wrap_atoms=True): """ yield from self._container.iter(*arrays, wrap_atoms=wrap_atoms) + def train_test_split(self, train_size: float, seed=None, + run: bool = True, + train_name: str = None, test_name: str = None, project=None): + """ + Split into two random sub sets for training and testing. + + Args: + train_size (float): fraction of data points for the training set, must be within (0, 1) + seed (optional): how to initialize the RNG, see numpy.random.default_rng() for details, but an int will do + run (bool, optional): whether to immediately run and save the new containers + train_name (str, optional): default is the name of this container with suffix _train + test_name (str, optional): default is the name of this container with suffix _test + project (Project, optional): where to create the new containers; defaults to this project + + Returns: + (TrainingContainer, TrainingContainer): two training storages for training and testing + + Raises: + ValueError: if either `train_name` or `test_name` already exist in `project`. + ValueError: from :meth:`.TrainingStorage.train_test_split` if `train_size` cannot be realized + """ + if project is None: + project = self.project + if train_name is None: + train_name = f"{self.name}_train" + if test_name is None: + test_name = f"{self.name}_test" + if len({train_name, test_name}.intersection(project.list_nodes())) > 0: + raise ValueError("Target containers already exist!") + + train, test = self._container.train_test_split(train_size, seed) + trainc = project.create.job.TrainingContainer(train_name) + trainc.include_storage(train) + testc = project.create.job.TrainingContainer(test_name) + testc.include_storage(test) + + if run: + trainc.run() + testc.run() + return trainc, testc + class TrainingPlots(StructurePlots): """