Skip to content

Commit

Permalink
include dataloader in eval epoch end during FIT (#957)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #957

Reviewed By: diego-urgell

Differential Revision: D67813439

fbshipit-source-id: f1fdbbbb0b9784d0e4bdf6dffd9f52eec1915bab
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Jan 6, 2025
1 parent de119c5 commit 1fe0a5d
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 1 deletion.
92 changes: 92 additions & 0 deletions tests/framework/callbacks/test_dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -748,6 +748,98 @@ def test_save_restore_fit_eval_every_n_epochs(self) -> None:
expected_keys_with_dl,
)

def test_save_restore_fit_save_every_n_eval_epochs(self) -> None:
input_dim = 2
dataset_len = 10
batch_size = 2

my_unit = DummyAutoUnit(module=nn.Linear(input_dim, 2))
my_unit.output_mean = DummyMeanMetric()
my_unit.loss = 0.1

train_dataloader = generate_dummy_stateful_dataloader(
dataset_len, input_dim, batch_size
)

eval_dataloader = generate_dummy_stateful_dataloader(
dataset_len, input_dim, batch_size
)

with tempfile.TemporaryDirectory() as temp_dir:
dcp_cb = DistributedCheckpointSaver(
temp_dir,
knob_options=KnobOptions(1),
save_every_n_eval_epochs=1,
best_checkpoint_config=BestCheckpointConfig(monitored_metric="loss"),
)

fit(
my_unit,
max_epochs=1,
evaluate_every_n_steps=1,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
callbacks=[dcp_cb],
)

generated_ckpts = os.listdir(temp_dir)
# fbvscode.set_trace()
# Since we are using FIT, the metric value should be included
expected_ckpts_to_dl_mapping = {
"epoch_0_train_step_1_eval_step_5_loss=0.1",
"epoch_0_train_step_2_eval_step_10_loss=0.1",
"epoch_0_train_step_3_eval_step_15_loss=0.1",
"epoch_0_train_step_4_eval_step_20_loss=0.1",
"epoch_0_train_step_5_eval_step_25_loss=0.1",
"epoch_1_train_step_5_eval_step_30_loss=0.1",
}
self.assertCountEqual(generated_ckpts, [*expected_ckpts_to_dl_mapping])

expected_keys = [
"module", # Both train and eval checkpoints save full app_state in fit
"optimizer",
"lr_scheduler",
"train_progress",
"eval_progress",
"predict_progress", # included because of AutoUnit
"output_mean",
"eval_dataloader",
"train_dataloader",
]

for ckpt_path in expected_ckpts_to_dl_mapping:
full_ckpt_path = os.path.join(temp_dir, ckpt_path)
expected_keys_with_dl = list(expected_keys)
storage_reader = FsspecReader(full_ckpt_path)
metadata = storage_reader.read_metadata()
if ckpt_path == "epoch_1_train_step_5_eval_step_30_loss=0.1":
# remove dataloader keys as final checkpoint wont have them
expected_keys_with_dl = expected_keys_with_dl[:-1]
appstate_keys = {
key.split(".")[1] for key in metadata.state_dict_metadata.keys()
}
self.assertCountEqual(
# Get base keys after the app_state wrapper
appstate_keys,
expected_keys_with_dl,
msg=f"key: {ckpt_path}, {expected_keys_with_dl=}, {appstate_keys=},",
)

# Now make sure that the same exact keys are used when restoring
with patch(
"torchtnt.framework.callbacks.dcp_saver.dcp.load"
) as mock_load:
DistributedCheckpointSaver.restore(
full_ckpt_path,
my_unit,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
)
self.assertCountEqual(
[*mock_load.call_args[0][0]["app_state"].state_dict().keys()],
expected_keys_with_dl,
)

def test_save_fit_eval_every_n_steps(self) -> None:
input_dim = 2

Expand Down
8 changes: 7 additions & 1 deletion torchtnt/framework/callbacks/dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
DefaultSavePlanner,
)
from torch.distributed.checkpoint.planner import LoadPlanner, SavePlanner
from torchtnt.framework.state import EntryPoint

try:
from torch.distributed.checkpoint.state_dict import _init_optim_state
Expand Down Expand Up @@ -171,7 +172,12 @@ def _checkpoint_impl(
]:
raise RuntimeError(f"Unexpected hook encountered '{hook}'")

intra_epoch = "step_end" in hook
# intra epoch when checkpointing during "_step_end" hook OR
# when checkpointing during "on_eval_epoch_end" hook and the entry point is fit
# since it is still intra epoch with respect to the train epoch
intra_epoch = "step_end" in hook or (
"on_eval_epoch_end" == hook and state.entry_point == EntryPoint.FIT
)
curr_snapshot_wait = hook == "on_train_end"

if planner is None:
Expand Down

0 comments on commit 1fe0a5d

Please sign in to comment.