diff --git a/dpgen2/op/run_dp_train.py b/dpgen2/op/run_dp_train.py index f3e25e12..42d702f6 100644 --- a/dpgen2/op/run_dp_train.py +++ b/dpgen2/op/run_dp_train.py @@ -43,6 +43,123 @@ ) +def _make_train_command( + dp_command, + train_script_name, + impl, + do_init_model, + init_model, + finetune_mode, + finetune_args, + init_model_with_finetune, +): + # find checkpoint + if impl == "tensorflow" and os.path.isfile("checkpoint"): + checkpoint = "model.ckpt" + elif impl == "pytorch" and len(glob.glob("model.ckpt-[0-9]*.pt")) > 0: + checkpoint = "model.ckpt-%s.pt" % max( + [int(f[11:-3]) for f in glob.glob("model.ckpt-[0-9]*.pt")] + ) + else: + checkpoint = None + # case of restart + if checkpoint is not None: + command = dp_command + ["train", "--restart", checkpoint, train_script_name] + return command + # case of init model and finetune + assert checkpoint is None + do_init_model_or_train_init = do_init_model or finetune_mode == "train-init" + case_init_model = do_init_model_or_train_init and (not init_model_with_finetune) + case_finetune = finetune_mode == "finetune" or ( + do_init_model_or_train_init and init_model_with_finetune + ) + if case_init_model: + init_flag = "--init-frz-model" if impl == "tensorflow" else "--init-model" + command = dp_command + [ + "train", + init_flag, + str(init_model), + train_script_name, + ] + elif case_finetune: + command = ( + dp_command + + [ + "train", + train_script_name, + "--finetune", + str(init_model), + ] + + finetune_args.split() + ) + else: + command = dp_command + ["train", train_script_name] + return command + + +def _make_train_command_old( + dp_command, + train_script_name, + impl, + do_init_model, + init_model, + finetune_mode, + finetune_args, + init_model_with_finetune, +): + if impl == "tensorflow" and os.path.isfile("checkpoint"): + command = dp_command + [ + "train", + "--restart", + "model.ckpt", + train_script_name, + ] + elif impl == "pytorch" and len(glob.glob("model.ckpt-[0-9]*.pt")) > 0: + checkpoint = "model.ckpt-%s.pt" % max( + [int(f[11:-3]) for f in glob.glob("model.ckpt-[0-9]*.pt")] + ) + command = dp_command + [ + "train", + "--restart", + checkpoint, + train_script_name, + ] + elif ( + do_init_model or finetune_mode == "train-init" + ) and not init_model_with_finetune: + if impl == "pytorch": + command = dp_command + [ + "train", + "--init-model", + str(init_model), + train_script_name, + ] + else: + command = dp_command + [ + "train", + "--init-frz-model", + str(init_model), + train_script_name, + ] + elif finetune_mode == "finetune" or ( + (do_init_model or finetune_mode == "train-init") and init_model_with_finetune + ): + command = ( + dp_command + + [ + "train", + train_script_name, + "--finetune", + str(init_model), + ] + + finetune_args.split() + ) + else: + command = dp_command + ["train", train_script_name] + + return command + + class RunDPTrain(OP): r"""Execute a DP training task. Train and freeze a DP model. @@ -141,6 +258,7 @@ def execute( iter_data_new_exp = _expand_all_multi_sys_to_sys(iter_data[-1:]) iter_data_exp = iter_data_old_exp + iter_data_new_exp work_dir = Path(task_name) + init_model_with_finetune = config["init_model_with_finetune"] # update the input script input_script = Path(task_path) / train_script_name @@ -204,56 +322,17 @@ def clean_before_quit(): json.dump(train_dict, fp, indent=4) # train model - if impl == "tensorflow" and os.path.isfile("checkpoint"): - command = dp_command + [ - "train", - "--restart", - "model.ckpt", - train_script_name, - ] - elif impl == "pytorch" and len(glob.glob("model.ckpt-[0-9]*.pt")) > 0: - checkpoint = "model.ckpt-%s.pt" % max( - [int(f[11:-3]) for f in glob.glob("model.ckpt-[0-9]*.pt")] - ) - command = dp_command + [ - "train", - "--restart", - checkpoint, - train_script_name, - ] - elif (do_init_model or finetune_mode == "train-init") and not config[ - "init_model_with_finetune" - ]: - if impl == "pytorch": - command = dp_command + [ - "train", - "--init-model", - str(init_model), - train_script_name, - ] - else: - command = dp_command + [ - "train", - "--init-frz-model", - str(init_model), - train_script_name, - ] - elif finetune_mode == "finetune" or ( - (do_init_model or finetune_mode == "train-init") - and config["init_model_with_finetune"] - ): - command = ( - dp_command - + [ - "train", - train_script_name, - "--finetune", - str(init_model), - ] - + finetune_args.split() - ) - else: - command = dp_command + ["train", train_script_name] + command = _make_train_command_old( + dp_command, + train_script_name, + impl, + do_init_model, + init_model, + finetune_mode, + finetune_args, + init_model_with_finetune, + ) + ret, out, err = run_command(command) if ret != 0: clean_before_quit() diff --git a/tests/op/test_run_dp_train.py b/tests/op/test_run_dp_train.py index d4366df3..79d020a0 100644 --- a/tests/op/test_run_dp_train.py +++ b/tests/op/test_run_dp_train.py @@ -1,3 +1,4 @@ +import itertools import json import os import shutil @@ -35,6 +36,8 @@ from dpgen2.op.run_dp_train import ( RunDPTrain, _get_data_size_of_all_mult_sys, + _make_train_command, + _make_train_command_old, ) # isort: on @@ -939,3 +942,52 @@ def test_exec_v2_empty_dir(self, mocked_run): with open(out["script"]) as fp: jdata = json.load(fp) self.assertDictEqual(jdata, self.expected_odict_v2) + + +class TestMakeTrainCommand(unittest.TestCase): + def setUp(self): + pass + + def tearDown(self): + pass + + def test_consistency_impl(self): + dp_command = ["foo"] + train_script_name = "bar.json" + finetune_args = "piz" + init_model = "fox.pt" + + # restart, impl, do_init, finetune_model, init_model_w_finetune + for res, ii, dim, fm, imwf in itertools.product( + [True, False], + ["tensorflow", "pytorch"], + [True, False], + ["finetune", "train-init"], + [True, False], + ): + if res: + if ii == "tensorflow": + Path("checkpoint").write_text("") + else: + Path("model.ckpt-000.pt").write_text("") + Path("model.ckpt-001.pt").write_text("") + + args = [ + dp_command, + train_script_name, + ii, + dim, + init_model, + fm, + finetune_args, + imwf, + ] + cmd_new = _make_train_command(*args) + cmd_old = _make_train_command_old(*args) + + self.assertEqual(cmd_old, cmd_new) + + if res: + for ii in ["model.ckpt-000.pt", "model.ckpt-001.pt", "checkpoint"]: + if os.path.exists(ii): + os.remove(ii)