-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CSE lifts strict_float-wrapped operations out of their strict_float wrappers. #8563
Comments
I debugged this a bit further, and learned a few things:
Here is a quick hack for fixing CSE and strict floats: diff --git a/src/CSE.cpp b/src/CSE.cpp
index df055c4bd..e5acbaa56 100644
--- a/src/CSE.cpp
+++ b/src/CSE.cpp
@@ -80,6 +80,7 @@ class GVN : public IRMutator {
public:
struct Entry {
Expr expr;
+ bool strict_float = false;
int use_count = 0;
// All consumer Exprs for which this is the last child Expr.
map<Expr, int, IRGraphDeepCompare> uses;
@@ -144,6 +145,7 @@ public:
class ComputeUseCounts : public IRGraphVisitor {
GVN &gvn;
bool lift_all;
+ bool in_strict_float{false};
public:
ComputeUseCounts(GVN &g, bool l)
@@ -153,6 +155,15 @@ public:
using IRGraphVisitor::include;
using IRGraphVisitor::visit;
+ void visit(const Call *op) override {
+ if (op->is_intrinsic(Call::strict_float)) {
+ ScopedValue<bool> bind(in_strict_float, true);
+ IRGraphVisitor::visit(op);
+ } else {
+ IRGraphVisitor::visit(op);
+ }
+ }
+
void include(const Expr &e) override {
// If it's not the sort of thing we want to extract as a let,
// just use the generic visitor to increment use counts for
@@ -167,7 +178,9 @@ public:
// Find this thing's number.
auto iter = gvn.output_numbering.find(e);
if (iter != gvn.output_numbering.end()) {
- gvn.entries[iter->second]->use_count++;
+ auto &entry = gvn.entries[iter->second];
+ entry->use_count++;
+ entry->strict_float |= in_strict_float;
} else {
internal_error << "Expr not in shallow numbering: " << e << "\n";
}
@@ -321,14 +334,14 @@ Expr common_subexpression_elimination(const Expr &e_in, bool lift_all) {
debug(4) << "Canonical form without lets " << e << "\n";
// Figure out which ones we'll pull out as lets and variables.
- vector<pair<string, Expr>> lets;
+ vector<std::tuple<string, Expr, bool>> lets;
vector<Expr> new_version(gvn.entries.size());
map<Expr, Expr, ExprCompare> replacements;
for (size_t i = 0; i < gvn.entries.size(); i++) {
const auto &e = gvn.entries[i];
if (e->use_count > 1) {
string name = namer.make_unique_name();
- lets.emplace_back(name, e->expr);
+ lets.emplace_back(name, e->expr, e->strict_float);
// Point references to this expr to the variable instead.
replacements[e->expr] = Variable::make(e->expr.type(), name);
}
@@ -342,11 +355,15 @@ Expr common_subexpression_elimination(const Expr &e_in, bool lift_all) {
debug(4) << "With variables " << e << "\n";
// Wrap the final expr in the lets.
- for (const auto &[var, value] : reverse_view(lets)) {
+ for (const auto &[var, value, expr_strict_float] : reverse_view(lets)) {
// Drop this variable as an acceptable replacement for this expr.
replacer.erase(value);
// Use containing lets in the value.
- e = Let::make(var, replacer.mutate(value), e);
+ if (expr_strict_float) {
+ e = Let::make(var, strict_float(replacer.mutate(value)), e);
+ } else {
+ e = Let::make(var, replacer.mutate(value), e);
+ }
} The catch here is that once a common subexpression appears once in a strict-float context, all other occurrences of that subexpression will use the lifted |
Despite fixes and enabling strict float, I still have an issue. The final lowered stmt is: gpu_thread<CUDA> (.thread_id_x, 0, 256) {
let sin_approx$4.s0.i.v11.base.s = min(sin_approx$4.s0.i.v10.block_id_x*256, sin_approx$4.extent.0 + -256)
let t118 = (float32)strict_float((float32)strict_float(float32(((sin_approx$4.min.0 + sin_approx$4.s0.i.v11.base.s) + .thread_id_x)))/(float32)strict_float(1048576.000000f))
let t119 = (float32)strict_float((float32)strict_float((float32)strict_float(-1.046150f)*(float32)strict_float((float32)strict_float(1.000000f) - t118)) + (float32)strict_float((float32)strict_float(1.046150f)*t118))
let t120 = (float32)strict_float((float32)floor_f32(t119*0.636620f))
let t121 = strict_float(int32(t120) % 4)
let t122 = (float32)strict_float(t119 - (t120*1.570796f))
let t123 = (float32)strict_float(select((t121 == 1) || (t121 == 3), 1.570796f - t122, t122))
let t124 = (float32)strict_float(t123*t123)
let t125 = (float32)strict_float(t123*((t124*((t124*((t124*-0.000185f) + 0.008312f)) + -0.166655f)) + 0.999999f))
sin_approx$4[.thread_id_x + sin_approx$4.s0.i.v11.base.s] = (float32)strict_float(select(1 < t121, 0.000000f - t125, t125))
} Which produces this LLVM IR: ; Function Attrs: mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: write)
define void @_kernel_sin_approx__4_s0_i_v10_block_id_x(ptr noalias nocapture writeonly %"sin_approx$4", i32 %"sin_approx$4.extent.0", i32 %"sin_approx$4.min.0") local_unnamed_addr #1 {
entry:
%"sin_approx$42" = addrspacecast ptr %"sin_approx$4" to ptr addrspace(1)
%"sin_approx$4.s0.i.v10.block_id_x" = tail call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
%.thread_id_x = tail call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
%a15 = shl nsw i32 %"sin_approx$4.s0.i.v10.block_id_x", 8
%b17 = add nsw i32 %"sin_approx$4.extent.0", -256
%"sin_approx$4.s0.i.v11.base.s" = tail call i32 @llvm.smin.i32(i32 %a15, i32 %b17)
%0 = add i32 %"sin_approx$4.s0.i.v11.base.s", %.thread_id_x
%1 = add i32 %0, %"sin_approx$4.min.0"
%2 = sitofp i32 %1 to float
%t118 = fmul float %2, 0x3EB0000000000000
%3 = fsub float 1.000000e+00, %t118
%4 = fmul float %3, 0x3FF0BD0820000000
%5 = fmul float %t118, 0x3FF0BD0820000000
%6 = fsub float %5, %4
%7 = fmul float %6, 0x3FE45F3060000000
%8 = tail call float @llvm.nvvm.floor.ftz.f(float %7)
%9 = fptosi float %8 to i32
%t121 = and i32 %9, 2
%10 = fmul float %8, 0x3FF921FB60000000
%t122 = fsub float %6, %10
%11 = and i32 %9, 1
%.not = icmp eq i32 %11, 0
%12 = fsub float 0x3FF921FB60000000, %t122
%t123 = select i1 %.not, float %t122, float %12
%t124 = fmul float %t123, %t123
%13 = fmul float %t124, 0x3F28395C00000000
%14 = fsub float 0x3F8105A920000000, %13
%15 = fmul float %t124, %14
%16 = fadd float %15, 0xBFC554F4A0000000
%17 = fmul float %t124, %16
%18 = fadd float %17, 0x3FEFFFFE00000000
%t125 = fmul float %t123, %18
%.not1 = icmp eq i32 %t121, 0
%19 = fsub float 0.000000e+00, %t125
%20 = select i1 %.not1, float %t125, float %19
%21 = sext i32 %0 to i64
%22 = getelementptr inbounds float, ptr addrspace(1) %"sin_approx$42", i64 %21
store float %20, ptr addrspace(1) %22, align 4, !tbaa !6
ret void
}
; Function Attrs: nocallback nofree nosync nounwind speculatable willreturn memory(none)
declare i32 @llvm.smin.i32(i32 %0, i32 %1) #2
attributes #0 = { mustprogress nocallback nofree nosync nounwind speculatable willreturn memory(none) "nvptx-f32ftz"="true" }
attributes #1 = { mustprogress nofree norecurse nosync nounwind willreturn memory(argmem: write) "nvptx-f32ftz"="true" "reciprocal-estimates"="none" "target-cpu" "target-features" "tune-cpu" }
attributes #2 = { nocallback nofree nosync nounwind speculatable willreturn memory(none) }
!nvvm.internalize.after.link = !{}
!nvvm.annotations = !{!0, !1, !0, !2, !2, !2, !2, !3, !3, !2, !4}
!llvm.module.flags = !{!5}
!0 = !{null, !"align", i32 8}
!1 = !{null, !"align", i32 8, !"align", i32 65544, !"align", i32 131080}
!2 = !{null, !"align", i32 16}
!3 = !{null, !"align", i32 16, !"align", i32 65552, !"align", i32 131088}
!4 = !{ptr @_kernel_sin_approx__4_s0_i_v10_block_id_x, !"kernel", i32 1}
!5 = !{i32 4, !"nvvm-reflect-ftz", i32 1}
!6 = !{!7, !7, i64 0}
!7 = !{!"sin_approx$4", !8, i64 0}
!8 = !{!"Halide buffer"} And a PTX kernel that looks like this:
I'm stilling not getting matched accuracy as on CPU target. Am I missing some flags or details...? I notice the use of the fma (fused multiply add), but that should be more precise than a sequence of multiply and add. Correct? Polynomials were optimized using double precision, so I expect the FMA to improve the result, not degrade. |
I have one more idea to try. If that fails, then my GPU is not IEEE compliant or something... I seemingly get exactly the same results on Vulkan as I get on CUDA. Will report when I had the chance to try that. |
Update: I had wrapped the arguments to my transcendentals in float lower = -10.0f, upper = 10.0f;
int steps = 10000;
Expr t = i / float(steps);
Expr arg = strict_float(lower * (1.0f - t) + upper * t);
Expr result = fast_sin(arg); This was intended to avoid variable arguments on different backends, due to differences in optimization. While I think my code is correct, and the idea is correct, this failed. I do get different results on CUDA then on CPU. Precalculating the inputs Either way, this on it's own seems a strict_float related issue, which I didn't debug yet. I'm now back to working on the PR for transcendentals. |
It seems that the The issue I first noticed (original post on top of this issue page) arises from the fact that I generate new Overall, I think the way To summarize, I think we still have a bug when we do this: Var x;
Expr t = x / 10304.0f;
Expr t2 = t * t;
Expr t3 = t2 * t;
Func f;
f(x) = strict_float(t3 + t2); I would expect the whole func to be computed as under strict float conditions; and not just the final addition. Right now CSE will lift /** Makes a best effort attempt to preserve IEEE floating-point
* semantics in evaluating an expression. May not be implemented for
* all backends. (E.g. it is difficult to do this for C++ code
* generation as it depends on the compiler flags used to compile the
* generated code. */ Which suggests my code example above should indeed be fully compiled under strict-float conditions. Perhaps we need a What I'm advocating for I think, is that |
Yeah, strict_float is problematic to keep around in the IR. I think instead of strict_float, we should have pure math intrinsics for strict floating point multiplication, addition, subtract, division, and anything else that rounds. The simplifier will not have rules for these intrinsics, so it won't mess with them. Instead of being an intrinsic, strict_float will be a function that takes an Expr and mutates it, replacing all rounding floating point ops with the equivalent strict math intrinsics. |
I think there is an important nuance to this. What I described would only be problematic when the This wouldn't be true for |
Debugging precision errors on my transcendentals revealed during lowering this behavior:
This causes LLVM IR to be produced that looks like this (a whole lot further down):
Notice the
reassoc
,contract
andafn
flags, which is basically the "fast-math" flags: https://llvm.org/docs/LangRef.html#rewrite-based-flagsAlso notice that the issue is introduced after the CUDA-specific "Injecting warp shuffles..." pass.
For reference, final simplification before device codegen:
Offending code:
Halide/src/CSE.cpp
Lines 317 to 344 in b937537
I'm not seeing this issue on x86_64...
The text was updated successfully, but these errors were encountered: