From f64130e7f7800b94f5b235390fd803f89279329a Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Thu, 4 Apr 2024 15:54:51 -0700 Subject: [PATCH] Expose timestamp stats (#2794) Summary: Pull Request resolved: https://github.com/pytorch/executorch/pull/2794 Reviewed By: shoumikhin Differential Revision: D55604786 Pulled By: kirklandsign fbshipit-source-id: 4ab1bbb13b746903547d4829f87cda81628ee006 --- examples/models/llama2/runner/runner.cpp | 110 +++++++++++++---------- examples/models/llama2/runner/runner.h | 60 ++++++------- 2 files changed, 90 insertions(+), 80 deletions(-) diff --git a/examples/models/llama2/runner/runner.cpp b/examples/models/llama2/runner/runner.cpp index 2808aa3c9b..af7c25ec67 100644 --- a/examples/models/llama2/runner/runner.cpp +++ b/examples/models/llama2/runner/runner.cpp @@ -29,6 +29,8 @@ namespace torch::executor { namespace { static constexpr auto kTopp = 0.9f; +void printReport(const Runner::Stats& stats); +std::string statsToJsonString(const Runner::Stats& stats); } // namespace Runner::Runner( @@ -208,20 +210,21 @@ Result Runner::run_model_step( Error Runner::generate( const std::string& prompt, int32_t seq_len, - std::function callback) { + std::function token_callback, + std::function stats_callback) { // Prepare the inputs. // Use ones-initialized inputs. ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); if (!is_loaded()) { - timers_.model_load_start_ms = util::time_in_ms(); + stats_.model_load_start_ms = util::time_in_ms(); ET_CHECK_OK_OR_RETURN_ERROR(load()); - timers_.model_load_end_ms = util::time_in_ms(); + stats_.model_load_end_ms = util::time_in_ms(); } // First token time only measures the time it takes to encode the prompt and // return a response token. - timers_.inference_start_ms = util::time_in_ms(); + stats_.inference_start_ms = util::time_in_ms(); shouldStop_ = false; // encode the (string) prompt into tokens sequence @@ -319,9 +322,9 @@ Error Runner::generate( run_model_step(cur_token, tokens_managed, start_pos_managed, seq_len); if (pos == num_prompt_tokens) { - timers_.first_token_ms = util::time_in_ms(); + stats_.first_token_ms = util::time_in_ms(); } else if (pos == num_prompt_tokens - 1) { - timers_.prompt_eval_end_ms = util::time_in_ms(); + stats_.prompt_eval_end_ms = util::time_in_ms(); } ET_CHECK_OK_OR_RETURN_ERROR(logits_res.error()); @@ -345,7 +348,7 @@ Error Runner::generate( "Unsupported dtype output %hhd", static_cast(logits_tensor.scalar_type())); } - timers_.aggregate_sampling_time_ms += + stats_.aggregate_sampling_time_ms += util::time_in_ms() - sample_start_time_ms; // advance the state machine @@ -364,8 +367,8 @@ Error Runner::generate( util::safe_printf(piece); fflush(stdout); - if (callback) { - callback(piece); + if (token_callback) { + token_callback(piece); } if (shouldStop_) { @@ -379,93 +382,102 @@ Error Runner::generate( break; } } - timers_.inference_end_ms = util::time_in_ms(); + stats_.inference_end_ms = util::time_in_ms(); printf("\n"); if (pos == seq_len) { ET_LOG(Info, "Sequence length (%i tokens) reached!", seq_len); } - timers_.printReport(num_prompt_tokens, pos - num_prompt_tokens); + stats_.num_prompt_tokens = num_prompt_tokens; + stats_.num_generated_tokens = pos - num_prompt_tokens; + printReport(stats_); + if (stats_callback) { + stats_callback(stats_); + } delete[] prompt_tokens; return Error::Ok; } -void Runner::TimeStamps::printReport( - const int64_t& num_prompt_tokens, - const int64_t& num_generated_tokens) { - printf( - "PyTorchObserver %s\n", - toJsonString(num_prompt_tokens, num_generated_tokens).c_str()); +namespace { +void printReport(const Runner::Stats& stats) { + printf("PyTorchObserver %s\n", statsToJsonString(stats).c_str()); ET_LOG( Info, "\tPrompt Tokens: %" PRIu64 " Generated Tokens: %" PRIu64, - num_prompt_tokens, - num_generated_tokens); + stats.num_prompt_tokens, + stats.num_generated_tokens); ET_LOG( Info, "\tModel Load Time:\t\t%f (seconds)", - ((double)(model_load_end_ms - model_load_start_ms) / - SCALING_FACTOR_UNITS_PER_SECOND)); - double inference_time_ms = (double)(inference_end_ms - inference_start_ms); + ((double)(stats.model_load_end_ms - stats.model_load_start_ms) / + stats.SCALING_FACTOR_UNITS_PER_SECOND)); + double inference_time_ms = + (double)(stats.inference_end_ms - stats.inference_start_ms); ET_LOG( Info, "\tTotal inference time:\t\t%f (seconds)\t\t Rate: \t%f (tokens/second)", - inference_time_ms / SCALING_FACTOR_UNITS_PER_SECOND, + inference_time_ms / stats.SCALING_FACTOR_UNITS_PER_SECOND, - (num_generated_tokens) / (double)(inference_end_ms - inference_start_ms) * - SCALING_FACTOR_UNITS_PER_SECOND); - double prompt_eval_time = (double)(prompt_eval_end_ms - inference_start_ms); + (stats.num_generated_tokens) / + (double)(stats.inference_end_ms - stats.inference_start_ms) * + stats.SCALING_FACTOR_UNITS_PER_SECOND); + double prompt_eval_time = + (double)(stats.prompt_eval_end_ms - stats.inference_start_ms); ET_LOG( Info, "\t\tPrompt evaluation:\t%f (seconds)\t\t Rate: \t%f (tokens/second)", - prompt_eval_time / SCALING_FACTOR_UNITS_PER_SECOND, - (num_prompt_tokens) / prompt_eval_time * SCALING_FACTOR_UNITS_PER_SECOND); + prompt_eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND, + (stats.num_prompt_tokens) / prompt_eval_time * + stats.SCALING_FACTOR_UNITS_PER_SECOND); - double eval_time = (double)(inference_end_ms - prompt_eval_end_ms); + double eval_time = + (double)(stats.inference_end_ms - stats.prompt_eval_end_ms); ET_LOG( Info, "\t\tGenerated %" PRIu64 " tokens:\t%f (seconds)\t\t Rate: \t%f (tokens/second)", - num_generated_tokens, - eval_time / SCALING_FACTOR_UNITS_PER_SECOND, - num_generated_tokens / eval_time * SCALING_FACTOR_UNITS_PER_SECOND); + stats.num_generated_tokens, + eval_time / stats.SCALING_FACTOR_UNITS_PER_SECOND, + stats.num_generated_tokens / eval_time * + stats.SCALING_FACTOR_UNITS_PER_SECOND); // Time to first token is measured from the start of inference, excluding // model load time. ET_LOG( Info, "\tTime to first generated token:\t%f (seconds)", - ((double)(first_token_ms - inference_start_ms) / - SCALING_FACTOR_UNITS_PER_SECOND)); + ((double)(stats.first_token_ms - stats.inference_start_ms) / + stats.SCALING_FACTOR_UNITS_PER_SECOND)); ET_LOG( Info, "\tSampling time over %" PRIu64 " tokens:\t%f (seconds)", - num_prompt_tokens + num_generated_tokens, - (double)aggregate_sampling_time_ms / SCALING_FACTOR_UNITS_PER_SECOND); + stats.num_prompt_tokens + stats.num_generated_tokens, + (double)stats.aggregate_sampling_time_ms / + stats.SCALING_FACTOR_UNITS_PER_SECOND); } -const std::string Runner::TimeStamps::toJsonString( - const int64_t& num_prompt_tokens, - const int64_t& num_generated_tokens) { +std::string statsToJsonString(const Runner::Stats& stats) { std::stringstream ss; - ss << "{\"prompt_tokens\":" << num_prompt_tokens << "," - << "\"generated_tokens\":" << num_generated_tokens << "," - << "\"model_load_start_ms\":" << model_load_start_ms << "," - << "\"model_load_end_ms\":" << model_load_end_ms << "," - << "\"inference_start_ms\":" << inference_start_ms << "," - << "\"inference_end_ms\":" << inference_end_ms << "," - << "\"prompt_eval_end_ms\":" << prompt_eval_end_ms << "," - << "\"first_token_ms\":" << first_token_ms << "," - << "\"aggregate_sampling_time_ms\":" << aggregate_sampling_time_ms << "," + ss << "{\"prompt_tokens\":" << stats.num_prompt_tokens << "," + << "\"generated_tokens\":" << stats.num_generated_tokens << "," + << "\"model_load_start_ms\":" << stats.model_load_start_ms << "," + << "\"model_load_end_ms\":" << stats.model_load_end_ms << "," + << "\"inference_start_ms\":" << stats.inference_start_ms << "," + << "\"inference_end_ms\":" << stats.inference_end_ms << "," + << "\"prompt_eval_end_ms\":" << stats.prompt_eval_end_ms << "," + << "\"first_token_ms\":" << stats.first_token_ms << "," + << "\"aggregate_sampling_time_ms\":" << stats.aggregate_sampling_time_ms + << "," << "\"SCALING_FACTOR_UNITS_PER_SECOND\":" - << SCALING_FACTOR_UNITS_PER_SECOND << "}"; + << stats.SCALING_FACTOR_UNITS_PER_SECOND << "}"; return ss.str(); } +} // namespace void Runner::stop() { shouldStop_ = true; diff --git a/examples/models/llama2/runner/runner.h b/examples/models/llama2/runner/runner.h index 34339a7c03..08f5e33c47 100644 --- a/examples/models/llama2/runner/runner.h +++ b/examples/models/llama2/runner/runner.h @@ -31,12 +31,39 @@ class Runner { const std::string& tokenizer_path, const float temperature = 0.8f); + struct Stats { + // Scaling factor for timestamps - in this case, we use ms. + const long SCALING_FACTOR_UNITS_PER_SECOND = 1000; + // Time stamps for the different stages of the execution + // model_load_start_ms: Start of model loading. + long model_load_start_ms; + // model_load_end_ms: End of model loading. + long model_load_end_ms; + // inference_start_ms: Immediately after the model is loaded (or we check + // for model load), measure the inference time. + long inference_start_ms; + // prompt_eval_end_ms: Prompt array allocation and tokenization. Ends right + // before the inference loop starts + long prompt_eval_end_ms; + // first_token: Timestamp when the first generated token is emitted + long first_token_ms; + // inference_end_ms: End of inference/generation. + long inference_end_ms; + // Keep a running total of the time spent in sampling. + long aggregate_sampling_time_ms; + // Token count from prompt + int64_t num_prompt_tokens; + // Token count from generated (total - prompt) + int64_t num_generated_tokens; + }; + bool is_loaded() const; Error load(); Error generate( const std::string& prompt, int32_t seq_len = 128, - std::function callback = {}); + std::function token_callback = {}, + std::function stats_callback = {}); void stop(); private: @@ -68,36 +95,7 @@ class Runner { std::unique_ptr tokenizer_; std::unique_ptr sampler_; bool shouldStop_{false}; - - struct TimeStamps { - // Scaling factor for timestamps - in this case, we use ms. - const long SCALING_FACTOR_UNITS_PER_SECOND = 1000; - // Time stamps for the different stages of the execution - // model_load_start_ms: Start of model loading. - long model_load_start_ms; - // model_load_end_ms: End of model loading. - long model_load_end_ms; - // inference_start_ms: Immediately after the model is loaded (or we check - // for model load), measure the inference time. - long inference_start_ms; - // prompt_eval_end_ms: Prompt array allocation and tokenization. Ends right - // before the inference loop starts - long prompt_eval_end_ms; - // first_token: Timestamp when the first generated token is emitted - long first_token_ms; - // inference_end_ms: End of inference/generation. - long inference_end_ms; - // Keep a running total of the time spent in sampling. - long aggregate_sampling_time_ms; - - void printReport( - const int64_t& num_prompt_tokens, - const int64_t& num_generated_tokens); - const std::string toJsonString( - const int64_t& num_prompt_tokens, - const int64_t& num_generated_tokens); - }; - TimeStamps timers_; + Stats stats_; }; } // namespace torch::executor