Skip to content

Commit

Permalink
add advice about dealing with non-invertable hessians (#875)
Browse files Browse the repository at this point in the history
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
  • Loading branch information
kylesayrs authored Oct 30, 2024
1 parent c62f2e3 commit a268a25
Showing 1 changed file with 14 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -164,13 +164,20 @@ def compress(
Losses = torch.zeros(self.rows, device=self.dev)

# compute inverse hessian in place to save memory
damp = percdamp * torch.mean(torch.diag(self.H))
diag = torch.arange(self.columns, device=self.dev)
self.H[diag, diag] += damp
self.H = torch.linalg.cholesky(self.H)
self.H = torch.cholesky_inverse(self.H)
self.H = torch.linalg.cholesky(self.H, upper=True)
Hinv = self.H
try:
damp = percdamp * torch.mean(torch.diag(self.H))
diag = torch.arange(self.columns, device=self.dev)
self.H[diag, diag] += damp
self.H = torch.linalg.cholesky(self.H)
self.H = torch.cholesky_inverse(self.H)
self.H = torch.linalg.cholesky(self.H, upper=True)
Hinv = self.H
except torch._C._LinAlgError:
raise ValueError(
"Failed to invert hessian due to numerical instability. Consider "
"increasing GPTQModifier.dampening_frac, increasing the number "
"of calibration samples, or shuffling the calibration dataset"
)

# See section 3.4 of https://arxiv.org/abs/2203.07259
for i1 in range(0, self.columns, blocksize):
Expand Down

0 comments on commit a268a25

Please sign in to comment.