-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Restore from master weights (& allow restoring from a checkpoint of different precision) #702
Restore from master weights (& allow restoring from a checkpoint of different precision) #702
Conversation
…rministically by also saving RNG state of last update)
6387e66
to
4d77ece
Compare
@@ -61,6 +61,16 @@ __global__ void adamw_kernel3(Tp* params_memory, float* master_params_memory, Tg | |||
); | |||
} | |||
|
|||
template <typename Tp> | |||
__global__ void init_from_master_kernel(Tp* params_memory, float* master_params_memory, size_t num_parameters, | |||
ptrdiff_t w_stride, ptrdiff_t s_stride, unsigned int seed, bool check_identical) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unused check_identical
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oops, that used to be for debug lots that I removed, thanks
@@ -1532,7 +1554,7 @@ int main(int argc, char *argv[]) { | |||
gpt2_init_common(&model); | |||
if (resuming == 1) { | |||
// if `-y 1` was set, then we are resuming from the latest checkpoint | |||
gpt2_build_from_checkpoint(&model, filename_buffer); | |||
gpt2_build_from_checkpoint(&model, filename_buffer, true); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
at this point model.use_master_weights
is not yet initialized with use_master_weights
, that happens below, and is just the default (true). This variable is used inside gpt2_build_from_checkpoint
and this is probably a bug?
This is fully deterministic for new checkpoints where the new rng_state_last_update is saved, so that stochastic rounding from master weights is done with the exact same seeds (while restoring the actual final rng_state again afterwards, in case anything else changed it between that update and saving the checkpoint).
In the case where we are resuming from a checkpoint (not just a regular model file) and we have master weights, this simply skips loading the weights from the checkpoint completely, so it doesn't matter if they are not even the right number of bytes.
It should be useful to check if FP32 helps with runs exploding, and going forward it will allow FP8 runs to not care too much about what format the non-master weights are saved, so we don't need to worry about changes breaking compatibility etc...