Skip to content

Commit

Permalink
refact make train command
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Apr 9, 2024
1 parent 30b5169 commit 1285105
Show file tree
Hide file tree
Showing 2 changed files with 181 additions and 50 deletions.
179 changes: 129 additions & 50 deletions dpgen2/op/run_dp_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,123 @@
)


def _make_train_command(
dp_command,
train_script_name,
impl,
do_init_model,
init_model,
finetune_mode,
finetune_args,
init_model_with_finetune,
):
# find checkpoint
if impl == "tensorflow" and os.path.isfile("checkpoint"):
checkpoint = "model.ckpt"
elif impl == "pytorch" and len(glob.glob("model.ckpt-[0-9]*.pt")) > 0:
checkpoint = "model.ckpt-%s.pt" % max(
[int(f[11:-3]) for f in glob.glob("model.ckpt-[0-9]*.pt")]
)
else:
checkpoint = None
# case of restart
if checkpoint is not None:
command = dp_command + ["train", "--restart", checkpoint, train_script_name]
return command
# case of init model and finetune
assert checkpoint is None
do_init_model_or_train_init = do_init_model or finetune_mode == "train-init"
case_init_model = do_init_model_or_train_init and (not init_model_with_finetune)
case_finetune = finetune_mode == "finetune" or (
do_init_model_or_train_init and init_model_with_finetune
)
if case_init_model:
init_flag = "--init-frz-model" if impl == "tensorflow" else "--init-model"
command = dp_command + [
"train",
init_flag,
str(init_model),
train_script_name,
]
elif case_finetune:
command = (
dp_command
+ [
"train",
train_script_name,
"--finetune",
str(init_model),
]
+ finetune_args.split()
)
else:
command = dp_command + ["train", train_script_name]
return command


def _make_train_command_old(
dp_command,
train_script_name,
impl,
do_init_model,
init_model,
finetune_mode,
finetune_args,
init_model_with_finetune,
):
if impl == "tensorflow" and os.path.isfile("checkpoint"):
command = dp_command + [
"train",
"--restart",
"model.ckpt",
train_script_name,
]
elif impl == "pytorch" and len(glob.glob("model.ckpt-[0-9]*.pt")) > 0:
checkpoint = "model.ckpt-%s.pt" % max(
[int(f[11:-3]) for f in glob.glob("model.ckpt-[0-9]*.pt")]
)
command = dp_command + [
"train",
"--restart",
checkpoint,
train_script_name,
]
elif (
do_init_model or finetune_mode == "train-init"
) and not init_model_with_finetune:
if impl == "pytorch":
command = dp_command + [
"train",
"--init-model",
str(init_model),
train_script_name,
]
else:
command = dp_command + [
"train",
"--init-frz-model",
str(init_model),
train_script_name,
]
elif finetune_mode == "finetune" or (
(do_init_model or finetune_mode == "train-init") and init_model_with_finetune
):
command = (
dp_command
+ [
"train",
train_script_name,
"--finetune",
str(init_model),
]
+ finetune_args.split()
)
else:
command = dp_command + ["train", train_script_name]

return command


class RunDPTrain(OP):
r"""Execute a DP training task. Train and freeze a DP model.
Expand Down Expand Up @@ -141,6 +258,7 @@ def execute(
iter_data_new_exp = _expand_all_multi_sys_to_sys(iter_data[-1:])
iter_data_exp = iter_data_old_exp + iter_data_new_exp
work_dir = Path(task_name)
init_model_with_finetune = config["init_model_with_finetune"]

# update the input script
input_script = Path(task_path) / train_script_name
Expand Down Expand Up @@ -204,56 +322,17 @@ def clean_before_quit():
json.dump(train_dict, fp, indent=4)

# train model
if impl == "tensorflow" and os.path.isfile("checkpoint"):
command = dp_command + [
"train",
"--restart",
"model.ckpt",
train_script_name,
]
elif impl == "pytorch" and len(glob.glob("model.ckpt-[0-9]*.pt")) > 0:
checkpoint = "model.ckpt-%s.pt" % max(
[int(f[11:-3]) for f in glob.glob("model.ckpt-[0-9]*.pt")]
)
command = dp_command + [
"train",
"--restart",
checkpoint,
train_script_name,
]
elif (do_init_model or finetune_mode == "train-init") and not config[
"init_model_with_finetune"
]:
if impl == "pytorch":
command = dp_command + [
"train",
"--init-model",
str(init_model),
train_script_name,
]
else:
command = dp_command + [
"train",
"--init-frz-model",
str(init_model),
train_script_name,
]
elif finetune_mode == "finetune" or (
(do_init_model or finetune_mode == "train-init")
and config["init_model_with_finetune"]
):
command = (
dp_command
+ [
"train",
train_script_name,
"--finetune",
str(init_model),
]
+ finetune_args.split()
)
else:
command = dp_command + ["train", train_script_name]
command = _make_train_command_old(
dp_command,
train_script_name,
impl,
do_init_model,
init_model,
finetune_mode,
finetune_args,
init_model_with_finetune,
)

ret, out, err = run_command(command)
if ret != 0:
clean_before_quit()
Expand Down
52 changes: 52 additions & 0 deletions tests/op/test_run_dp_train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import json
import os
import shutil
Expand Down Expand Up @@ -35,6 +36,8 @@
from dpgen2.op.run_dp_train import (
RunDPTrain,
_get_data_size_of_all_mult_sys,
_make_train_command,
_make_train_command_old,
)

# isort: on
Expand Down Expand Up @@ -939,3 +942,52 @@ def test_exec_v2_empty_dir(self, mocked_run):
with open(out["script"]) as fp:
jdata = json.load(fp)
self.assertDictEqual(jdata, self.expected_odict_v2)


class TestMakeTrainCommand(unittest.TestCase):
def setUp(self):
pass

def tearDown(self):
pass

def test_consistency_impl(self):
dp_command = ["foo"]
train_script_name = "bar.json"
finetune_args = "piz"
init_model = "fox.pt"

# restart, impl, do_init, finetune_model, init_model_w_finetune
for res, ii, dim, fm, imwf in itertools.product(
[True, False],
["tensorflow", "pytorch"],
[True, False],
["finetune", "train-init"],
[True, False],
):
if res:
if ii == "tensorflow":
Path("checkpoint").write_text("")
else:
Path("model.ckpt-000.pt").write_text("")
Path("model.ckpt-001.pt").write_text("")

args = [
dp_command,
train_script_name,
ii,
dim,
init_model,
fm,
finetune_args,
imwf,
]
cmd_new = _make_train_command(*args)
cmd_old = _make_train_command_old(*args)

self.assertEqual(cmd_old, cmd_new)

if res:
for ii in ["model.ckpt-000.pt", "model.ckpt-001.pt", "checkpoint"]:
if os.path.exists(ii):
os.remove(ii)

0 comments on commit 1285105

Please sign in to comment.