From 899a76f905b28b2f5c0db688e9d480050efbd0a9 Mon Sep 17 00:00:00 2001 From: Xinzijian Liu Date: Mon, 26 Aug 2024 10:42:59 +0800 Subject: [PATCH] Support keywords of pair_style for LmpTemplateTaskGroup (#254) ## Summary by CodeRabbit - **New Features** - Enhanced the input handling to dynamically accommodate additional parameters for the `pair_style` command. - Added new documentation for parameters to improve usability and clarity in the task group configuration. - **Bug Fixes** - Improved the handling of varying input scenarios by refining the functionality of the input model revision process. --------- Signed-off-by: zjgemi Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../task/lmp_template_task_group.py | 23 +++++++++++++++---- .../task/make_task_group_from_config.py | 10 +++++++- 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/dpgen2/exploration/task/lmp_template_task_group.py b/dpgen2/exploration/task/lmp_template_task_group.py index aeecde92..b82e1695 100644 --- a/dpgen2/exploration/task/lmp_template_task_group.py +++ b/dpgen2/exploration/task/lmp_template_task_group.py @@ -43,14 +43,19 @@ def set_lmp( plm_template_fname: Optional[str] = None, revisions: dict = {}, traj_freq: int = 10, + extra_pair_style_args: str = "", ) -> None: self.lmp_template = Path(lmp_template_fname).read_text().split("\n") self.revisions = revisions self.traj_freq = traj_freq + self.extra_pair_style_args = extra_pair_style_args self.lmp_set = True self.model_list = sorted([model_name_pattern % ii for ii in range(numb_models)]) self.lmp_template = revise_lmp_input_model( - self.lmp_template, self.model_list, self.traj_freq + self.lmp_template, + self.model_list, + self.traj_freq, + self.extra_pair_style_args, ) self.lmp_template = revise_lmp_input_dump(self.lmp_template, self.traj_freq) if plm_template_fname is not None: @@ -138,12 +143,20 @@ def find_only_one_key(lmp_lines, key): return found[0] -def revise_lmp_input_model(lmp_lines, task_model_list, trj_freq, deepmd_version="1"): +def revise_lmp_input_model( + lmp_lines, task_model_list, trj_freq, extra_pair_style_args="", deepmd_version="1" +): idx = find_only_one_key(lmp_lines, ["pair_style", "deepmd"]) + if extra_pair_style_args: + extra_pair_style_args = " " + extra_pair_style_args graph_list = " ".join(task_model_list) - lmp_lines[idx] = "pair_style deepmd %s out_freq %d out_file model_devi.out" % ( - graph_list, - trj_freq, + lmp_lines[idx] = ( + "pair_style deepmd %s out_freq %d out_file model_devi.out%s" + % ( + graph_list, + trj_freq, + extra_pair_style_args, + ) ) return lmp_lines diff --git a/dpgen2/exploration/task/make_task_group_from_config.py b/dpgen2/exploration/task/make_task_group_from_config.py index 37b7f8b4..c467fd8e 100644 --- a/dpgen2/exploration/task/make_task_group_from_config.py +++ b/dpgen2/exploration/task/make_task_group_from_config.py @@ -116,6 +116,7 @@ def lmp_template_task_group_args(): doc_plm_template_fname = "The file name of plumed input template" doc_revisions = "The revisions. Should be a dict providing the key - list of desired values pair. Key is the word to be replaced in the templates, and it may appear in both the lammps and plumed input templates. All values in the value list will be enmerated." doc_traj_freq = "The frequency of dumping configurations and thermodynamic states" + doc_extra_pair_style_args = "The extra arguments for pair_style" return [ Argument("conf_idx", list, optional=False, doc=doc_conf_idx, alias=["sys_idx"]), @@ -141,7 +142,7 @@ def lmp_template_task_group_args(): doc=doc_plm_template_fname, alias=["plm_template", "plm"], ), - Argument("revisions", dict, optional=True, default={}), + Argument("revisions", dict, optional=True, default={}, doc=doc_revisions), Argument( "traj_freq", int, @@ -150,6 +151,13 @@ def lmp_template_task_group_args(): doc=doc_traj_freq, alias=["t_freq", "trj_freq", "trj_freq"], ), + Argument( + "extra_pair_style_args", + str, + optional=True, + default="", + doc=doc_extra_pair_style_args, + ), ]