Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Outlier detection: catch more outliers by not updating moving average with skipped updates #711

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added dev/cuda/advanced_copy_transpose
Binary file not shown.
22 changes: 20 additions & 2 deletions llmc/outlier_detector.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ reconsider this choice in the future, as the compute cost here is minimal.
#include <math.h>

// use compile-time constant for window size to avoid dynamic memory allocations
#define OUTLIER_DETECTOR_WINDOW_SIZE 128
#define OUTLIER_DETECTOR_WINDOW_SIZE 100

typedef struct {
double buffer[OUTLIER_DETECTOR_WINDOW_SIZE];
int count;
int index;
int skipped_in_a_row;
double sum;
double sum_sq;
} OutlierDetector;
Expand All @@ -33,13 +34,14 @@ void init_detector(OutlierDetector *detector) {
detector->sum_sq = 0.0;
}

double update_detector(OutlierDetector *detector, double new_value) {
double update_detector(OutlierDetector *detector, double new_value, double skip_update_threshold) {

if (detector->count < OUTLIER_DETECTOR_WINDOW_SIZE) {
// here we are still building up a window of observations
detector->buffer[detector->count] = new_value;
detector->sum += new_value;
detector->sum_sq += new_value * new_value;
detector->skipped_in_a_row = 0;
detector->count++;
return nan(""); // not enough data yet

Expand All @@ -65,6 +67,22 @@ double update_detector(OutlierDetector *detector, double new_value) {
}
double z = (new_value - mean) / std_dev;

if (skip_update_threshold != 0.0 && z > skip_update_threshold) {
// let's go back in time and pretend this never happened
// i.e. don't let bad outliers affect the threshold for detecting future outliers
// otherwise the detector will get less picky and accept things it really shouldn't!
// but we do update on consecutive outliers, to avoid getting stuck completely
detector->skipped_in_a_row++;
if (detector->skipped_in_a_row <= 1) {
detector->index = (detector->index - 1) % OUTLIER_DETECTOR_WINDOW_SIZE;
detector->sum += old_value - new_value;
detector->sum_sq += (old_value * old_value) - (new_value * new_value);
detector->buffer[detector->index] = old_value;
}
} else {
detector->skipped_in_a_row = 0;
}

return z;
}
}
39 changes: 29 additions & 10 deletions train_gpt2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1394,8 +1394,8 @@ int main(int argc, char *argv[]) {
int warmup_iterations = 0;
float final_learning_rate_frac = 1.0f; // final fraction of learning rate, at end of training
float weight_decay = 0.0f;
float skip_update_lossz = 0.0f; // skip update if loss goes above this in zscore
float skip_update_gradz = 0.0f; // skip update if grad_norm goes above this in zscore
float skip_update_lossz = 8.0f; // skip update if loss goes above this in zscore
float skip_update_gradz = 8.0f; // skip update if grad_norm goes above this in zscore
int val_loss_every = 20; // every how many steps do we eval validation loss?
int val_max_steps = 20; // how many batches max do we eval for validation loss?
int sample_every = 20; // every how many steps to do inference?
Expand Down Expand Up @@ -1791,23 +1791,42 @@ int main(int argc, char *argv[]) {
// backward pass. all model params accumulate gradients with += inside this inner loop
gpt2_backward_and_reduce(&model, train_loader.inputs, train_loader.targets, grad_accum_steps, micro_step);
}
float zloss = (float)(update_detector(&loss_outlier_detector, (double)model.mean_loss)); // loss z-score
float zloss = (float)(update_detector(&loss_outlier_detector, (double)model.mean_loss, (double)skip_update_lossz)); // loss z-score
// fetch the next learning rate
float step_learning_rate = get_learning_rate(&lr_scheduler, step);
// calculate the gradient norm and how much we wish to scale the gradient
float grad_norm = gpt2_calculate_grad_norm(&model, &multi_gpu_config);
float zgrad = (float)(update_detector(&grad_norm_outlier_detector, (double)grad_norm)); // grad z-score
float zgrad = (float)(update_detector(&grad_norm_outlier_detector, (double)grad_norm, (double)skip_update_gradz)); // grad z-score

float beta1 = 0.9f;
float beta2 = 0.95f;
float grad_scale = 1.0f;
// update the model parameters
if (isfinite(zloss) && skip_update_lossz != 0.0f && zloss > skip_update_lossz) {
printf0("skipping update due to loss z-score of %f\n", zloss);
printf0("mostly skipping update due to loss z-score of %f\n", zloss);
step_learning_rate *= 0.1f;
weight_decay *= 0.2f;
beta1 = 0.95f; // same as beta2
} else if (isfinite(zgrad) && skip_update_gradz != 0.0f && zgrad > skip_update_gradz) {
printf0("skipping update due to grad z-score of %f\n", zgrad);
printf0("mostly skipping update due to grad z-score of %f\n", zgrad);
step_learning_rate *= 0.1f;
weight_decay *= 0.2f;
beta1 = 0.95f; // same as beta2
} else if (isfinite(zgrad) && zgrad > 2.0f) {
float lr_ratio = min(1.0f, 3.5f / zgrad); // 2.0 to 3.5 only reduces beta2
printf0("reducing beta2 to 0.9 and lr/wd by %.3f due to grad z-score of %f\n", lr_ratio, zgrad);
step_learning_rate *= lr_ratio;
weight_decay *= lr_ratio;
beta2 = 0.9f; // same as beta1
} else {
// clip the gradient norm to a maximum value
float grad_clip = 1.0f;
float grad_scale = (grad_norm > grad_clip) ? grad_clip / grad_norm : 1.0f;
gpt2_update(&model, step_learning_rate, 0.9f, 0.95f, 1e-8f, weight_decay, grad_scale, step+1, &multi_gpu_config);
// clip the gradient, relevant for early steps only where norm is >1.0, improves learning, due to smaller m/v
// which is a bit silly, but don't want a regression just because of removing this for now...
// unlike before, we won't clip if grad z-score is high, so it won't result in instability later in the run
grad_scale = (grad_norm > 1.0f) ? 1.0f / grad_norm : 1.0f;
}

gpt2_update(&model, step_learning_rate, beta1, beta2, 1e-8f, weight_decay, grad_scale, step+1, &multi_gpu_config);

cudaCheck(cudaEventRecord(end));
cudaCheck(cudaEventSynchronize(end)); // wait for the end event to finish to get correct timings
// --------------- TRAINING SECTION END -------------------
Expand Down