-
Notifications
You must be signed in to change notification settings - Fork 374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: nan failure during training #3159
base: main
Are you sure you want to change the base?
Conversation
If Nan exists, also added a gradient_clip_val param to train to be set >0 , it will help solve the nan issue
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #3159 +/- ##
==========================================
- Coverage 89.43% 82.59% -6.85%
==========================================
Files 185 185
Lines 16182 16210 +28
==========================================
- Hits 14473 13389 -1084
- Misses 1709 2821 +1112
|
) | ||
else: | ||
self.reason = ( | ||
"\033[31m[Warning] Exception occurred during training (Nan or Inf gradients). " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's this \033[31m string?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it prints the reason in red to screen
pyro_param_store = best_state_dict.pop("pyro_param_store", None) | ||
pl_module.module.load_state_dict(best_state_dict) | ||
if pyro_param_store is not None: | ||
# For scArches shapes are changed and we don't want to overwrite |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's this comment here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you tell me :-). it came from resolvi merge
src/scvi/train/_callbacks.py
Outdated
# For scArches shapes are changed and we don't want to overwrite | ||
# these changed shapes. | ||
pyro.get_param_store().set_state(pyro_param_store) | ||
print(self.reason) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we refine the printing here instead of two prints statements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can do it in one line of course.
src/scvi/train/_trainer.py
Outdated
@@ -72,6 +72,8 @@ class Trainer(pl.Trainer): | |||
and in 'max' mode it will stop when the quantity monitored has stopped increasing. | |||
enable_progress_bar | |||
Whether to enable or disable the progress bar. | |||
gradient_clip_val |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does it do anything?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its working, but does it help in avoiding gradients nan exception? in my test case, no.
it just changed the epoch it failed. Basically, its a common practice to use it when such a thing happens.
But, we can still use it via the trainer kwargs and not explicitly with the train function signature. Ill revert.
See velovi model, its part of it.
|
||
model.train( | ||
max_epochs=5, | ||
callbacks=[SaveCheckpoint(monitor="elbo_validation", load_best_on_end=True)], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this give a NaN?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
of course not. its a placeholder for now.
I was hoping that we could come up with a unit test data that causes the on exception to work in a pytest env. If we do we can use it here and check we get what we expect.
Using SaveCheckpoint callback with on_exception can save the best optimal model up to the point it crashed due to Nan's in loss or gradients.
See an example (using Michal's data):
We can then load it and continue training it (with or without parameters twicking)