Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore from master weights (& allow restoring from a checkpoint of different precision) #702

Merged
merged 6 commits into from
Jul 30, 2024

Conversation

ademeure
Copy link
Contributor

This is fully deterministic for new checkpoints where the new rng_state_last_update is saved, so that stochastic rounding from master weights is done with the exact same seeds (while restoring the actual final rng_state again afterwards, in case anything else changed it between that update and saving the checkpoint).

In the case where we are resuming from a checkpoint (not just a regular model file) and we have master weights, this simply skips loading the weights from the checkpoint completely, so it doesn't matter if they are not even the right number of bytes.

It should be useful to check if FP32 helps with runs exploding, and going forward it will allow FP8 runs to not care too much about what format the non-master weights are saved, so we don't need to worry about changes breaking compatibility etc...

@ademeure ademeure force-pushed the restore_from_master_weights branch from 6387e66 to 4d77ece Compare July 25, 2024 21:04
@@ -61,6 +61,16 @@ __global__ void adamw_kernel3(Tp* params_memory, float* master_params_memory, Tg
);
}

template <typename Tp>
__global__ void init_from_master_kernel(Tp* params_memory, float* master_params_memory, size_t num_parameters,
ptrdiff_t w_stride, ptrdiff_t s_stride, unsigned int seed, bool check_identical) {
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused check_identical?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oops, that used to be for debug lots that I removed, thanks

@@ -1532,7 +1554,7 @@ int main(int argc, char *argv[]) {
gpt2_init_common(&model);
if (resuming == 1) {
// if `-y 1` was set, then we are resuming from the latest checkpoint
gpt2_build_from_checkpoint(&model, filename_buffer);
gpt2_build_from_checkpoint(&model, filename_buffer, true);
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at this point model.use_master_weights is not yet initialized with use_master_weights, that happens below, and is just the default (true). This variable is used inside gpt2_build_from_checkpoint and this is probably a bug?

@karpathy karpathy merged commit b2ae847 into karpathy:master Jul 30, 2024
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants