Skip to content

Commit

Permalink
fix validation batch limit
Browse files Browse the repository at this point in the history
  • Loading branch information
bastiscode committed Jul 7, 2024
1 parent bff228d commit e0e9155
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions python/text_utils/api/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,11 @@ def prepare_data_loader(
1,
train_cfg["batch_limit"] // info.world_size
)
limit_cfg = {
"batch_limit": train_cfg["batch_limit"]
}
if "batch_limit_type" in train_cfg:
limit_cfg["batch_limit_type"] = train_cfg["batch_limit_type"]

# pop some configs not used by the dataloader
max_length = train_cfg.pop("max_length")
Expand All @@ -720,6 +725,7 @@ def prepare_data_loader(

pipeline_cfg = train_cfg.pop("pipeline")

# for validation always turn off shuffling and turn on sorting
if isinstance(val_cfg, int):
# if validation is a split of the training set
train_limit = train_cfg.get("limit", None)
Expand All @@ -736,16 +742,17 @@ def prepare_data_loader(
distributed=(info.rank, info.world_size),
**train_cfg,
)
# for validation always turn off shuffling, turn on sorting, and
# specify the val limit
val_loader = prepare_data_loader(
pipeline_cfg,
*training,
seed=seed,
limit=val_cfg,
max_length=max_length,
distributed=(info.rank, info.world_size),
sort=True,
shuffle=False,
sort=True
**limit_cfg
)

elif isinstance(val_cfg, list):
Expand All @@ -756,7 +763,7 @@ def prepare_data_loader(
seed=seed,
max_length=max_length,
distributed=(info.rank, info.world_size),
**train_cfg,
**train_cfg
)
(
*validation,
Expand All @@ -769,10 +776,12 @@ def prepare_data_loader(
val_loader = prepare_data_loader(
pipeline_cfg,
*validation,
seed=seed,
max_length=max_length,
distributed=(info.rank, info.world_size),
sort=True,
shuffle=False,
sort=True
**limit_cfg
)

else:
Expand Down

0 comments on commit e0e9155

Please sign in to comment.