From e274dfa98df972b913e5702e9df531e69d35489c Mon Sep 17 00:00:00 2001
From: "romain.biessy" <romain.biessy@intel.com>
Date: Thu, 4 Jan 2024 17:31:43 +0100
Subject: [PATCH] Invalidate cache in our benchmarks

---
 test/bench/portfft/launch_bench.hpp | 50 ++++++++++++++++++++++-------
 test/bench/utils/bench_utils.hpp    | 25 +++++++++++++++
 2 files changed, 64 insertions(+), 11 deletions(-)

diff --git a/test/bench/portfft/launch_bench.hpp b/test/bench/portfft/launch_bench.hpp
index 4a8ef377..0706e107 100644
--- a/test/bench/portfft/launch_bench.hpp
+++ b/test/bench/portfft/launch_bench.hpp
@@ -57,29 +57,45 @@ void bench_dft_average_host_time_impl(benchmark::State& state, sycl::queue q, po
   double ops = cooley_tukey_ops_estimate(N, N_transforms);
   std::size_t bytes_transferred = global_mem_transactions<complex_type, complex_type>(N_transforms, N, N);
 
-  auto in_dev = make_shared<forward_t>(num_elements, q);
+  std::size_t input_size_bytes = num_elements * sizeof(forward_t);
+  std::size_t output_size_bytes =
+      desc.placement == portfft::placement::OUT_OF_PLACE ? num_elements * sizeof(complex_type) : 0;
+  auto global_mem_size = q.get_device().get_info<sycl::info::device::global_mem_size>();
+  const std::size_t num_inputs =
+      get_average_host_num_inputs(input_size_bytes, output_size_bytes, global_mem_size, runs_to_average);
+
+  std::vector<std::shared_ptr<forward_t>> device_inputs;
+  for (std::size_t i = 0; i < num_inputs; ++i) {
+    device_inputs.push_back(make_shared<forward_t>(num_elements, q));
+  }
+  auto in_dev0 = device_inputs[0];
   std::shared_ptr<complex_type> out_dev =
       desc.placement == portfft::placement::OUT_OF_PLACE ? make_shared<complex_type>(num_elements, q) : nullptr;
 
   auto committed = desc.commit(q);
   q.wait();
 
+  std::vector<forward_t> host_forward_data;
 #ifdef PORTFFT_VERIFY_BENCHMARKS
   auto [forward_data, backward_data, forward_data_imag, backward_data_imag] =
       gen_fourier_data<portfft::direction::FORWARD, portfft::complex_storage::INTERLEAVED_COMPLEX>(
           desc, portfft::detail::layout::PACKED, portfft::detail::layout::PACKED, 0.f);
-  q.copy(forward_data.data(), in_dev.get(), num_elements).wait();
+  q.copy(forward_data.data(), in_dev0.get(), num_elements).wait();
+  host_forward_data = std::move(forward_data);
+#else
+  host_forward_data.resize(num_elements);
 #endif  // PORTFFT_VERIFY_BENCHMARKS
 
   // warmup
-  auto event = desc.placement == portfft::placement::IN_PLACE ? committed.compute_forward(in_dev.get())
-                                                              : committed.compute_forward(in_dev.get(), out_dev.get());
+  auto event = desc.placement == portfft::placement::IN_PLACE ? committed.compute_forward(in_dev0.get())
+                                                              : committed.compute_forward(in_dev0.get(), out_dev.get());
   event.wait();
 
 #ifdef PORTFFT_VERIFY_BENCHMARKS
   std::vector<complex_type> host_output(num_elements);
-  q.copy(desc.placement == portfft::placement::IN_PLACE ? reinterpret_cast<complex_type*>(in_dev.get()) : out_dev.get(),
-         host_output.data(), num_elements)
+  q.copy(
+       desc.placement == portfft::placement::IN_PLACE ? reinterpret_cast<complex_type*>(in_dev0.get()) : out_dev.get(),
+       host_output.data(), num_elements)
       .wait();
   verify_dft<portfft::direction::FORWARD, portfft::complex_storage::INTERLEAVED_COMPLEX>(desc, backward_data,
                                                                                          host_output, 1e-2);
@@ -92,21 +108,27 @@ void bench_dft_average_host_time_impl(benchmark::State& state, sycl::queue q, po
     // calculation of flops
     dependencies.clear();
 
+    // Write to the inputs to invalidate cache
+    for (auto in_dev : device_inputs) {
+      q.copy(host_forward_data.data(), in_dev.get(), num_elements);
+    }
+    q.wait_and_throw();
+
     std::chrono::time_point<std::chrono::high_resolution_clock> start;
     std::chrono::time_point<std::chrono::high_resolution_clock> end;
     if (desc.placement == portfft::placement::IN_PLACE) {
       start = std::chrono::high_resolution_clock::now();
-      dependencies.emplace_back(committed.compute_forward(in_dev.get()));
+      dependencies.emplace_back(committed.compute_forward(device_inputs[0].get()));
       for (std::size_t r = 1; r != runs; r += 1) {
-        dependencies[0] = committed.compute_forward(in_dev.get(), dependencies);
+        dependencies[0] = committed.compute_forward(device_inputs[r % num_inputs].get(), dependencies);
       }
       dependencies[0].wait();
       end = std::chrono::high_resolution_clock::now();
     } else {
       start = std::chrono::high_resolution_clock::now();
-      dependencies.emplace_back(committed.compute_forward(in_dev.get(), out_dev.get()));
+      dependencies.emplace_back(committed.compute_forward(device_inputs[0].get(), out_dev.get()));
       for (std::size_t r = 1; r != runs; r += 1) {
-        dependencies[0] = committed.compute_forward(in_dev.get(), out_dev.get(), dependencies);
+        dependencies[0] = committed.compute_forward(device_inputs[r % num_inputs].get(), out_dev.get(), dependencies);
       }
       dependencies[0].wait();
       end = std::chrono::high_resolution_clock::now();
@@ -162,13 +184,17 @@ void bench_dft_device_time_impl(benchmark::State& state, sycl::queue q, portfft:
       desc.placement == portfft::placement::OUT_OF_PLACE ? make_shared<complex_type>(num_elements, q) : nullptr;
 
   auto committed = desc.commit(q);
-
   q.wait();
+
+  std::vector<forward_t> host_forward_data;
 #ifdef PORTFFT_VERIFY_BENCHMARKS
   auto [forward_data, backward_data, forward_data_imag, backward_data_imag] =
       gen_fourier_data<portfft::direction::FORWARD, portfft::complex_storage::INTERLEAVED_COMPLEX>(
           desc, portfft::detail::layout::PACKED, portfft::detail::layout::PACKED, 0.f);
   q.copy(forward_data.data(), in_dev.get(), num_elements).wait();
+  host_forward_data = std::move(forward_data);
+#else
+  host_forward_data.resize(num_elements);
 #endif  // PORTFFT_VERIFY_BENCHMARKS
 
   auto compute = [&]() {
@@ -188,6 +214,8 @@ void bench_dft_device_time_impl(benchmark::State& state, sycl::queue q, portfft:
 #endif  // PORTFFT_VERIFY_BENCHMARKS
 
   for (auto _ : state) {
+    // Write to the input to invalidate cache
+    q.copy(host_forward_data.data(), in_dev.get(), num_elements).wait();
     sycl::event e = compute();
     e.wait();
     auto start = e.get_profiling_info<sycl::info::event_profiling::command_start>();
diff --git a/test/bench/utils/bench_utils.hpp b/test/bench/utils/bench_utils.hpp
index 04e27c20..87fbb31b 100644
--- a/test/bench/utils/bench_utils.hpp
+++ b/test/bench/utils/bench_utils.hpp
@@ -38,6 +38,31 @@
  */
 static constexpr std::size_t runs_to_average = 10;
 
+/**
+ * Get the number of inputs to allocate for the average_host benchmark.
+ * Try to use \p target_num_inputs distinct inputs so that each call to compute uses a different input and avoids
+ * relying on cache. We allow to allocate up to 90% of the global memory for inputs and outputs. Multiple inputs are
+ * needed to avoid affecting the timings.
+ *
+ * @param input_size_bytes Size of the FFT input in bytes
+ * @param output_size_bytes Upper bound estimation of the size of the FFT output in bytes
+ * @param global_mem_size Global memory size available on the device
+ * @param target_num_inputs Target number of inputs
+ */
+std::size_t get_average_host_num_inputs(std::size_t input_size_bytes, std::size_t output_size_bytes,
+                                        std::size_t global_mem_size, std::size_t target_num_inputs) {
+  const std::size_t desired_allocation_size = input_size_bytes * target_num_inputs + output_size_bytes;
+  const std::size_t allocation_size_threshold = static_cast<std::size_t>(0.9 * static_cast<double>(global_mem_size));
+  const std::size_t num_inputs = desired_allocation_size <= allocation_size_threshold ? target_num_inputs : 1;
+  if (num_inputs < target_num_inputs) {
+    std::cerr << "Warning: Not enough global memory to allocate " << target_num_inputs
+              << " input(s). The results may appear better than they would be in a real application due to the "
+                 "device's cache."
+              << std::endl;
+  }
+  return num_inputs;
+}
+
 // Handle an exception by passing the message onto `SkipWithError`.
 // It is expected that this will be placed so the benchmark ends after this is called,
 // allowing the test to exit gracefully with an error message before moving onto the next test.