Skip to content

Commit

Permalink
Early stopper inconsistent devices fix (#949)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #949

`val` and `self._best_value` can have inconsitent devices in multi-GPU trainings which will fail at early stopper checks

Reviewed By: JKSenthil

Differential Revision: D66160768

fbshipit-source-id: 7f80900343bbdb80b118052452156a4fd5b67b73
  • Loading branch information
Pavel Levin authored and facebook-github-bot committed Jan 8, 2025
1 parent 1fe0a5d commit 6e6824c
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions torchtnt/utils/early_stop_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing_extensions import final, Literal

_log: logging.Logger = logging.getLogger(__name__)
_log.setLevel(logging.DEBUG)


@final
Expand Down Expand Up @@ -179,11 +180,13 @@ def check(self, val: Union[torch.Tensor, float, int]) -> bool:
divergence_threshold = divergence_threshold.to(val.device)
improvement_threshold = self.min_delta
if self._threshold_mode == "rel":
base_val = self._best_value if torch.isfinite(self._best_value) else 0.0
base_val = (
self._best_value.to(val.device)
if torch.isfinite(self._best_value)
else 0.0
)
improvement_threshold = self.min_delta.to(val.device) * base_val

improvement_threshold = improvement_threshold.to(val.device)

# Check finite
if self.check_finite and not torch.isfinite(val):
_log.debug(
Expand Down Expand Up @@ -212,7 +215,7 @@ def check(self, val: Union[torch.Tensor, float, int]) -> bool:

# Check if improvement is happening
if self._mode_func(
val - improvement_threshold, self._best_value.to(val.device)
val - improvement_threshold.to(val.device), self._best_value.to(val.device)
):
# Still improving
should_stop = False
Expand Down Expand Up @@ -259,9 +262,12 @@ def _improvement_message(self, val: torch.Tensor) -> str:
"""Formats a log message that informs the user about an improvement in the monitored score."""
if torch.isfinite(self._best_value):
improvement = (
torch.abs(self._best_value - val)
torch.abs(self._best_value.to(val.device) - val)
if self.threshold_mode == "abs"
else torch.abs((self._best_value - val) / (1.0 * self._best_value))
else torch.abs(
(self._best_value.to(val.device) - val)
/ (1.0 * self._best_value.to(val.device))
)
)
msg = (
f"Metric improved by {self.threshold_mode} {improvement} >="
Expand Down

0 comments on commit 6e6824c

Please sign in to comment.