From e0e9155763119ef1a81e4087e464303eb3088fc0 Mon Sep 17 00:00:00 2001 From: Sebastian Walter Date: Sun, 7 Jul 2024 18:20:24 +0200 Subject: [PATCH] fix validation batch limit --- python/text_utils/api/trainer.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/python/text_utils/api/trainer.py b/python/text_utils/api/trainer.py index 00aca3d..bf38100 100644 --- a/python/text_utils/api/trainer.py +++ b/python/text_utils/api/trainer.py @@ -704,6 +704,11 @@ def prepare_data_loader( 1, train_cfg["batch_limit"] // info.world_size ) + limit_cfg = { + "batch_limit": train_cfg["batch_limit"] + } + if "batch_limit_type" in train_cfg: + limit_cfg["batch_limit_type"] = train_cfg["batch_limit_type"] # pop some configs not used by the dataloader max_length = train_cfg.pop("max_length") @@ -720,6 +725,7 @@ def prepare_data_loader( pipeline_cfg = train_cfg.pop("pipeline") + # for validation always turn off shuffling and turn on sorting if isinstance(val_cfg, int): # if validation is a split of the training set train_limit = train_cfg.get("limit", None) @@ -736,16 +742,17 @@ def prepare_data_loader( distributed=(info.rank, info.world_size), **train_cfg, ) - # for validation always turn off shuffling, turn on sorting, and # specify the val limit val_loader = prepare_data_loader( pipeline_cfg, *training, + seed=seed, limit=val_cfg, max_length=max_length, distributed=(info.rank, info.world_size), + sort=True, shuffle=False, - sort=True + **limit_cfg ) elif isinstance(val_cfg, list): @@ -756,7 +763,7 @@ def prepare_data_loader( seed=seed, max_length=max_length, distributed=(info.rank, info.world_size), - **train_cfg, + **train_cfg ) ( *validation, @@ -769,10 +776,12 @@ def prepare_data_loader( val_loader = prepare_data_loader( pipeline_cfg, *validation, + seed=seed, max_length=max_length, distributed=(info.rank, info.world_size), + sort=True, shuffle=False, - sort=True + **limit_cfg ) else: