Skip to content

Commit

Permalink
Merge pull request #13 from normal-computing/vi-diag-fix
Browse files Browse the repository at this point in the history
Small bug fix to make sure VI has proper inplace support
  • Loading branch information
SamDuffield authored Feb 9, 2024
2 parents ff08ff3 + ff8e8dd commit 0cf5dd7
Showing 1 changed file with 4 additions and 1 deletion.
5 changes: 4 additions & 1 deletion uqlib/vi/diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,10 @@ def update(
)(state.mean, sd_diag, batch, log_posterior, temperature, n_samples, stl)

updates, optimizer_state = optimizer.update(
nelbo_grads, state.optimizer_state, params=[state.mean, state.log_sd_diag]
nelbo_grads,
state.optimizer_state,
params=[state.mean, state.log_sd_diag],
inplace=inplace,
)
mean, log_sd_diag = torchopt.apply_updates(
(state.mean, state.log_sd_diag), updates, inplace=inplace
Expand Down

0 comments on commit 0cf5dd7

Please sign in to comment.