Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CSE lifts strict_float-wrapped operations out of their strict_float wrappers. #8563

Open
mcourteaux opened this issue Feb 6, 2025 · 7 comments

Comments

@mcourteaux
Copy link
Contributor

mcourteaux commented Feb 6, 2025

Debugging precision errors on my transcendentals revealed during lowering this behavior:

Bounding small allocations...
Lowering after bounding small allocations:
assert(reinterpret<uint64>((struct halide_buffer_t *)sin_approx$2.buffer) != (uint64)0, halide_error_buffer_argument_is_null("sin_approx$2"))
let sin_approx$2 = (void *)_halide_buffer_get_host((struct halide_buffer_t *)sin_approx$2.buffer)
let sin_approx$2.type = (uint32)_halide_buffer_get_type((struct halide_buffer_t *)sin_approx$2.buffer)
let sin_approx$2.dimensions = _halide_buffer_get_dimensions((struct halide_buffer_t *)sin_approx$2.buffer)
let sin_approx$2.min.0 = _halide_buffer_get_min((struct halide_buffer_t *)sin_approx$2.buffer, 0)
let sin_approx$2.extent.0 = _halide_buffer_get_extent((struct halide_buffer_t *)sin_approx$2.buffer, 0)
let sin_approx$2.stride.0 = _halide_buffer_get_stride((struct halide_buffer_t *)sin_approx$2.buffer, 0)
if ((uint1)_halide_buffer_is_bounds_query((struct halide_buffer_t *)sin_approx$2.buffer)) {
 (struct halide_buffer_t *)_halide_buffer_init((struct halide_buffer_t *)sin_approx$2.buffer, (struct halide_dimension_t *)_halide_buffer_get_shape((struct halide_buffer_t *)sin_approx$2.buffer), reinterpret<(void *)>((uint64)0), (uint64)0, reinterpret<(struct halide_device_interface_t *)>((uint64)0), 2, 32, 1, (struct halide_dimension_t *)make_struct((min(sin_approx$2.extent.0, 256) + sin_approx$2.min.0) + -256, max(sin_approx$2.extent.0, 256), 1, 0), (uint64)0)
}
if (!(uint1)_halide_buffer_is_bounds_query((struct halide_buffer_t *)sin_approx$2.buffer)) {
 assert(sin_approx$2.type == (uint32)73730, halide_error_bad_type("Output buffer sin_approx$2", sin_approx$2.type, (uint32)73730))
 assert(sin_approx$2.dimensions == 1, halide_error_bad_dimensions("Output buffer sin_approx$2", sin_approx$2.dimensions, 1))
 assert((256 <= sin_approx$2.extent.0) && (((max(sin_approx$2.extent.0, 256) + (min(sin_approx$2.extent.0, 256) + sin_approx$2.min.0)) + -256) <= (sin_approx$2.extent.0 + sin_approx$2.min.0)), halide_error_access_out_of_bounds("Output buffer sin_approx$2", 0, (min(sin_approx$2.extent.0, 256) + sin_approx$2.min.0) + -256, (max(sin_approx$2.extent.0, 256) + (min(sin_approx$2.extent.0, 256) + sin_approx$2.min.0)) + -257, sin_approx$2.min.0, (sin_approx$2.extent.0 + sin_approx$2.min.0) + -1))
 assert(sin_approx$2.stride.0 == 1, halide_error_constraint_violated("sin_approx$2.stride.0", sin_approx$2.stride.0, "1", 1))
 assert((uint64)abs(int64(sin_approx$2.extent.0)) <= (uint64)2147483647, halide_error_buffer_allocation_too_large("sin_approx$2", (uint64)abs(int64(sin_approx$2.extent.0)), (uint64)2147483647))
 produce sin_approx$2 {
  let halide_copy_to_device_result$5 = halide_copy_to_device((struct halide_buffer_t *)sin_approx$2.buffer, (struct halide_device_interface_t const *)halide_cuda_device_interface())
  assert(halide_copy_to_device_result$5 == 0, halide_copy_to_device_result$5)
  gpu_block<CUDA> (sin_approx$2.s0.i.v10.block_id_x, 0, (sin_approx$2.extent.0 + 255)/256) {
   gpu_thread<CUDA> (.thread_id_x, 0, 256) {
    let sin_approx$2.s0.i.v11.base.s = min(sin_approx$2.s0.i.v10.block_id_x*256, sin_approx$2.extent.0 + -256)
    sin_approx$2[.thread_id_x + sin_approx$2.s0.i.v11.base.s] = 
      let t115 = (float32)strict_float((float32)strict_float(float32(((sin_approx$2.min.0 + sin_approx$2.s0.i.v11.base.s) + .thread_id_x)))/(float32)strict_float(1048576.000000f)) in
      (float32)strict_float(let t118 = (float32)strict_float((float32)strict_float((float32)strict_float(-9.424778f)*(float32)strict_float((float32)strict_float(1.000000f) - t115)) + (float32)strict_float((float32)strict_float(9.424778f)*t115)) in (let t119 = (float32)floor_f32(t118*0.636620f) in (let t120.s = int32(t119) in (let t121 = (t118 - (t119*1.570796f)) in (let t122 = select(((t120.s % 4) == 1) || ((t120.s % 4) == 3), 1.570796f - t121, t121) in (let t124.s = ((t122*t122)*(((t122*t122)*(((t122*t122)*(((t122*t122)*0.000003f) + -0.000198f)) + 0.008333f)) + -0.166667f)) in select(1 < (t120.s % 4), 0.000000f - ((t124.s + 1.000000f)*t122), (t124.s + 1.000000f)*t122)))))))

   }
  }
  _halide_buffer_set_device_dirty((struct halide_buffer_t *)sin_approx$2.buffer, (uint1)1)
 }
}

Injecting warp shuffles...
Lowering after injecting warp shuffles:
assert(reinterpret<uint64>((struct halide_buffer_t *)sin_approx$2.buffer) != (uint64)0, halide_error_buffer_argument_is_null("sin_approx$2"))
let sin_approx$2 = (void *)_halide_buffer_get_host((struct halide_buffer_t *)sin_approx$2.buffer)
let sin_approx$2.type = (uint32)_halide_buffer_get_type((struct halide_buffer_t *)sin_approx$2.buffer)
let sin_approx$2.dimensions = _halide_buffer_get_dimensions((struct halide_buffer_t *)sin_approx$2.buffer)
let sin_approx$2.min.0 = _halide_buffer_get_min((struct halide_buffer_t *)sin_approx$2.buffer, 0)
let sin_approx$2.extent.0 = _halide_buffer_get_extent((struct halide_buffer_t *)sin_approx$2.buffer, 0)
let sin_approx$2.stride.0 = _halide_buffer_get_stride((struct halide_buffer_t *)sin_approx$2.buffer, 0)
if ((uint1)_halide_buffer_is_bounds_query((struct halide_buffer_t *)sin_approx$2.buffer)) {
 (struct halide_buffer_t *)_halide_buffer_init((struct halide_buffer_t *)sin_approx$2.buffer, (struct halide_dimension_t *)_halide_buffer_get_shape((struct halide_buffer_t *)sin_approx$2.buffer), reinterpret<(void *)>((uint64)0), (uint64)0, reinterpret<(struct halide_device_interface_t *)>((uint64)0), 2, 32, 1, (struct halide_dimension_t *)make_struct((min(sin_approx$2.extent.0, 256) + sin_approx$2.min.0) + -256, max(sin_approx$2.extent.0, 256), 1, 0), (uint64)0)
}
if (!(uint1)_halide_buffer_is_bounds_query((struct halide_buffer_t *)sin_approx$2.buffer)) {
 assert(sin_approx$2.type == (uint32)73730, halide_error_bad_type("Output buffer sin_approx$2", sin_approx$2.type, (uint32)73730))
 assert(sin_approx$2.dimensions == 1, halide_error_bad_dimensions("Output buffer sin_approx$2", sin_approx$2.dimensions, 1))
 assert(256 <= sin_approx$2.extent.0, halide_error_access_out_of_bounds("Output buffer sin_approx$2", 0, (sin_approx$2.extent.0 + sin_approx$2.min.0) + -256, (sin_approx$2.extent.0 + sin_approx$2.min.0) + -1, sin_approx$2.min.0, (sin_approx$2.extent.0 + sin_approx$2.min.0) + -1))
 assert(sin_approx$2.stride.0 == 1, halide_error_constraint_violated("sin_approx$2.stride.0", sin_approx$2.stride.0, "1", 1))
 assert((uint64)abs(int64(sin_approx$2.extent.0)) <= (uint64)2147483647, halide_error_buffer_allocation_too_large("sin_approx$2", (uint64)abs(int64(sin_approx$2.extent.0)), (uint64)2147483647))
 produce sin_approx$2 {
  let halide_copy_to_device_result$5 = halide_copy_to_device((struct halide_buffer_t *)sin_approx$2.buffer, (struct halide_device_interface_t const *)halide_cuda_device_interface())
  assert(halide_copy_to_device_result$5 == 0, halide_copy_to_device_result$5)
  gpu_block<CUDA> (sin_approx$2.s0.i.v10.block_id_x, 0, (sin_approx$2.extent.0 + 255)/256) {
   gpu_thread<CUDA> (.thread_id_x, 0, 256) {
    let sin_approx$2.s0.i.v11.base.s = min(sin_approx$2.s0.i.v10.block_id_x*256, sin_approx$2.extent.0 + -256)
    let t129 = (float32)strict_float((float32)strict_float(float32(((sin_approx$2.min.0 + sin_approx$2.s0.i.v11.base.s) + .thread_id_x)))/(float32)strict_float(1048576.000000f))
    let t130 = (float32)strict_float((float32)strict_float((float32)strict_float(-9.424778f)*(float32)strict_float((float32)strict_float(1.000000f) - t129)) + (float32)strict_float((float32)strict_float(9.424778f)*t129))
    let t131 = (float32)floor_f32(t130*0.636620f)
    let t132.s = int32(t131)
    let t133 = t130 - (t131*1.570796f)
    let t134 = select(((t132.s % 4) == 1) || ((t132.s % 4) == 3), 1.570796f - t133, t133)
    let t136.s = (t134*t134)*(((t134*t134)*(((t134*t134)*(((t134*t134)*0.000003f) + -0.000198f)) + 0.008333f)) + -0.166667f)
    sin_approx$2[.thread_id_x + sin_approx$2.s0.i.v11.base.s] = (float32)strict_float(select(1 < (t132.s % 4), 0.000000f - ((t136.s + 1.000000f)*t134), (t136.s + 1.000000f)*t134))
   }
  }
  _halide_buffer_set_device_dirty((struct halide_buffer_t *)sin_approx$2.buffer, (uint1)1)
 }
}

This causes LLVM IR to be produced that looks like this (a whole lot further down):

; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: write)
define void @_kernel_sin_approx__2_s0_i_v10_block_id_x(ptr noalias nocapture writeonly %"sin_approx$2", i32 %"sin_approx$2.extent.0", i32 %"sin_approx$2.min.0") local_unnamed_addr #1 {
entry:
  %"sin_approx$22" = addrspacecast ptr %"sin_approx$2" to ptr addrspace(1)
  %"sin_approx$2.s0.i.v10.block_id_x" = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
  %.thread_id_x = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
  %a15 = shl nsw i32 %"sin_approx$2.s0.i.v10.block_id_x", 8
  %b17 = add nsw i32 %"sin_approx$2.extent.0", -256
  %"sin_approx$2.s0.i.v11.base.s" = tail call i32 @llvm.smin.i32(i32 %a15, i32 %b17)
  %0 = add i32 %"sin_approx$2.s0.i.v11.base.s", %.thread_id_x
  %1 = add i32 %0, %"sin_approx$2.min.0"
  %2 = sitofp i32 %1 to float
  %t129 = fmul float %2, 0x3EB0000000000000
  %3 = fsub float 1.000000e+00, %t129
  %4 = fmul float %3, 0x4022D97C80000000
  %5 = fmul float %t129, 0x4022D97C80000000
  %6 = fsub float %5, %4
  %7 = fmul reassoc nnan ninf nsz contract afn float %6, 0x3FE45F3060000000
  %8 = tail call float @llvm.nvvm.floor.ftz.f(float %7)
  %t132.s = fptosi float %8 to i32
  %9 = fmul reassoc nnan ninf nsz contract afn float %8, 0x3FF921FB60000000
  %t133 = fsub reassoc nnan ninf nsz contract afn float %6, %9
  %t148 = and i32 %t132.s, 2
  %10 = and i32 %t132.s, 1
  %.not = icmp eq i32 %10, 0
  %11 = fsub reassoc nnan ninf nsz contract afn float 0x3FF921FB60000000, %t133
  %t134 = select reassoc nnan ninf nsz contract afn i1 %.not, float %t133, float %11
  %t149 = fmul reassoc nnan ninf nsz contract afn float %t134, %t134
  %12 = fmul reassoc nnan ninf nsz contract afn float %t149, 0x3EC5D30200000000
  %13 = fadd reassoc nnan ninf nsz contract afn float %12, 0xBF29F635A0000000
  %14 = fmul reassoc nnan ninf nsz contract afn float %13, %t149
  %15 = fadd reassoc nnan ninf nsz contract afn float %14, 0x3F8110E740000000
  %16 = fmul reassoc nnan ninf nsz contract afn float %15, %t149
  %17 = fadd reassoc nnan ninf nsz contract afn float %16, 0xBFC5555480000000
  %t136.s = fmul reassoc nnan ninf nsz contract afn float %17, %t149
  %18 = fadd reassoc nnan ninf nsz contract afn float %t136.s, 1.000000e+00
  %t145 = fmul reassoc nnan ninf nsz contract afn float %18, %t134
  %.not1 = icmp eq i32 %t148, 0
  %19 = fsub float 0.000000e+00, %t145
  %20 = select i1 %.not1, float %t145, float %19
  %21 = sext i32 %0 to i64
  %22 = getelementptr inbounds float, ptr addrspace(1) %"sin_approx$22", i64 %21
  store float %20, ptr addrspace(1) %22, align 4, !tbaa !6
  ret void
}

Notice the reassoc, contract and afn flags, which is basically the "fast-math" flags: https://llvm.org/docs/LangRef.html#rewrite-based-flags

Also notice that the issue is introduced after the CUDA-specific "Injecting warp shuffles..." pass.

For reference, final simplification before device codegen:

Lowering after final simplification:
assert(reinterpret<uint64>((struct halide_buffer_t *)sin_approx$2.buffer) != (uint64)0, halide_error_buffer_argument_is_null("sin_approx$2"))
let sin_approx$2 = (void *)_halide_buffer_get_host((struct halide_buffer_t *)sin_approx$2.buffer)
let sin_approx$2.type = (uint32)_halide_buffer_get_type((struct halide_buffer_t *)sin_approx$2.buffer)
let sin_approx$2.dimensions = _halide_buffer_get_dimensions((struct halide_buffer_t *)sin_approx$2.buffer)
let sin_approx$2.min.0 = _halide_buffer_get_min((struct halide_buffer_t *)sin_approx$2.buffer, 0)
let sin_approx$2.extent.0 = _halide_buffer_get_extent((struct halide_buffer_t *)sin_approx$2.buffer, 0)
let sin_approx$2.stride.0 = _halide_buffer_get_stride((struct halide_buffer_t *)sin_approx$2.buffer, 0)
if ((uint1)_halide_buffer_is_bounds_query((struct halide_buffer_t *)sin_approx$2.buffer)) {
 (struct halide_buffer_t *)_halide_buffer_init((struct halide_buffer_t *)sin_approx$2.buffer, (struct halide_dimension_t *)_halide_buffer_get_shape((struct halide_buffer_t *)sin_approx$2.buffer), reinterpret<(void *)>((uint64)0), (uint64)0, reinterpret<(struct halide_device_interface_t *)>((uint64)0), 2, 32, 1, (struct halide_dimension_t *)make_struct((min(sin_approx$2.extent.0, 256) + sin_approx$2.min.0) + -256, max(sin_approx$2.extent.0, 256), 1, 0), (uint64)0)
}
if (!(uint1)_halide_buffer_is_bounds_query((struct halide_buffer_t *)sin_approx$2.buffer)) {
 assert(sin_approx$2.type == (uint32)73730, halide_error_bad_type("Output buffer sin_approx$2", sin_approx$2.type, (uint32)73730))
 assert(sin_approx$2.dimensions == 1, halide_error_bad_dimensions("Output buffer sin_approx$2", sin_approx$2.dimensions, 1))
 assert(256 <= sin_approx$2.extent.0, let t147 = (sin_approx$2.extent.0 + sin_approx$2.min.0) in halide_error_access_out_of_bounds("Output buffer sin_approx$2", 0, t147 + -256, t147 + -1, sin_approx$2.min.0, t147 + -1))
 assert(sin_approx$2.stride.0 == 1, halide_error_constraint_violated("sin_approx$2.stride.0", sin_approx$2.stride.0, "1", 1))
 assert((uint64)abs(int64(sin_approx$2.extent.0)) <= (uint64)2147483647, halide_error_buffer_allocation_too_large("sin_approx$2", (uint64)abs(int64(sin_approx$2.extent.0)), (uint64)2147483647))
 produce sin_approx$2 {
  let halide_copy_to_device_result$5 = halide_copy_to_device((struct halide_buffer_t *)sin_approx$2.buffer, (struct halide_device_interface_t const *)halide_cuda_device_interface())
  assert(halide_copy_to_device_result$5 == 0, halide_copy_to_device_result$5)
  let t146 = (sin_approx$2.extent.0 + 255)/256
  gpu_block<CUDA> (sin_approx$2.s0.i.v10.block_id_x, 0, t146) {
   gpu_thread<CUDA> (.thread_id_x, 0, 256) {
    let sin_approx$2.s0.i.v11.base.s = min(sin_approx$2.s0.i.v10.block_id_x*256, sin_approx$2.extent.0 + -256)
    let t129 = (float32)strict_float((float32)strict_float(float32(((sin_approx$2.min.0 + sin_approx$2.s0.i.v11.base.s) + .thread_id_x)))/(float32)strict_float(1048576.000000f))
    let t130 = (float32)strict_float((float32)strict_float((float32)strict_float(-9.424778f)*(float32)strict_float((float32)strict_float(1.000000f) - t129)) + (float32)strict_float((float32)strict_float(9.424778f)*t129))
    let t131 = (float32)floor_f32(t130*0.636620f)
    let t132.s = int32(t131)
    let t133 = t130 - (t131*1.570796f)
    let t134 = let t148 = (t132.s % 4) in select((t148 == 1) || (t148 == 3), 1.570796f - t133, t133)
    let t136.s = let t149 = (t134*t134) in (t149*((t149*((t149*((t149*0.000003f) + -0.000198f)) + 0.008333f)) + -0.166667f))
    let t145 = (t136.s + 1.000000f)*t134
    sin_approx$2[.thread_id_x + sin_approx$2.s0.i.v11.base.s] = (float32)strict_float(select(1 < (t132.s % 4), 0.000000f - t145, t145))
   }
  }
  _halide_buffer_set_device_dirty((struct halide_buffer_t *)sin_approx$2.buffer, (uint1)1)
 }
}

Offending code:

Halide/src/CSE.cpp

Lines 317 to 344 in b937537

// Figure out which ones we'll pull out as lets and variables.
vector<pair<string, Expr>> lets;
vector<Expr> new_version(gvn.entries.size());
map<Expr, Expr, ExprCompare> replacements;
for (size_t i = 0; i < gvn.entries.size(); i++) {
const auto &e = gvn.entries[i];
if (e->use_count > 1) {
string name = namer.make_unique_name();
lets.emplace_back(name, e->expr);
// Point references to this expr to the variable instead.
replacements[e->expr] = Variable::make(e->expr.type(), name);
}
debug(4) << i << ": " << e->expr << ", " << e->use_count << "\n";
}
// Rebuild the expr to include references to the variables:
Replacer replacer(replacements);
e = replacer.mutate(e);
debug(4) << "With variables " << e << "\n";
// Wrap the final expr in the lets.
for (const auto &[var, value] : reverse_view(lets)) {
// Drop this variable as an acceptable replacement for this expr.
replacer.erase(value);
// Use containing lets in the value.
e = Let::make(var, replacer.mutate(value), e);
}

I'm not seeing this issue on x86_64...

@mcourteaux
Copy link
Contributor Author

I debugged this a bit further, and learned a few things:

  • CSE does lift out the expressions of their strict-float context, and it does influence the result.
  • StricifyFloat is running early, so my fast_sin intrinsic gets replaced by the polynomial in a later pass, and the Exprs don't get a bunch of strict_float littered over them.
  • Running a result = common_subexpression_elimination(result) at the end of my approximation function does effectively start the CSE without the context. When someone would fix CSE to be aware of the strict_float() context in which it currently is extracting Exprs (like I did below), that still wouldn't work if we randomly run it during lowering passes after StrictifyFloat has done it's work.

Here is a quick hack for fixing CSE and strict floats:

diff --git a/src/CSE.cpp b/src/CSE.cpp
index df055c4bd..e5acbaa56 100644
--- a/src/CSE.cpp
+++ b/src/CSE.cpp
@@ -80,6 +80,7 @@ class GVN : public IRMutator {
 public:
     struct Entry {
         Expr expr;
+        bool strict_float = false;
         int use_count = 0;
         // All consumer Exprs for which this is the last child Expr.
         map<Expr, int, IRGraphDeepCompare> uses;
@@ -144,6 +145,7 @@ public:
 class ComputeUseCounts : public IRGraphVisitor {
     GVN &gvn;
     bool lift_all;
+    bool in_strict_float{false};
 
 public:
     ComputeUseCounts(GVN &g, bool l)
@@ -153,6 +155,15 @@ public:
     using IRGraphVisitor::include;
     using IRGraphVisitor::visit;
 
+    void visit(const Call *op) override {
+        if (op->is_intrinsic(Call::strict_float)) {
+            ScopedValue<bool> bind(in_strict_float, true);
+            IRGraphVisitor::visit(op);
+        } else {
+            IRGraphVisitor::visit(op);
+        }
+    }
+
     void include(const Expr &e) override {
         // If it's not the sort of thing we want to extract as a let,
         // just use the generic visitor to increment use counts for
@@ -167,7 +178,9 @@ public:
         // Find this thing's number.
         auto iter = gvn.output_numbering.find(e);
         if (iter != gvn.output_numbering.end()) {
-            gvn.entries[iter->second]->use_count++;
+            auto &entry = gvn.entries[iter->second];
+            entry->use_count++;
+            entry->strict_float |= in_strict_float;
         } else {
             internal_error << "Expr not in shallow numbering: " << e << "\n";
         }
@@ -321,14 +334,14 @@ Expr common_subexpression_elimination(const Expr &e_in, bool lift_all) {
     debug(4) << "Canonical form without lets " << e << "\n";
 
     // Figure out which ones we'll pull out as lets and variables.
-    vector<pair<string, Expr>> lets;
+    vector<std::tuple<string, Expr, bool>> lets;
     vector<Expr> new_version(gvn.entries.size());
     map<Expr, Expr, ExprCompare> replacements;
     for (size_t i = 0; i < gvn.entries.size(); i++) {
         const auto &e = gvn.entries[i];
         if (e->use_count > 1) {
             string name = namer.make_unique_name();
-            lets.emplace_back(name, e->expr);
+            lets.emplace_back(name, e->expr, e->strict_float);
             // Point references to this expr to the variable instead.
             replacements[e->expr] = Variable::make(e->expr.type(), name);
         }
@@ -342,11 +355,15 @@ Expr common_subexpression_elimination(const Expr &e_in, bool lift_all) {
     debug(4) << "With variables " << e << "\n";
 
     // Wrap the final expr in the lets.
-    for (const auto &[var, value] : reverse_view(lets)) {
+    for (const auto &[var, value, expr_strict_float] : reverse_view(lets)) {
         // Drop this variable as an acceptable replacement for this expr.
         replacer.erase(value);
         // Use containing lets in the value.
-        e = Let::make(var, replacer.mutate(value), e);
+        if (expr_strict_float) {
+            e = Let::make(var, strict_float(replacer.mutate(value)), e);
+        } else {
+            e = Let::make(var, replacer.mutate(value), e);
+        }
     }

The catch here is that once a common subexpression appears once in a strict-float context, all other occurrences of that subexpression will use the lifted strict_float-version of that Expr. I don't know if that's acceptable.

@mcourteaux
Copy link
Contributor Author

mcourteaux commented Feb 6, 2025

Despite fixes and enabling strict float, I still have an issue. The final lowered stmt is:

   gpu_thread<CUDA> (.thread_id_x, 0, 256) {
    let sin_approx$4.s0.i.v11.base.s = min(sin_approx$4.s0.i.v10.block_id_x*256, sin_approx$4.extent.0 + -256)
    let t118 = (float32)strict_float((float32)strict_float(float32(((sin_approx$4.min.0 + sin_approx$4.s0.i.v11.base.s) + .thread_id_x)))/(float32)strict_float(1048576.000000f))
    let t119 = (float32)strict_float((float32)strict_float((float32)strict_float(-1.046150f)*(float32)strict_float((float32)strict_float(1.000000f) - t118)) + (float32)strict_float((float32)strict_float(1.046150f)*t118))
    let t120 = (float32)strict_float((float32)floor_f32(t119*0.636620f))
    let t121 = strict_float(int32(t120) % 4)
    let t122 = (float32)strict_float(t119 - (t120*1.570796f))
    let t123 = (float32)strict_float(select((t121 == 1) || (t121 == 3), 1.570796f - t122, t122))
    let t124 = (float32)strict_float(t123*t123)
    let t125 = (float32)strict_float(t123*((t124*((t124*((t124*-0.000185f) + 0.008312f)) + -0.166655f)) + 0.999999f))
    sin_approx$4[.thread_id_x + sin_approx$4.s0.i.v11.base.s] = (float32)strict_float(select(1 < t121, 0.000000f - t125, t125))
   }

Which produces this LLVM IR:

; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: write)
define void @_kernel_sin_approx__4_s0_i_v10_block_id_x(ptr noalias nocapture writeonly %"sin_approx$4", i32 %"sin_approx$4.extent.0", i32 %"sin_approx$4.min.0") local_unnamed_addr #1 {
entry:
  %"sin_approx$42" = addrspacecast ptr %"sin_approx$4" to ptr addrspace(1)
  %"sin_approx$4.s0.i.v10.block_id_x" = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
  %.thread_id_x = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
  %a15 = shl nsw i32 %"sin_approx$4.s0.i.v10.block_id_x", 8
  %b17 = add nsw i32 %"sin_approx$4.extent.0", -256
  %"sin_approx$4.s0.i.v11.base.s" = tail call i32 @llvm.smin.i32(i32 %a15, i32 %b17)
  %0 = add i32 %"sin_approx$4.s0.i.v11.base.s", %.thread_id_x
  %1 = add i32 %0, %"sin_approx$4.min.0"
  %2 = sitofp i32 %1 to float
  %t118 = fmul float %2, 0x3EB0000000000000
  %3 = fsub float 1.000000e+00, %t118
  %4 = fmul float %3, 0x3FF0BD0820000000
  %5 = fmul float %t118, 0x3FF0BD0820000000
  %6 = fsub float %5, %4
  %7 = fmul float %6, 0x3FE45F3060000000
  %8 = tail call float @llvm.nvvm.floor.ftz.f(float %7)
  %9 = fptosi float %8 to i32
  %t121 = and i32 %9, 2
  %10 = fmul float %8, 0x3FF921FB60000000
  %t122 = fsub float %6, %10
  %11 = and i32 %9, 1
  %.not = icmp eq i32 %11, 0
  %12 = fsub float 0x3FF921FB60000000, %t122
  %t123 = select i1 %.not, float %t122, float %12
  %t124 = fmul float %t123, %t123
  %13 = fmul float %t124, 0x3F28395C00000000
  %14 = fsub float 0x3F8105A920000000, %13
  %15 = fmul float %t124, %14
  %16 = fadd float %15, 0xBFC554F4A0000000
  %17 = fmul float %t124, %16
  %18 = fadd float %17, 0x3FEFFFFE00000000
  %t125 = fmul float %t123, %18
  %.not1 = icmp eq i32 %t121, 0
  %19 = fsub float 0.000000e+00, %t125
  %20 = select i1 %.not1, float %t125, float %19
  %21 = sext i32 %0 to i64
  %22 = getelementptr inbounds float, ptr addrspace(1) %"sin_approx$42", i64 %21
  store float %20, ptr addrspace(1) %22, align 4, !tbaa !6
  ret void
}

; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare i32 @llvm.smin.i32(i32 %0, i32 %1) #2

attributes #0 = { mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) "nvptx-f32ftz"="true" }
attributes #1 = { mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: write) "nvptx-f32ftz"="true" "reciprocal-estimates"="none" "target-cpu" "target-features" "tune-cpu" }
attributes #2 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }

!nvvm.internalize.after.link = !{}
!nvvm.annotations = !{!0, !1, !0, !2, !2, !2, !2, !3, !3, !2, !4}
!llvm.module.flags = !{!5}

!0 = !{null, !"align", i32 8}
!1 = !{null, !"align", i32 8, !"align", i32 65544, !"align", i32 131080}
!2 = !{null, !"align", i32 16}
!3 = !{null, !"align", i32 16, !"align", i32 65552, !"align", i32 131088}
!4 = !{ptr @_kernel_sin_approx__4_s0_i_v10_block_id_x, !"kernel", i32 1}
!5 = !{i32 4, !"nvvm-reflect-ftz", i32 1}
!6 = !{!7, !7, i64 0}
!7 = !{!"sin_approx$4", !8, i64 0}
!8 = !{!"Halide buffer"}

And a PTX kernel that looks like this:

//
// Generated by LLVM NVPTX Back-End
//

.version 7.0
.target sm_75
.address_size 64

	// .globl	_kernel_sin_approx__4_s0_i_v10_block_id_x // -- Begin function _kernel_sin_approx__4_s0_i_v10_block_id_x
                                        // @_kernel_sin_approx__4_s0_i_v10_block_id_x
.visible .entry _kernel_sin_approx__4_s0_i_v10_block_id_x(
	.param .u64 _kernel_sin_approx__4_s0_i_v10_block_id_x_param_0,
	.param .u32 _kernel_sin_approx__4_s0_i_v10_block_id_x_param_1,
	.param .u32 _kernel_sin_approx__4_s0_i_v10_block_id_x_param_2
)
{
	.reg .pred 	%p<3>;
	.reg .b32 	%r<13>;
	.reg .f32 	%f<21>;
	.reg .b64 	%rd<5>;

// %bb.0:                               // %entry
	ld.param.u64 	%rd1, [_kernel_sin_approx__4_s0_i_v10_block_id_x_param_0];
	cvta.to.global.u64 	%rd2, %rd1;
	ld.param.u32 	%r1, [_kernel_sin_approx__4_s0_i_v10_block_id_x_param_1];
	ld.param.u32 	%r2, [_kernel_sin_approx__4_s0_i_v10_block_id_x_param_2];
	mov.u32 	%r3, %ctaid.x;
	mov.u32 	%r4, %tid.x;
	shl.b32 	%r5, %r3, 8;
	add.s32 	%r6, %r1, -256;
	min.s32 	%r7, %r5, %r6;
	add.s32 	%r8, %r7, %r4;
	add.s32 	%r9, %r8, %r2;
	cvt.rn.f32.s32 	%f1, %r9;
	mul.f32 	%f2, %f1, 0f35800000;
	fma.rn.f32 	%f3, %f1, 0fB5800000, 0f3F800000;
	mul.f32 	%f4, %f3, 0f3F85E841;
	neg.f32 	%f5, %f4;
	fma.rn.f32 	%f6, %f2, 0f3F85E841, %f5;
	mul.f32 	%f7, %f6, 0f3F22F983;
	cvt.rmi.ftz.f32.f32 	%f8, %f7;
	cvt.rzi.s32.f32 	%r10, %f8;
	and.b32  	%r11, %r10, 2;
	fma.rn.f32 	%f9, %f8, 0fBFC90FDB, %f6;
	and.b32  	%r12, %r10, 1;
	setp.eq.b32 	%p1, %r12, 1;
	mov.f32 	%f10, 0f3FC90FDB;
	sub.f32 	%f11, %f10, %f9;
	selp.f32 	%f12, %f11, %f9, %p1;
	mul.f32 	%f13, %f12, %f12;
	fma.rn.f32 	%f14, %f13, 0fB941CAE0, 0f3C082D49;
	fma.rn.f32 	%f15, %f13, %f14, 0fBE2AA7A5;
	fma.rn.f32 	%f16, %f13, %f15, 0f3F7FFFF0;
	mul.f32 	%f17, %f12, %f16;
	setp.eq.s32 	%p2, %r11, 0;
	neg.f32 	%f18, %f12;
	fma.rn.f32 	%f19, %f18, %f16, 0f00000000;
	selp.f32 	%f20, %f17, %f19, %p2;
	mul.wide.s32 	%rd3, %r8, 4;
	add.s64 	%rd4, %rd2, %rd3;
	st.global.f32 	[%rd4], %f20;
	ret;
                                        // -- End function
}

I'm stilling not getting matched accuracy as on CPU target. Am I missing some flags or details...?

I notice the use of the fma (fused multiply add), but that should be more precise than a sequence of multiply and add. Correct? Polynomials were optimized using double precision, so I expect the FMA to improve the result, not degrade.

@mcourteaux
Copy link
Contributor Author

I have one more idea to try. If that fails, then my GPU is not IEEE compliant or something... I seemingly get exactly the same results on Vulkan as I get on CUDA. Will report when I had the chance to try that.

@mcourteaux
Copy link
Contributor Author

Update: I had wrapped the arguments to my transcendentals in strict_float(), which seems to not work.

float lower = -10.0f, upper = 10.0f;
int steps = 10000;
Expr t = i / float(steps);
Expr arg = strict_float(lower * (1.0f - t) + upper * t);
Expr result = fast_sin(arg);

This was intended to avoid variable arguments on different backends, due to differences in optimization. While I think my code is correct, and the idea is correct, this failed. I do get different results on CUDA then on CPU. Precalculating the inputs compute_root() on CPU, and then copying the buffer over to the GPU to test the transcendentals works great.

Either way, this on it's own seems a strict_float related issue, which I didn't debug yet. I'm now back to working on the PR for transcendentals.

@mcourteaux
Copy link
Contributor Author

It seems that the StrictifyFloat pass just slams a strict_float() wrapper on any Expr and sub-Expr until the very leaves of the IR tree. As such all common sub-expressions should have the strict_float() on them, and all of their sub-expressions have strict_float(), which makes lifting them out not a problem in general, as the strict_float() was part of what was identified as being "common".

The issue I first noticed (original post on top of this issue page) arises from the fact that I generate new Exprs during one of the lowering passes.

Overall, I think the way StrictifyFloat litters strict_float() on all float Exprs might not be such a nice way to do things. Anything wrapped in strict_float() should mean that all of it's sub-expressions are strict float too.

To summarize, I think we still have a bug when we do this:

Var x;
Expr t = x / 10304.0f;
Expr t2 = t * t;
Expr t3 = t2 * t;
Func f;
f(x) = strict_float(t3 + t2);

I would expect the whole func to be computed as under strict float conditions; and not just the final addition. Right now CSE will lift t2 and t3 into Lets which get converted to LetStmts, and they are no longer strict float. As I understand it, strict_float() now is not recursively applied, but just on the direct child Expr. Right now the documentation reads:

/** Makes a best effort attempt to preserve IEEE floating-point
 * semantics in evaluating an expression. May not be implemented for
 * all backends. (E.g. it is difficult to do this for C++ code
 * generation as it depends on the compiler flags used to compile the
 * generated code. */

Which suggests my code example above should indeed be fully compiled under strict-float conditions. Perhaps we need a relax_float() counterpart to stop the recursive impact from propagating further down the IR tree.

What I'm advocating for I think, is that StrictifyFloat should simply put one strict_float() wrapper on every Func's definition Expr in the environment during lowering, instead of littering it everywhere.

@abadams
Copy link
Member

abadams commented Feb 6, 2025

Yeah, strict_float is problematic to keep around in the IR. I think instead of strict_float, we should have pure math intrinsics for strict floating point multiplication, addition, subtract, division, and anything else that rounds. The simplifier will not have rules for these intrinsics, so it won't mess with them. Instead of being an intrinsic, strict_float will be a function that takes an Expr and mutates it, replacing all rounding floating point ops with the equivalent strict math intrinsics.

@mcourteaux
Copy link
Contributor Author

  • Running a result = common_subexpression_elimination(result) at the end of my approximation function does effectively start the CSE without the context. When someone would fix CSE to be aware of the strict_float() context in which it currently is extracting Exprs (like I did below), that still wouldn't work if we randomly run it during lowering passes after StrictifyFloat has done it's work.

I think there is an important nuance to this. What I described would only be problematic when the Let expressions get converted to LetStmts. A Let expressions can be wrapped in a strict_float(), as it's an expression. So when one writes a strict_float(fast_sin(...)), the whole thing can be rewritten with a CSE pass using Let expressions, which as a whole is still wrapped inside the strict_float().

This wouldn't be true for LetStmts. However, the thing is that running CSE on an Expr (like at the end of my polynomial approximation funcions) will never produce LetStmts, but only Lets.

@alexreinking alexreinking removed the bug label Feb 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants