|
26 | 26 | #include "types.h" |
27 | 27 |
|
28 | 28 | #include <executorch/extension/llm/runner/llm_runner_helper.h> |
| 29 | +#include <executorch/extension/llm/runner/stats.h> |
29 | 30 | #include <executorch/extension/llm/runner/util.h> |
30 | 31 | #include <executorch/extension/llm/runner/wav_loader.h> |
31 | 32 | #include <executorch/extension/llm/tokenizers/third-party/llama.cpp-unicode/include/unicode.h> |
@@ -334,6 +335,10 @@ std::vector<Token> greedy_decode_executorch( |
334 | 335 | int main(int argc, char** argv) { |
335 | 336 | gflags::ParseCommandLineFlags(&argc, &argv, true); |
336 | 337 |
|
| 338 | + // Initialize stats for benchmarking |
| 339 | + ::executorch::extension::llm::Stats stats; |
| 340 | + stats.model_load_start_ms = ::executorch::extension::llm::time_in_ms(); |
| 341 | + |
337 | 342 | TimestampOutputMode timestamp_mode; |
338 | 343 | try { |
339 | 344 | timestamp_mode = parse_timestamp_output_mode(FLAGS_timestamps); |
@@ -362,6 +367,8 @@ int main(int argc, char** argv) { |
362 | 367 | ET_LOG(Error, "Failed to load model."); |
363 | 368 | return 1; |
364 | 369 | } |
| 370 | + stats.model_load_end_ms = ::executorch::extension::llm::time_in_ms(); |
| 371 | + stats.inference_start_ms = ::executorch::extension::llm::time_in_ms(); |
365 | 372 |
|
366 | 373 | // Load audio |
367 | 374 | ET_LOG(Info, "Loading audio from: %s", FLAGS_audio_path.c_str()); |
@@ -412,6 +419,10 @@ int main(int argc, char** argv) { |
412 | 419 | ET_LOG(Error, "Encoder forward failed."); |
413 | 420 | return 1; |
414 | 421 | } |
| 422 | + stats.prompt_eval_end_ms = ::executorch::extension::llm::time_in_ms(); |
| 423 | + stats.first_token_ms = |
| 424 | + stats.prompt_eval_end_ms; // For ASR, first token is at end of encoding |
| 425 | + |
415 | 426 | auto& enc_outputs = enc_result.get(); |
416 | 427 | auto f_proj = enc_outputs[0].toTensor(); // [B, T, joint_hidden] |
417 | 428 | int64_t encoded_len = enc_outputs[1].toTensor().const_data_ptr<int64_t>()[0]; |
@@ -488,6 +499,15 @@ int main(int argc, char** argv) { |
488 | 499 | decoded_tokens, *tokenizer); |
489 | 500 | std::cout << "Transcribed text: " << text << std::endl; |
490 | 501 |
|
| 502 | + // Record inference end time and token counts |
| 503 | + stats.inference_end_ms = ::executorch::extension::llm::time_in_ms(); |
| 504 | + stats.num_prompt_tokens = |
| 505 | + encoded_len; // Use encoder output length as "prompt" tokens |
| 506 | + stats.num_generated_tokens = static_cast<int64_t>(decoded_tokens.size()); |
| 507 | + |
| 508 | + // Print PyTorchObserver stats for benchmarking |
| 509 | + ::executorch::extension::llm::print_report(stats); |
| 510 | + |
491 | 511 | #ifdef ET_BUILD_METAL |
492 | 512 | executorch::backends::metal::print_metal_backend_stats(); |
493 | 513 | #endif // ET_BUILD_METAL |
|
0 commit comments