Skip to content

Add distributed training log filtering for master rank #1641

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions mmengine/logging/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,28 @@ def filter(self, record: LogRecord) -> bool:
self.seen.add(record.msg)
return True
return False


class FilterMasterRank(logging.Filter):
"""Filter log messages from non-master processes in distributed training.

Args:
name (str): name of the filter. Defaults to 'mmengine'.
"""

def __init__(self, name: str = 'mmengine') -> None:
super().__init__(name)

def filter(self, record: LogRecord) -> bool:
"""Filter the log message of non-master processes.

Args:
record (LogRecord): The log record.

Returns:
bool: True if the log is from master process (rank 0).
"""
return int(os.environ.get("LOCAL_RANK", 0)) == 0


class MMFormatter(logging.Formatter):
Expand Down Expand Up @@ -221,6 +243,7 @@ def __init__(self,
else:
stream_handler.setLevel(logging.ERROR)
stream_handler.addFilter(FilterDuplicateWarning(logger_name))
stream_handler.addFilter(FilterMasterRank(logger_name))
self.handlers.append(stream_handler)

if log_file is not None:
Expand Down Expand Up @@ -267,6 +290,7 @@ def __init__(self,
MMFormatter(color=False, datefmt='%Y/%m/%d %H:%M:%S'))
file_handler.setLevel(log_level)
file_handler.addFilter(FilterDuplicateWarning(logger_name))
file_handler.addFilter(FilterMasterRank(logger_name))
self.handlers.append(file_handler)
self._log_file = log_file

Expand Down
4 changes: 2 additions & 2 deletions mmengine/runner/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from mmengine.logging import print_log
from mmengine.model import BaseTTAModel, is_model_wrapper
from mmengine.utils import (apply_to, deprecated_function, digit_version,
mkdir_or_exist)
mkdir_or_exist,)
from mmengine.utils.dl_utils import load_url

# `MMENGINE_HOME` is the highest priority directory to save checkpoints
Expand Down Expand Up @@ -810,6 +810,6 @@ def find_latest_checkpoint(path: str) -> Optional[str]:
with open(save_file) as f:
last_saved = f.read().strip()
else:
print_log('Did not find last_checkpoint to be resumed.')
print_log('Did not find last_checkpoint to be resumed.', logger='current')
last_saved = None
return last_saved
43 changes: 27 additions & 16 deletions mmengine/runner/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,27 +162,39 @@ def __init__(self, dataloader: DataLoader) -> None:
def __iter__(self):
return self

def __next__(self) -> Sequence[dict]:
def skip_iter(self, iter: int) -> None:
for _ in range(iter):
self._next_data(skip_loading=True)

def __next__(self) -> Union[Sequence[dict], None]:
return self._next_data()

def _next_data(self, skip_loading=False) -> Union[Sequence[dict], None]:
data = None
try:
data = next(self._iterator)
if skip_loading:
self._iterator._next_index()
else:
data = next(self._iterator)
except StopIteration:
print_log(
'Reach the end of the dataloader, it will be '
'restarted and continue to iterate. It is '
'recommended to use '
'`mmengine.dataset.InfiniteSampler` to enable the '
'dataloader to iterate infinitely.',
logger='current',
level=logging.WARNING)
"Reach the end of the dataloader, it will be "
"restarted and continue to iterate. It is "
"recommended to use "
"`mmengine.dataset.InfiniteSampler` to enable the "
"dataloader to iterate infinitely.",
logger="current",
level=logging.WARNING,
)
self._epoch += 1
if hasattr(self._dataloader, 'sampler') and hasattr(
self._dataloader.sampler, 'set_epoch'):
if hasattr(self._dataloader, "sampler") and hasattr(self._dataloader.sampler, "set_epoch"):
# In case the` _SingleProcessDataLoaderIter` has no sampler,
# or data loader uses `SequentialSampler` in Pytorch.
self._dataloader.sampler.set_epoch(self._epoch)

elif hasattr(self._dataloader, 'batch_sampler') and hasattr(
self._dataloader.batch_sampler.sampler, 'set_epoch'):
elif hasattr(self._dataloader, "batch_sampler") and hasattr(
self._dataloader.batch_sampler.sampler, "set_epoch"
):
# In case the` _SingleProcessDataLoaderIter` has no batch
# sampler. batch sampler in pytorch warps the sampler as its
# attributes.
Expand Down Expand Up @@ -280,8 +292,7 @@ def run(self) -> None:
'that has already been trained',
logger='current',
level=logging.WARNING)
for _ in range(self._iter):
next(self.dataloader_iterator)
self.dataloader_iterator.skip_iter(self._iter)
while self._iter < self._max_iters and not self.stop_training:
self.runner.model.train()

Expand All @@ -299,7 +310,7 @@ def run(self) -> None:
self.runner.call_hook('after_train')
return self.runner.model

def run_iter(self, data_batch: Sequence[dict]) -> None:
def run_iter(self, data_batch: Union[Sequence[dict], None]) -> None:
"""Iterate one mini-batch.

Args:
Expand Down