Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

argcheck: restrict the type of elements in a list #1364

Merged
merged 3 commits into from
Oct 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 0 additions & 34 deletions dpgen/arginfo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Callable

from dargs import Argument

from dpgen.dispatcher.Dispatcher import mdata_arginfo
Expand Down Expand Up @@ -43,35 +41,3 @@ def general_mdata_arginfo(name: str, tasks: tuple[str]) -> Argument:
)
)
return Argument(name, dict, sub_fields=sub_fields, doc=doc_run_mdata)


def check_nd_list(dimesion: int = 2) -> Callable:
"""Return a method to check if the input is a nd list.

Parameters
----------
dimesion : int, default=2
dimension of the array

Returns
-------
callable
check function
"""

def check(value, dimension=dimesion):
if value is None:
# do not check null
return True
if dimension:
if not isinstance(value, list):
return False
if dimension > 1:
if not all(check(v, dimension=dimesion - 1) for v in value):
return False
return True

return check


errmsg_nd_list = "Must be a %d-dimension list."
30 changes: 15 additions & 15 deletions dpgen/data/arginfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def init_bulk_abacus_args() -> list[Argument]:
return [
Argument("relax_kpt", str, optional=True, doc=doc_relax_kpt),
Argument("md_kpt", str, optional=True, doc=doc_md_kpt),
Argument("atom_masses", list, optional=True, doc=doc_atom_masses),
Argument("atom_masses", list[float], optional=True, doc=doc_atom_masses),
]


Expand Down Expand Up @@ -105,25 +105,25 @@ def init_bulk_jdata_arginfo() -> Argument:
"init_bulk_jdata",
dict,
[
Argument("stages", list, optional=False, doc=doc_stages),
Argument("elements", list, optional=False, doc=doc_elements),
Argument("potcars", list, optional=True, doc=doc_potcars),
Argument("stages", list[int], optional=False, doc=doc_stages),
Argument("elements", list[str], optional=False, doc=doc_elements),
Argument("potcars", list[str], optional=True, doc=doc_potcars),
Argument("cell_type", str, optional=True, doc=doc_cell_type),
Argument("super_cell", list, optional=False, doc=doc_super_cell),
Argument("super_cell", list[int], optional=False, doc=doc_super_cell),
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
Argument(
"from_poscar", bool, optional=True, default=False, doc=doc_from_poscar
),
Argument("from_poscar_path", str, optional=True, doc=doc_from_poscar_path),
Argument("relax_incar", str, optional=True, doc=doc_relax_incar),
Argument("md_incar", str, optional=True, doc=doc_md_incar),
Argument("scale", list, optional=False, doc=doc_scale),
Argument("scale", list[float], optional=False, doc=doc_scale),
Argument("skip_relax", bool, optional=False, doc=doc_skip_relax),
Argument("pert_numb", int, optional=False, doc=doc_pert_numb),
Argument("pert_box", float, optional=False, doc=doc_pert_box),
Argument("pert_atom", float, optional=False, doc=doc_pert_atom),
Argument("md_nstep", int, optional=False, doc=doc_md_nstep),
Argument("coll_ndata", int, optional=False, doc=doc_coll_ndata),
Argument("type_map", list, optional=True, doc=doc_type_map),
Argument("type_map", list[str], optional=True, doc=doc_type_map),
],
sub_variants=init_bulk_variant_type_args(),
doc=doc_init_bulk,
Expand Down Expand Up @@ -171,11 +171,11 @@ def init_surf_jdata_arginfo() -> Argument:
"init_surf_jdata",
dict,
[
Argument("stages", list, optional=False, doc=doc_stages),
Argument("elements", list, optional=False, doc=doc_elements),
Argument("potcars", list, optional=True, doc=doc_potcars),
Argument("stages", list[int], optional=False, doc=doc_stages),
Argument("elements", list[str], optional=False, doc=doc_elements),
Argument("potcars", list[str], optional=True, doc=doc_potcars),
Argument("cell_type", str, optional=True, doc=doc_cell_type),
Argument("super_cell", list, optional=False, doc=doc_super_cell),
Argument("super_cell", list[int], optional=False, doc=doc_super_cell),
Argument(
"from_poscar", bool, optional=True, default=False, doc=doc_from_poscar
),
Expand All @@ -185,13 +185,13 @@ def init_surf_jdata_arginfo() -> Argument:
Argument("z_min", int, optional=True, doc=doc_z_min),
Argument("vacuum_max", float, optional=False, doc=doc_vacuum_max),
Argument("vacuum_min", float, optional=True, doc=doc_vacuum_min),
Argument("vacuum_resol", list, optional=False, doc=doc_vacuum_resol),
Argument("vacuum_resol", list[float], optional=False, doc=doc_vacuum_resol),
Argument("vacuum_numb", int, optional=True, doc=doc_vacuum_numb),
Argument("mid_point", float, optional=True, doc=doc_mid_point),
Argument("head_ratio", float, optional=True, doc=doc_head_ratio),
Argument("millers", list, optional=False, doc=doc_millers),
Argument("millers", list[list[int]], optional=False, doc=doc_millers),
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
Argument("relax_incar", str, optional=True, doc=doc_relax_incar),
Argument("scale", list, optional=False, doc=doc_scale),
Argument("scale", list[float], optional=False, doc=doc_scale),
Argument("skip_relax", bool, optional=False, doc=doc_skip_relax),
Argument("pert_numb", int, optional=False, doc=doc_pert_numb),
Argument("pert_box", float, optional=False, doc=doc_pert_box),
Expand Down Expand Up @@ -233,7 +233,7 @@ def init_reaction_jdata_arginfo() -> Argument:
"init_reaction_jdata",
dict,
[
Argument("type_map", list, doc=doc_type_map),
Argument("type_map", list[str], doc=doc_type_map),
Argument(
"reaxff",
dict,
Expand Down
93 changes: 52 additions & 41 deletions dpgen/generator/arginfo.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import textwrap
from typing import Union

from dargs import Argument, Variant

from dpgen.arginfo import check_nd_list, errmsg_nd_list, general_mdata_arginfo
from dpgen.arginfo import general_mdata_arginfo


def run_mdata_arginfo() -> Argument:
Expand All @@ -26,9 +27,13 @@ def basic_args() -> list[Argument]:
- 2: electron temperature as atom parameter."

return [
Argument("type_map", list, optional=False, doc=doc_type_map),
Argument("type_map", list[str], optional=False, doc=doc_type_map),
Argument(
"mass_map", [list, str], optional=True, default="auto", doc=doc_mass_map
"mass_map",
[list[float], str],
optional=True,
default="auto",
doc=doc_mass_map,
),
Argument("use_ele_temp", int, optional=True, default=0, doc=doc_use_ele_temp),
]
Expand All @@ -45,23 +50,29 @@ def data_args() -> list[Argument]:

return [
Argument("init_data_prefix", str, optional=True, doc=doc_init_data_prefix),
Argument("init_data_sys", list, optional=False, doc=doc_init_data_sys),
Argument("init_data_sys", list[str], optional=False, doc=doc_init_data_sys),
Argument(
"sys_format", str, optional=True, default="vasp/poscar", doc=doc_sys_format
),
Argument(
"init_batch_size", [list, str], optional=True, doc=doc_init_batch_size
"init_batch_size",
[list[Union[int, str]], str],
optional=True,
doc=doc_init_batch_size,
),
Argument("sys_configs_prefix", str, optional=True, doc=doc_sys_configs_prefix),
Argument(
"sys_configs",
list,
list[list[str]],
optional=False,
doc=doc_sys_configs,
extra_check=check_nd_list(2),
extra_check_errmsg=errmsg_nd_list % 2,
),
Argument("sys_batch_size", list, optional=True, doc=doc_sys_batch_size),
Argument(
"sys_batch_size",
list[Union[int, str]],
optional=True,
doc=doc_sys_batch_size,
),
]


Expand Down Expand Up @@ -115,7 +126,7 @@ def training_args() -> list[Argument]:
Argument("numb_models", int, optional=False, doc=doc_numb_models),
Argument(
"training_iter0_model_path",
list,
list[str],
optional=True,
doc=doc_training_iter0_model_path,
),
Expand Down Expand Up @@ -182,21 +193,21 @@ def training_args() -> list[Argument]:
),
Argument(
"model_devi_activation_func",
[None, list],
[None, list[list[str]]],
optional=True,
doc=doc_model_devi_activation_func,
),
Argument("srtab_file_path", str, optional=True, doc=doc_srtab_file_path),
Argument("one_h5", bool, optional=True, default=False, doc=doc_one_h5),
Argument(
"training_init_frozen_model",
list,
list[str],
optional=True,
doc=doc_training_init_frozen_model,
),
Argument(
"training_finetune_model",
list,
list[str],
optional=True,
doc=doc_training_finetune_model,
),
Expand All @@ -218,7 +229,7 @@ def model_devi_jobs_template_args() -> Argument:
Argument("plm", str, optional=True, doc=doc_template_plm),
]
return Argument(
"template", list, args, [], optional=True, repeat=False, doc=doc_template
"template", dict, args, [], optional=True, repeat=False, doc=doc_template
)


Expand All @@ -235,7 +246,7 @@ def model_devi_jobs_rev_mat_args() -> Argument:
Argument("plm", dict, optional=True, doc=doc_rev_mat_plm),
]
return Argument(
"rev_mat", list, args, [], optional=True, repeat=False, doc=doc_rev_mat
"rev_mat", dict, args, [], optional=True, repeat=False, doc=doc_rev_mat
)


Expand Down Expand Up @@ -264,9 +275,9 @@ def model_devi_jobs_args() -> list[Argument]:
model_devi_jobs_template_args(),
model_devi_jobs_rev_mat_args(),
Argument("sys_rev_mat", dict, optional=True, doc=doc_sys_rev_mat),
Argument("sys_idx", list, optional=False, doc=doc_sys_idx),
Argument("temps", list, optional=True, doc=doc_temps),
Argument("press", list, optional=True, doc=doc_press),
Argument("sys_idx", list[int], optional=False, doc=doc_sys_idx),
Argument("temps", list[float], optional=True, doc=doc_temps),
Argument("press", list[float], optional=True, doc=doc_press),
Argument("trj_freq", int, optional=False, doc=doc_trj_freq),
Argument("nsteps", int, optional=True, doc=doc_nsteps),
Argument("ensemble", str, optional=True, doc=doc_ensemble),
Expand Down Expand Up @@ -342,26 +353,26 @@ def model_devi_lmp_args() -> list[Argument]:
Argument("model_devi_skip", int, optional=False, doc=doc_model_devi_skip),
Argument(
"model_devi_f_trust_lo",
[float, list, dict],
[float, list[float], dict],
optional=False,
doc=doc_model_devi_f_trust_lo,
),
Argument(
"model_devi_f_trust_hi",
[float, list, dict],
[float, list[float], dict],
optional=False,
doc=doc_model_devi_f_trust_hi,
),
Argument(
"model_devi_v_trust_lo",
[float, list, dict],
[float, list[float], dict],
optional=True,
default=1e10,
doc=doc_model_devi_v_trust_lo,
),
Argument(
"model_devi_v_trust_hi",
[float, list, dict],
[float, list[float], dict],
optional=True,
default=1e10,
doc=doc_model_devi_v_trust_hi,
Expand Down Expand Up @@ -510,7 +521,7 @@ def model_devi_amber_args() -> list[Argument]:
repeat=True,
doc=doc_model_devi_jobs,
sub_fields=[
Argument("sys_idx", list, optional=False, doc=doc_sys_idx),
Argument("sys_idx", list[int], optional=False, doc=doc_sys_idx),
Argument("trj_freq", int, optional=False, doc=doc_trj_freq),
Argument(
"restart_from_iter", int, optional=True, doc=doc_restart_from_iter
Expand All @@ -520,32 +531,30 @@ def model_devi_amber_args() -> list[Argument]:
Argument("low_level", str, optional=False, doc=doc_low_level),
Argument("cutoff", float, optional=False, doc=doc_cutoff),
Argument("parm7_prefix", str, optional=True, doc=doc_parm7_prefix),
Argument("parm7", list, optional=False, doc=doc_parm7),
Argument("parm7", list[str], optional=False, doc=doc_parm7),
Argument("mdin_prefix", str, optional=True, doc=doc_mdin_prefix),
Argument("mdin", list, optional=False, doc=doc_mdin),
Argument("qm_region", list, optional=False, doc=doc_qm_region),
Argument("qm_charge", list, optional=False, doc=doc_qm_charge),
Argument("nsteps", list, optional=False, doc=doc_nsteps),
Argument("mdin", list[str], optional=False, doc=doc_mdin),
Argument("qm_region", list[str], optional=False, doc=doc_qm_region),
Argument("qm_charge", list[int], optional=False, doc=doc_qm_charge),
Argument("nsteps", list[int], optional=False, doc=doc_nsteps),
Argument(
"r",
list,
list[list[Union[float, list[float]]]],
optional=False,
doc=doc_r,
extra_check=check_nd_list(2),
extra_check_errmsg=errmsg_nd_list % 2,
),
Argument("disang_prefix", str, optional=True, doc=doc_disang_prefix),
Argument("disang", list, optional=False, doc=doc_disang),
Argument("disang", list[str], optional=False, doc=doc_disang),
# post model devi args
Argument(
"model_devi_f_trust_lo",
[float, list, dict],
[float, list[float], dict],
optional=False,
doc=doc_model_devi_f_trust_lo,
),
Argument(
"model_devi_f_trust_hi",
[float, list, dict],
[float, list[float], dict],
optional=False,
doc=doc_model_devi_f_trust_hi,
),
Expand Down Expand Up @@ -587,9 +596,11 @@ def fp_style_vasp_args() -> list[Argument]:

return [
Argument("fp_pp_path", str, optional=False, doc=doc_fp_pp_path),
Argument("fp_pp_files", list, optional=False, doc=doc_fp_pp_files),
Argument("fp_pp_files", list[str], optional=False, doc=doc_fp_pp_files),
Argument("fp_incar", str, optional=False, doc=doc_fp_incar),
Argument("fp_aniso_kspacing", list, optional=True, doc=doc_fp_aniso_kspacing),
Argument(
"fp_aniso_kspacing", list[float], optional=True, doc=doc_fp_aniso_kspacing
),
Argument("cvasp", bool, optional=True, doc=doc_cvasp),
Argument("fp_skip_bad_box", str, optional=True, doc=doc_fp_skip_bad_box),
]
Expand All @@ -610,13 +621,13 @@ def fp_style_abacus_args() -> list[Argument]:

return [
Argument("fp_pp_path", str, optional=False, doc=doc_fp_pp_path),
Argument("fp_pp_files", list, optional=False, doc=doc_fp_pp_files),
Argument("fp_orb_files", list, optional=True, doc=doc_fp_orb_files),
Argument("fp_pp_files", list[str], optional=False, doc=doc_fp_pp_files),
Argument("fp_orb_files", list[str], optional=True, doc=doc_fp_orb_files),
Argument("fp_incar", str, optional=True, doc=doc_fp_incar),
Argument("fp_kpt_file", str, optional=True, doc=doc_fp_kpt_file),
Argument("fp_dpks_descriptor", str, optional=True, doc=doc_fp_dpks_descriptor),
Argument("user_fp_params", dict, optional=True, doc=doc_user_fp_params),
Argument("k_points", list, optional=True, doc=doc_k_points),
Argument("k_points", list[int], optional=True, doc=doc_k_points),
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
]


Expand Down Expand Up @@ -646,7 +657,7 @@ def fp_style_gaussian_args() -> list[Argument]:
)

args = [
Argument("keywords", [str, list], optional=False, doc=doc_keywords),
Argument("keywords", [str, list[str]], optional=False, doc=doc_keywords),
Argument(
"multiplicity",
[int, str],
Expand Down Expand Up @@ -736,7 +747,7 @@ def fp_style_siesta_args() -> list[Argument]:
Argument("cluster_cutoff", float, optional=True, doc=doc_cluster_cutoff),
Argument("fp_params", dict, args, [], optional=False, doc=doc_fp_params_siesta),
Argument("fp_pp_path", str, optional=False, doc=doc_fp_pp_path),
Argument("fp_pp_files", list, optional=False, doc=doc_fp_pp_files),
Argument("fp_pp_files", list[str], optional=False, doc=doc_fp_pp_files),
]


Expand Down
2 changes: 1 addition & 1 deletion dpgen/simplify/arginfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def general_simplify_arginfo() -> Argument:

return [
Argument("labeled", bool, optional=True, default=False, doc=doc_labeled),
Argument("pick_data", [str, list], doc=doc_pick_data),
Argument("pick_data", [str, list[str]], doc=doc_pick_data),
Argument("init_pick_number", int, doc=doc_init_pick_number),
Argument("iter_pick_number", int, doc=doc_iter_pick_number),
Argument(
Expand Down
Loading