Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: nan failure during training #3159

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open

fix: nan failure during training #3159

wants to merge 22 commits into from

Conversation

ori-kron-wis
Copy link
Collaborator

@ori-kron-wis ori-kron-wis commented Jan 21, 2025

Using SaveCheckpoint callback with on_exception can save the best optimal model up to the point it crashed due to Nan's in loss or gradients.
See an example (using Michal's data):

import scvi
from scvi.train._callbacks import SaveCheckpoint
from scvi.model import SCANVI
import pandas as pd
import numpy as np
import scanpy as sc
import torch
torch.set_float32_matmul_precision('high')

pd.set_option('display.max_rows', 50)
pd.set_option('display.max_columns', 50)
pd.set_option('display.width', 1000)
scvi.settings.seed = 0

early_stopping_kwargs = {
    'early_stopping': True,
    'early_stopping_monitor': 'elbo_validation', #'train_loss'
    'early_stopping_patience': 50,
    'early_stopping_mode': "min",
    'early_stopping_min_delta': 0.0,
    #'check_val_every_n_epoch': 1,
    #'check_finite': True,
}

ScviM = scvi.model.SCVI.load("/home/access/scvi_forScanVI4")

lvae = scvi.model.SCANVI.from_scvi_model(
                ScviM,
                unlabeled_category='unlabeled',
                labels_key="celltypes_steven2",
                linear_classifier=True,
            )
lvae.train(batch_size=1024,n_samples_per_label=100, max_epochs=500, gradient_clip_val=0,
           **early_stopping_kwargs , detect_anomaly=False, enable_checkpointing=True,
           callbacks=[SaveCheckpoint(monitor="elbo_validation", load_best_on_end=True)]) #breaks at epoch 58

#WE now want to laod this model and continue to train it
model = SCANVI.load("/home/access/.config/JetBrains/PyCharmCE2024.2/scratches/scvi_log/"
                    "2025-01-23_13-37-44_elbo_validation/"
                    "epoch=54-step=53295-elbo_validation=1255.7066650390625/",adata=ScviM.adata)
model.train(batch_size=2048,n_samples_per_label=50, max_epochs=500, gradient_clip_val=1,
           **early_stopping_kwargs , detect_anomaly=False, enable_checkpointing=True, plan_kwargs={"lr": 1e-2},
           callbacks=[SaveCheckpoint(monitor="elbo_validation", load_best_on_end=True)])

#running with detect_anomlay=True really slows down the whole thing
print("done")

We can then load it and continue training it (with or without parameters twicking)

@ori-kron-wis ori-kron-wis self-assigned this Jan 21, 2025
@ori-kron-wis ori-kron-wis added the on-merge: backport to 1.3.x on-merge: backport to 1.3.x label Jan 21, 2025
@ori-kron-wis ori-kron-wis added this to the scvi-tools 1.3 milestone Jan 21, 2025
@ori-kron-wis ori-kron-wis changed the title Ori nan crash fix fix: nan failure during training Jan 21, 2025
Copy link

codecov bot commented Jan 21, 2025

Codecov Report

Attention: Patch coverage is 35.71429% with 18 lines in your changes missing coverage. Please review.

Project coverage is 82.59%. Comparing base (2f1611c) to head (f977864).

Files with missing lines Patch % Lines
src/scvi/train/_callbacks.py 35.71% 18 Missing ⚠️

❗ There is a different number of reports uploaded between BASE (2f1611c) and HEAD (f977864). Click for more details.

HEAD has 53 uploads less than BASE
Flag BASE (2f1611c) HEAD (f977864)
56 3
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3159      +/-   ##
==========================================
- Coverage   89.43%   82.59%   -6.85%     
==========================================
  Files         185      185              
  Lines       16182    16210      +28     
==========================================
- Hits        14473    13389    -1084     
- Misses       1709     2821    +1112     
Files with missing lines Coverage Δ
src/scvi/train/_callbacks.py 77.05% <35.71%> (-8.16%) ⬇️

... and 27 files with indirect coverage changes

@ori-kron-wis ori-kron-wis added on-merge: backport to 1.2.x on-merge: backport to 1.2.x and removed on-merge: backport to 1.3.x on-merge: backport to 1.3.x labels Jan 30, 2025
)
else:
self.reason = (
"\033[31m[Warning] Exception occurred during training (Nan or Inf gradients). "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this \033[31m string?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it prints the reason in red to screen

pyro_param_store = best_state_dict.pop("pyro_param_store", None)
pl_module.module.load_state_dict(best_state_dict)
if pyro_param_store is not None:
# For scArches shapes are changed and we don't want to overwrite
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's this comment here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you tell me :-). it came from resolvi merge

# For scArches shapes are changed and we don't want to overwrite
# these changed shapes.
pyro.get_param_store().set_state(pyro_param_store)
print(self.reason)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we refine the printing here instead of two prints statements.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can do it in one line of course.

@@ -72,6 +72,8 @@ class Trainer(pl.Trainer):
and in 'max' mode it will stop when the quantity monitored has stopped increasing.
enable_progress_bar
Whether to enable or disable the progress bar.
gradient_clip_val
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it do anything?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its working, but does it help in avoiding gradients nan exception? in my test case, no.
it just changed the epoch it failed. Basically, its a common practice to use it when such a thing happens.

But, we can still use it via the trainer kwargs and not explicitly with the train function signature. Ill revert.

See velovi model, its part of it.


model.train(
max_epochs=5,
callbacks=[SaveCheckpoint(monitor="elbo_validation", load_best_on_end=True)],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this give a NaN?

Copy link
Collaborator Author

@ori-kron-wis ori-kron-wis Feb 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

of course not. its a placeholder for now.
I was hoping that we could come up with a unit test data that causes the on exception to work in a pytest env. If we do we can use it here and check we get what we expect.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
on-merge: backport to 1.2.x on-merge: backport to 1.2.x
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants