Skip to content

Commit

Permalink
[naga spv-out] Ensure loops generated by SPIRV backend are bounded
Browse files Browse the repository at this point in the history
If it is undefined behaviour for loops to be infinite, then, when
encountering an infinite loop, downstream compilers are able to make
certain optimizations that may be unsafe. For example, omitting bounds
checks. To prevent this, we must ensure that any loops emitted by our
backends are provably bounded. We already do this for both the MSL and
HLSL backends. This patch makes us do so for SPIRV as well.

The construct used is the same as for HLSL and MSL backends: use a
vec2<u32> to emulate a 64-bit counter, which is incremented every
iteration and breaks after 2^64 iterations.

While the implementation is fairly verbose for the SPIRV backend, the
logic is simple enough. The one point of note is that SPIRV requires
`OpVariable` instructions with a `Function` storage class to be
located at the start of the first block of the function. We therefore
remember the IDs generated for each loop counter variable in a
function whilst generating the function body's code. The instructions
to declare these variables are then emitted in `Function::to_words()`
prior to emitting the function's body.

As this may negatively impact shader performance, this workaround can
be disabled using the same mechanism as for other backends: eg calling
Device::create_shader_module_trusted() and setting the
ShaderRuntimeChecks::force_loop_bounding flag to false.
  • Loading branch information
jamienicol committed Feb 25, 2025
1 parent e0f0185 commit 67c18e8
Show file tree
Hide file tree
Showing 20 changed files with 2,989 additions and 2,269 deletions.
11 changes: 11 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,17 @@ pub enum PollError {
By @cwfitzgerald in [#6942](https://github.com/gfx-rs/wgpu/pull/6942).
By @cwfitzgerald in [#7030](https://github.com/gfx-rs/wgpu/pull/7030).
#### Naga
##### Ensure loops generated by SPIR-V and HLSL Naga backends are bounded
Make sure that all loops in shaders generated by these naga backends are bounded
to avoid undefined behaviour due to infinite loops. Note that this may have a
performance cost. As with the existing implementation for the MSL backend this
can be disabled by using `Device::create_shader_module_trusted()`.
By @jamienicol in [#6929](https://github.com/gfx-rs/wgpu/pull/6929) and [#7080](https://github.com/gfx-rs/wgpu/pull/7080).
### New Features
#### General
Expand Down
153 changes: 153 additions & 0 deletions naga/src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,155 @@ impl Writer {
}

impl BlockContext<'_> {
/// Generates code to ensure that a loop is bounded. Should be called immediately
/// after adding the OpLoopMerge instruction to `block`. This function will
/// [`consume()`](crate::back::spv::Function::consume) `block` and append its
/// instructions to a new [`Block`], which will be returned to the caller for it to
/// consumed prior to writing the loop body.
///
/// Additionally this function will populate [`force_loop_bounding_vars`](crate::back::spv::Function::force_loop_bounding_vars),
/// ensuring that [`Function::to_words()`](crate::back::spv::Function::to_words) will
/// declare the required variables.
///
/// See [`crate::back::msl::Writer::gen_force_bounded_loop_statements`] for details
/// of why this is required.
fn write_force_bounded_loop_instructions(&mut self, mut block: Block, merge_id: Word) -> Block {
let uint_type_id = self.writer.get_uint_type_id();
let uint2_type_id = self.writer.get_uint2_type_id();
let uint2_ptr_type_id = self
.writer
.get_uint2_pointer_type_id(spirv::StorageClass::Function);
let bool_type_id = self.writer.get_bool_type_id();
let bool2_type_id = self.writer.get_bool2_type_id();
let zero_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(0));
let zero_uint2_const_id = self.writer.get_constant_composite(
LookupType::Local(LocalType::Numeric(NumericType::Vector {
size: crate::VectorSize::Bi,
scalar: crate::Scalar::U32,
})),
&[zero_uint_const_id, zero_uint_const_id],
);
let one_uint_const_id = self.writer.get_constant_scalar(crate::Literal::U32(1));
let max_uint_const_id = self
.writer
.get_constant_scalar(crate::Literal::U32(u32::MAX));
let max_uint2_const_id = self.writer.get_constant_composite(
LookupType::Local(LocalType::Numeric(NumericType::Vector {
size: crate::VectorSize::Bi,
scalar: crate::Scalar::U32,
})),
&[max_uint_const_id, max_uint_const_id],
);

let loop_counter_var_id = self.gen_id();
if self.writer.flags.contains(WriterFlags::DEBUG) {
self.writer
.debugs
.push(Instruction::name(loop_counter_var_id, "loop_bound"));
}
let var = super::LocalVariable {
id: loop_counter_var_id,
instruction: Instruction::variable(
uint2_ptr_type_id,
loop_counter_var_id,
spirv::StorageClass::Function,
Some(zero_uint2_const_id),
),
};
self.function.force_loop_bounding_vars.push(var);

let break_if_block = self.gen_id();

self.function
.consume(block, Instruction::branch(break_if_block));
block = Block::new(break_if_block);

// Load the current loop counter value from its variable. We use a vec2<u32> to
// simulate a 64-bit counter.
let load_id = self.gen_id();
block.body.push(Instruction::load(
uint2_type_id,
load_id,
loop_counter_var_id,
None,
));

// If both the high and low u32s have reached u32::MAX then break. ie
// if (all(eq(loop_counter, vec2(u32::MAX)))) { break; }
let eq_id = self.gen_id();
block.body.push(Instruction::binary(
spirv::Op::IEqual,
bool2_type_id,
eq_id,
max_uint2_const_id,
load_id,
));
let all_eq_id = self.gen_id();
block.body.push(Instruction::relational(
spirv::Op::All,
bool_type_id,
all_eq_id,
eq_id,
));

let inc_counter_block_id = self.gen_id();
block.body.push(Instruction::selection_merge(
inc_counter_block_id,
spirv::SelectionControl::empty(),
));
self.function.consume(
block,
Instruction::branch_conditional(all_eq_id, merge_id, inc_counter_block_id),
);
block = Block::new(inc_counter_block_id);

// To simulate a 64-bit counter we always increment the low u32, and increment
// the high u32 when the low u32 overflows. ie
// counter += vec2(select(0u, 1u, counter.y == u32::MAX), 1u);
let low_id = self.gen_id();
block.body.push(Instruction::composite_extract(
uint_type_id,
low_id,
load_id,
&[1],
));
let low_overflow_id = self.gen_id();
block.body.push(Instruction::binary(
spirv::Op::IEqual,
bool_type_id,
low_overflow_id,
low_id,
max_uint_const_id,
));
let carry_bit_id = self.gen_id();
block.body.push(Instruction::select(
uint_type_id,
carry_bit_id,
low_overflow_id,
one_uint_const_id,
zero_uint_const_id,
));
let increment_id = self.gen_id();
block.body.push(Instruction::composite_construct(
uint2_type_id,
increment_id,
&[carry_bit_id, one_uint_const_id],
));
let result_id = self.gen_id();
block.body.push(Instruction::binary(
spirv::Op::IAdd,
uint2_type_id,
result_id,
load_id,
increment_id,
));
block
.body
.push(Instruction::store(loop_counter_var_id, result_id, None));

block
}

/// Cache an expression for a value.
pub(super) fn cache_expression_value(
&mut self,
Expand Down Expand Up @@ -2558,6 +2707,10 @@ impl BlockContext<'_> {
continuing_id,
spirv::SelectionControl::NONE,
));

if self.force_loop_bounding {
block = self.write_force_bounded_loop_instructions(block, merge_id);
}
self.function.consume(block, Instruction::branch(body_id));

// We can ignore the `BlockExitDisposition` returned here because,
Expand Down
10 changes: 10 additions & 0 deletions naga/src/back/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,8 @@ struct Function {
signature: Option<Instruction>,
parameters: Vec<FunctionArgument>,
variables: crate::FastHashMap<Handle<crate::LocalVariable>, LocalVariable>,
/// List of local variables used as a counters to ensure that all loops are bounded.
force_loop_bounding_vars: Vec<LocalVariable>,

/// A map taking an expression that yields a composite value (array, matrix)
/// to the temporary variables we have spilled it to, if any. Spilling
Expand Down Expand Up @@ -726,6 +728,8 @@ struct BlockContext<'w> {

/// Tracks the constness of `Expression`s residing in `self.ir_function.expressions`
expression_constness: ExpressionConstnessTracker,

force_loop_bounding: bool,
}

impl BlockContext<'_> {
Expand Down Expand Up @@ -779,6 +783,7 @@ pub struct Writer {
flags: WriterFlags,
bounds_check_policies: BoundsCheckPolicies,
zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode,
force_loop_bounding: bool,
void_type: Word,
//TODO: convert most of these into vectors, addressable by handle indices
lookup_type: crate::FastHashMap<LookupType, Word>,
Expand Down Expand Up @@ -882,6 +887,10 @@ pub struct Options<'a> {
/// Dictates the way workgroup variables should be zero initialized
pub zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode,

/// If set, loops will have code injected into them, forcing the compiler
/// to think the number of iterations is bounded.
pub force_loop_bounding: bool,

pub debug_info: Option<DebugInfo<'a>>,
}

Expand All @@ -900,6 +909,7 @@ impl Default for Options<'_> {
capabilities: None,
bounds_check_policies: BoundsCheckPolicies::default(),
zero_initialize_workgroup_memory: ZeroInitializeWorkgroupMemoryMode::Polyfill,
force_loop_bounding: true,
debug_info: None,
}
}
Expand Down
33 changes: 33 additions & 0 deletions naga/src/back/spv/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ impl Function {
for local_var in self.variables.values() {
local_var.instruction.to_words(sink);
}
for local_var in self.force_loop_bounding_vars.iter() {
local_var.instruction.to_words(sink);
}
for internal_var in self.spilled_composites.values() {
internal_var.instruction.to_words(sink);
}
Expand Down Expand Up @@ -71,6 +74,7 @@ impl Writer {
flags: options.flags,
bounds_check_policies: options.bounds_check_policies,
zero_initialize_workgroup_memory: options.zero_initialize_workgroup_memory,
force_loop_bounding: options.force_loop_bounding,
void_type,
lookup_type: crate::FastHashMap::default(),
lookup_function: crate::FastHashMap::default(),
Expand Down Expand Up @@ -111,6 +115,7 @@ impl Writer {
flags: self.flags,
bounds_check_policies: self.bounds_check_policies,
zero_initialize_workgroup_memory: self.zero_initialize_workgroup_memory,
force_loop_bounding: self.force_loop_bounding,
capabilities_available: take(&mut self.capabilities_available),
binding_map: take(&mut self.binding_map),

Expand Down Expand Up @@ -267,6 +272,14 @@ impl Writer {
self.get_type_id(local_type.into())
}

pub(super) fn get_uint2_type_id(&mut self) -> Word {
let local_type = LocalType::Numeric(NumericType::Vector {
size: crate::VectorSize::Bi,
scalar: crate::Scalar::U32,
});
self.get_type_id(local_type.into())
}

pub(super) fn get_uint3_type_id(&mut self) -> Word {
let local_type = LocalType::Numeric(NumericType::Vector {
size: crate::VectorSize::Tri,
Expand All @@ -283,6 +296,17 @@ impl Writer {
self.get_type_id(local_type.into())
}

pub(super) fn get_uint2_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
let local_type = LocalType::LocalPointer {
base: NumericType::Vector {
size: crate::VectorSize::Bi,
scalar: crate::Scalar::U32,
},
class,
};
self.get_type_id(local_type.into())
}

pub(super) fn get_uint3_pointer_type_id(&mut self, class: spirv::StorageClass) -> Word {
let local_type = LocalType::LocalPointer {
base: NumericType::Vector {
Expand All @@ -299,6 +323,14 @@ impl Writer {
self.get_type_id(local_type.into())
}

pub(super) fn get_bool2_type_id(&mut self) -> Word {
let local_type = LocalType::Numeric(NumericType::Vector {
size: crate::VectorSize::Bi,
scalar: crate::Scalar::BOOL,
});
self.get_type_id(local_type.into())
}

pub(super) fn get_bool3_type_id(&mut self) -> Word {
let local_type = LocalType::Numeric(NumericType::Vector {
size: crate::VectorSize::Tri,
Expand Down Expand Up @@ -839,6 +871,7 @@ impl Writer {

// Steal the Writer's temp list for a bit.
temp_list: std::mem::take(&mut self.temp_list),
force_loop_bounding: self.force_loop_bounding,
writer: self,
expression_constness: super::ExpressionConstnessTracker::from_arena(
&ir_function.expressions,
Expand Down
50 changes: 37 additions & 13 deletions naga/tests/out/spv/6220-break-from-loop.spvasm
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
; SPIR-V
; Version: 1.1
; Generator: rspirv
; Bound: 26
; Bound: 46
OpCapability Shader
OpCapability Linkage
%1 = OpExtInstImport "GLSL.std.450"
Expand All @@ -13,31 +13,55 @@ OpMemoryModel Logical GLSL450
%8 = OpConstant %3 4
%9 = OpConstant %3 1
%11 = OpTypePointer Function %3
%18 = OpTypeBool
%17 = OpTypeInt 32 0
%18 = OpTypeVector %17 2
%19 = OpTypePointer Function %18
%20 = OpTypeBool
%21 = OpTypeVector %20 2
%22 = OpConstant %17 0
%23 = OpConstantComposite %18 %22 %22
%24 = OpConstant %17 1
%25 = OpConstant %17 4294967295
%26 = OpConstantComposite %18 %25 %25
%5 = OpFunction %2 None %6
%4 = OpLabel
%10 = OpVariable %11 Function %7
%27 = OpVariable %19 Function %23
OpBranch %12
%12 = OpLabel
OpBranch %13
%13 = OpLabel
OpLoopMerge %14 %16 None
OpBranch %28
%28 = OpLabel
%29 = OpLoad %18 %27
%30 = OpIEqual %21 %26 %29
%31 = OpAll %20 %30
OpSelectionMerge %32 None
OpBranchConditional %31 %14 %32
%32 = OpLabel
%33 = OpCompositeExtract %17 %29 1
%34 = OpIEqual %20 %33 %25
%35 = OpSelect %17 %34 %24 %22
%36 = OpCompositeConstruct %18 %35 %24
%37 = OpIAdd %18 %29 %36
OpStore %27 %37
OpBranch %15
%15 = OpLabel
%17 = OpLoad %3 %10
%19 = OpSLessThan %18 %17 %8
OpSelectionMerge %20 None
OpBranchConditional %19 %20 %21
%21 = OpLabel
%38 = OpLoad %3 %10
%39 = OpSLessThan %20 %38 %8
OpSelectionMerge %40 None
OpBranchConditional %39 %40 %41
%41 = OpLabel
OpBranch %14
%20 = OpLabel
OpBranch %22
%22 = OpLabel
%40 = OpLabel
OpBranch %42
%42 = OpLabel
OpBranch %14
%16 = OpLabel
%24 = OpLoad %3 %10
%25 = OpIAdd %3 %24 %9
OpStore %10 %25
%44 = OpLoad %3 %10
%45 = OpIAdd %3 %44 %9
OpStore %10 %45
OpBranch %13
%14 = OpLabel
OpReturn
Expand Down
Loading

0 comments on commit 67c18e8

Please sign in to comment.