Skip to content

Commit

Permalink
fix: minor
Browse files Browse the repository at this point in the history
  • Loading branch information
LutingWang committed Feb 14, 2024
1 parent c3f9512 commit b8c8f0c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
11 changes: 6 additions & 5 deletions todd/runners/callbacks/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..types import Memo
from .base import BaseCallback

# TODO: check if the model has grad
# TODO: check if the model has grad after each iteration


@CallbackRegistry.register_()
Expand All @@ -17,21 +17,22 @@ class CheckCallback(BaseCallback):
def before_run(self, memo: Memo) -> None:
if todd.get_rank() == 0:
requires_grad_parameters = [
name for name, parameter in
repr(name)
for name, parameter in
self.trainer.strategy.module.named_parameters()
if parameter.requires_grad
]
self.trainer.logger.debug(
'Requires grad parameters:\n'
'Requires grad parameters\n'
+ ', '.join(requires_grad_parameters),
)

training_modules = [
name for name, module in
repr(name) for name, module in
self.trainer.strategy.module.named_modules() if module.training
]
self.trainer.logger.debug(
'Training modules:\n' + ', '.join(training_modules),
'Training modules\n' + ', '.join(training_modules),
)

super().before_run(memo)
5 changes: 4 additions & 1 deletion todd/runners/callbacks/composed.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
'ComposedCallback',
]

from typing import Any, Iterable, Literal, Mapping
from typing import Any, Iterable, Iterator, Literal, Mapping

from ...base import CallbackRegistry, Config
from ...utils import PriorityQueue
Expand Down Expand Up @@ -30,6 +30,9 @@ def __init__(self, *args, callbacks: Iterable[Config], **kwargs) -> None:
def priority_queue(self) -> PriorityQueue[KT, BaseCallback]:
return self._priority_queue

def __iter__(self) -> Iterator[BaseCallback]:
return iter(self._priority_queue.queue)

def init(self, *args, **kwargs) -> None:
super().init(*args, **kwargs)
for c in self._priority_queue('init'):
Expand Down

0 comments on commit b8c8f0c

Please sign in to comment.