-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1286 from pyiron/executable_container
Implement ExecutableJobContainer
- Loading branch information
Showing
6 changed files
with
319 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
import cloudpickle | ||
import numpy as np | ||
from pyiron_base.jobs.job.template import TemplateJob | ||
|
||
|
||
class ExecutableContainerJob(TemplateJob): | ||
""" | ||
The ExecutableContainerJob is designed to wrap any kind of external executable into a pyiron job object by providing | ||
a write_input(input_dict, working_directory) and a collect_output(working_directory) function. | ||
Example: | ||
>>> def write_input(input_dict, working_directory="."): | ||
>>> with open(os.path.join(working_directory, "input_file"), "w") as f: | ||
>>> f.write(str(input_dict["energy"])) | ||
>>> | ||
>>> | ||
>>> def collect_output(working_directory="."): | ||
>>> with open(os.path.join(working_directory, "output_file"), "r") as f: | ||
>>> return {"energy": float(f.readline())} | ||
>>> | ||
>>> | ||
>>> from pyiron_base import Project | ||
>>> pr = Project("test") | ||
>>> pr.create_job_class( | ||
>>> class_name="CatJob", | ||
>>> write_input_funct=write_input, | ||
>>> collect_output_funct=collect_output, | ||
>>> default_input_dict={"energy": 1.0}, | ||
>>> executable_str="cat input_file > output_file", | ||
>>> ) | ||
>>> job = self.project.create.job.CatJob(job_name="job_test") | ||
>>> job.input["energy"] = 2.0 | ||
>>> job.run() | ||
>>> job.output | ||
""" | ||
|
||
def __init__(self, project, job_name): | ||
super().__init__(project, job_name) | ||
self._write_input_funct = None | ||
self._collect_output_funct = None | ||
|
||
def set_job_type( | ||
self, | ||
write_input_funct, | ||
executable_str, | ||
collect_output_funct, | ||
default_input_dict=None, | ||
): | ||
""" | ||
Set the pre-defined write_input() and collect_output() function plus a dictionary of default inputs and an | ||
executable string. | ||
Args: | ||
write_input_funct (callable): The write input function write_input(input_dict, working_directory) | ||
executable_str (str): Call to an external executable | ||
collect_output_funct (callable): The collect output function collect_output(working_directory) | ||
default_input_dict (dict/None): Default input for the newly created job class | ||
Returns: | ||
callable: Function which requires a project and a job_name as input and returns a job object | ||
""" | ||
self._write_input_funct = write_input_funct | ||
self.executable = executable_str | ||
self._collect_output_funct = collect_output_funct | ||
if default_input_dict is not None: | ||
self.input.update(default_input_dict) | ||
|
||
def write_input(self): | ||
self._write_input_funct( | ||
input_dict=self.input.to_builtin(), working_directory=self.working_directory | ||
) | ||
|
||
def collect_output(self): | ||
self.output.update( | ||
self._collect_output_funct(working_directory=self.working_directory) | ||
) | ||
self.to_hdf() | ||
|
||
def to_hdf(self, hdf=None, group_name=None): | ||
super().to_hdf(hdf=hdf, group_name=group_name) | ||
if self._write_input_funct is not None: | ||
self.project_hdf5["write_input_function"] = np.void( | ||
cloudpickle.dumps(self._write_input_funct) | ||
) | ||
if self._collect_output_funct is not None: | ||
self.project_hdf5["collect_output_function"] = np.void( | ||
cloudpickle.dumps(self._collect_output_funct) | ||
) | ||
|
||
def from_hdf(self, hdf=None, group_name=None): | ||
super().from_hdf(hdf=hdf, group_name=group_name) | ||
self._write_input_funct = cloudpickle.loads( | ||
self.project_hdf5["write_input_function"] | ||
) | ||
self._collect_output_funct = cloudpickle.loads( | ||
self.project_hdf5["collect_output_function"] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from pyiron_base.utils.instance import static_isinstance | ||
|
||
|
||
def create_job_factory( | ||
write_input_funct, | ||
executable_str, | ||
collect_output_funct, | ||
default_input_dict=None, | ||
): | ||
""" | ||
Create a new job class based on pre-defined write_input() and collect_output() function plus a dictionary of | ||
default inputs and an executable string. | ||
Args: | ||
write_input_funct (callable): The write input function write_input(input_dict, working_directory) | ||
executable_str (str): Call to an external executable | ||
collect_output_funct (callable): The collect output function collect_output(working_directory) | ||
default_input_dict (dict/None): Default input for the newly created job class | ||
Example: | ||
>>> def write_input(input_dict, working_directory="."): | ||
>>> with open(os.path.join(working_directory, "input_file"), "w") as f: | ||
>>> f.write(str(input_dict["energy"])) | ||
>>> | ||
>>> | ||
>>> def collect_output(working_directory="."): | ||
>>> with open(os.path.join(working_directory, "output_file"), "r") as f: | ||
>>> return {"energy": float(f.readline())} | ||
>>> | ||
>>> | ||
>>> from pyiron_base import Project, create_job_factory | ||
>>> pr = Project("test") | ||
>>> create_catjob = create_job_factory( | ||
>>> write_input_funct=write_input, | ||
>>> collect_output_funct=collect_output, | ||
>>> default_input_dict={"energy": 1.0}, | ||
>>> executable_str="cat input_file > output_file", | ||
>>> ) | ||
>>> job = create_catjob(project=pr, job_name="job_test") | ||
>>> job.input["energy"] = 2.0 | ||
>>> job.run() | ||
>>> job.output | ||
""" | ||
|
||
def job_factory(project, job_name): | ||
""" | ||
Create a job based on the previously defined write_input(), collect_output() and the executable string. | ||
Args: | ||
project (ProjectHDFio/ Project): ProjectHDFio instance which points to the HDF5 file the job is stored in | ||
job_name (str): name of the job, which has to be unique within the project | ||
Returns: | ||
pyiron_base.jobs.flex.executablecontainer.ExecutableContainerJob: pyiron job object | ||
""" | ||
if static_isinstance(project, "pyiron_base.project.generic.Project"): | ||
job = project.create.job.ExecutableContainerJob(job_name=job_name) | ||
elif static_isinstance(project, "pyiron_base.storage.hdfio.ProjectHDFio"): | ||
job = project.project.create.job.ExecutableContainerJob(job_name=job_name) | ||
else: | ||
raise TypeError( | ||
"Expected ProjectHDFio/ Project but recieved", type(project) | ||
) | ||
job.set_job_type( | ||
write_input_funct=write_input_funct, | ||
collect_output_funct=collect_output_funct, | ||
default_input_dict=default_input_dict, | ||
executable_str=executable_str, | ||
) | ||
return job | ||
|
||
return job_factory |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import os | ||
from pyiron_base._tests import TestWithProject | ||
from pyiron_base.jobs.job.jobtype import JOB_CLASS_DICT | ||
from pyiron_base import create_job_factory | ||
from pyiron_base.storage.hdfio import ProjectHDFio | ||
|
||
|
||
def write_input(input_dict, working_directory="."): | ||
with open(os.path.join(working_directory, "input_file"), "w") as f: | ||
f.write(str(input_dict["energy"])) | ||
|
||
|
||
def collect_output(working_directory="."): | ||
with open(os.path.join(working_directory, "output_file"), "r") as f: | ||
return {"energy": float(f.readline())} | ||
|
||
|
||
class TestExecutableContainer(TestWithProject): | ||
def test_create_job_class(self): | ||
energy_value = 2.0 | ||
self.project.create_job_class( | ||
class_name="CatJob", | ||
write_input_funct=write_input, | ||
collect_output_funct=collect_output, | ||
default_input_dict={"energy": 1.0}, | ||
executable_str="cat input_file > output_file", | ||
) | ||
job = self.project.create.job.CatJob(job_name="job_test") | ||
job.input["energy"] = energy_value | ||
job.run() | ||
self.assertEqual(job.output["energy"], energy_value) | ||
job_reload = self.project.load(job.job_name) | ||
self.assertEqual(job_reload.input["energy"], energy_value) | ||
self.assertEqual(job_reload.output["energy"], energy_value) | ||
del JOB_CLASS_DICT["CatJob"] | ||
|
||
def test_create_job_factory_with_project(self): | ||
energy_value = 2.0 | ||
create_catjob = create_job_factory( | ||
write_input_funct=write_input, | ||
collect_output_funct=collect_output, | ||
default_input_dict={"energy": 1.0}, | ||
executable_str="cat input_file > output_file", | ||
) | ||
job = create_catjob(project=self.project, job_name="job_test") | ||
job.input["energy"] = energy_value | ||
job.run() | ||
self.assertEqual(job.output["energy"], energy_value) | ||
job_reload = self.project.load(job.job_name) | ||
self.assertEqual(job_reload.input["energy"], energy_value) | ||
self.assertEqual(job_reload.output["energy"], energy_value) | ||
|
||
def test_create_job_factory_with_projecthdfio(self): | ||
energy_value = 2.0 | ||
create_catjob = create_job_factory( | ||
write_input_funct=write_input, | ||
collect_output_funct=collect_output, | ||
default_input_dict={"energy": 1.0}, | ||
executable_str="cat input_file > output_file", | ||
) | ||
job = create_catjob( | ||
project=ProjectHDFio(project=self.project, file_name="any.h5", h5_path=None, mode=None), | ||
job_name="job_test" | ||
) | ||
job.input["energy"] = energy_value | ||
job.run() | ||
self.assertEqual(job.output["energy"], energy_value) | ||
job_reload = self.project.load(job.job_name) | ||
self.assertEqual(job_reload.input["energy"], energy_value) | ||
self.assertEqual(job_reload.output["energy"], energy_value) | ||
|
||
def test_create_job_factory_typeerror(self): | ||
create_catjob = create_job_factory( | ||
write_input_funct=write_input, | ||
collect_output_funct=collect_output, | ||
executable_str="cat input_file > output_file", | ||
) | ||
with self.assertRaises(TypeError): | ||
create_catjob( | ||
project="project", | ||
job_name="job_test" | ||
) |