Skip to content

Commit

Permalink
Handle scientific notation in metric values (#953)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #953

Previously, this code would crash if the metric value had scientific notation (like `6.486097566010406e+18`), which can cause job crashes (like https://www.internalfb.com/mlhub/pipelines/runs/mast/f676386628-alandu-hpo_mast_3342b13186?job_attempt=7&version=0&tab=debug&env=PRODUCTION).

This updates the regex to handle scientific notation and adds a test case.

Reviewed By: diego-urgell, JKSenthil

Differential Revision: D67541312

fbshipit-source-id: d4bc637a9a0f93323cf55877025844f9b1651428
  • Loading branch information
alanhdu authored and facebook-github-bot committed Dec 20, 2024
1 parent 2e762a1 commit ba2fb54
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
9 changes: 9 additions & 0 deletions tests/utils/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,15 @@ def test_from_str(self) -> None:
metric_data=MetricData("mean_loss_squared", 0.0),
),
),
(
"foo/bar/epoch_1_train_step_2_eval_step_3_eval_loss=6.486097566010406e+18",
CheckpointPath(
"foo/bar/",
epoch=1,
step={Phase.TRAIN: 2, Phase.EVALUATE: 3},
metric_data=MetricData("eval_loss", 6.486097566010406e18),
),
),
]
for path, expected_ckpt in valid_paths:
parsed_ckpt = CheckpointPath.from_str(path)
Expand Down
6 changes: 4 additions & 2 deletions torchtnt/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,14 @@ class CheckpointPath:
- phase-aware and metric-aware- <dirpath>/epoch_<epoch>_train_step_<train_step>_eval_step_<eval_step>_<metric_name>=<metric_value>
"""

_FLOAT_REGEX: str = r"-?\d+\.?\d*(?:e[\+\-]\d+)?"

PHASE_NAIVE_REGEX: Pattern = re.compile(
r"^(.+)epoch_(\d+)_step_(\d+)(?:_(\w+)=(-?\d+\.?\d*))?\/?$"
rf"^(.+)epoch_(\d+)_step_(\d+)(?:_(\w+)=({_FLOAT_REGEX}))?\/?$"
)

PHASE_AWARE_REGEX: Pattern = re.compile(
r"^(.+)epoch_(\d+)(?:_train_step_(\d+))?(?:_eval_step_(\d+))?(?:_predict_step_(\d+))?(?:_(\w+)=(-?\d+\.?\d*))?\/?$"
rf"^(.+)epoch_(\d+)(?:_train_step_(\d+))?(?:_eval_step_(\d+))?(?:_predict_step_(\d+))?(?:_(\w+)=({_FLOAT_REGEX}))?\/?$"
)

def __init__(
Expand Down

0 comments on commit ba2fb54

Please sign in to comment.