Skip to content

Commit

Permalink
[naga wgsl-in] Do not eagerly concretize global const declarations of…
Browse files Browse the repository at this point in the history
… abstract types

Instead allow the const to be converted and each time it is const
evaluated as part of another expression. This allows an abstract const
to be used as a different type depending on the context.

A consequence of this is that abstract types may now find their way in
to the IR, which we don't want. We therefore additionally now ensure
that the compact pass removes global constants of abstract types. This
will have no *functional* effect on shaders generated by the backends,
as the expressions belonging to the abstract consts in the IR will not
actually be used, as any usage in the input shader will have been
const-evaluated away. Certain unused const declarations will now be
removed, however, as can be seen by the effect on the snapshot
outputs.
  • Loading branch information
jamienicol committed Feb 12, 2025
1 parent 02de75c commit efe6402
Show file tree
Hide file tree
Showing 19 changed files with 175 additions and 243 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ By @brodycj in [#6924](https://github.com/gfx-rs/wgpu/pull/6924).
- Fix some instances of functions which have a return type but don't return a value being incorrectly validated. By @jamienicol in [#7013](https://github.com/gfx-rs/wgpu/pull/7013).
- Allow abstract expressions to be used in WGSL function return statements. By @jamienicol in [#7035](https://github.com/gfx-rs/wgpu/pull/7035).
- Error if structs have two fields with the same name. By @SparkyPotato in [#7088](https://github.com/gfx-rs/wgpu/pull/7088).
- Allow WGSL const declarations to have abstract types. By @jamienicol in [#7055](https://github.com/gfx-rs/wgpu/pull/7055).

#### General

Expand Down
17 changes: 14 additions & 3 deletions naga/src/compact/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,24 @@ pub fn compact(module: &mut crate::Module) {
log::trace!("tracing special types");
module_tracer.trace_special_types(&module.special_types);

// We treat all named constants as used by definition.
// We treat all named constants as used by definition, unless they have an
// abstract type as we do not want those reaching the IR.
log::trace!("tracing named constants");
for (handle, constant) in module.constants.iter() {
if constant.name.is_some() {
log::trace!("tracing constant {:?}", constant.name.as_ref().unwrap());
module_tracer.constants_used.insert(handle);
module_tracer.global_expressions_used.insert(constant.init);
let mut ty = constant.ty;
while let crate::TypeInner::Array { base, .. } = module.types[ty].inner {
ty = base;
}
if !module.types[ty]
.inner
.scalar()
.is_some_and(|s| s.is_abstract())
{
module_tracer.constants_used.insert(handle);
module_tracer.global_expressions_used.insert(constant.init);
}
}
}

Expand Down
28 changes: 19 additions & 9 deletions naga/src/front/wgsl/lower/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1072,8 +1072,13 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
v.ty.map(|ast| self.resolve_ast_type(ast, &mut ctx))
.transpose()?;

let (ty, initializer) =
self.type_and_init(v.name, v.init, explicit_ty, &mut ctx.as_override())?;
let (ty, initializer) = self.type_and_init(
v.name,
v.init,
explicit_ty,
false,
&mut ctx.as_override(),
)?;

let binding = if let Some(ref binding) = v.binding {
Some(crate::ResourceBinding {
Expand Down Expand Up @@ -1106,7 +1111,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.transpose()?;

let (ty, init) =
self.type_and_init(c.name, Some(c.init), explicit_ty, &mut ectx)?;
self.type_and_init(c.name, Some(c.init), explicit_ty, true, &mut ectx)?;
let init = init.expect("Global const must have init");

let handle = ctx.module.constants.append(
Expand All @@ -1128,7 +1133,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {

let mut ectx = ctx.as_override();

let (ty, init) = self.type_and_init(o.name, o.init, explicit_ty, &mut ectx)?;
let (ty, init) =
self.type_and_init(o.name, o.init, explicit_ty, false, &mut ectx)?;

let id =
o.id.map(|id| self.const_u32(id, &mut ctx.as_const()))
Expand Down Expand Up @@ -1201,6 +1207,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
name: ast::Ident<'source>,
init: Option<Handle<ast::Expression<'source>>>,
explicit_ty: Option<Handle<crate::Type>>,
allow_abstract: bool,
ectx: &mut ExpressionContext<'source, '_, '_>,
) -> Result<(Handle<crate::Type>, Option<Handle<crate::Expression>>), Error<'source>> {
let ty;
Expand Down Expand Up @@ -1234,9 +1241,12 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
initializer = Some(init);
}
(Some(init), None) => {
let concretized = self.expression(init, ectx)?;
ty = ectx.register_type(concretized)?;
initializer = Some(concretized);
let mut init = self.expression_for_abstract(init, ectx)?;
if !allow_abstract {
init = ectx.concretize(init)?;
}
ty = ectx.register_type(init)?;
initializer = Some(init);
}
(None, Some(explicit_ty)) => {
ty = explicit_ty;
Expand Down Expand Up @@ -1477,7 +1487,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
emitter.start(&ctx.function.expressions);
let mut ectx = ctx.as_expression(block, &mut emitter);
let (ty, initializer) =
self.type_and_init(v.name, v.init, explicit_ty, &mut ectx)?;
self.type_and_init(v.name, v.init, explicit_ty, false, &mut ectx)?;

let (const_initializer, initializer) = {
match initializer {
Expand Down Expand Up @@ -1537,7 +1547,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.transpose()?;

let (_ty, init) =
self.type_and_init(c.name, Some(c.init), explicit_ty, ectx)?;
self.type_and_init(c.name, Some(c.init), explicit_ty, false, ectx)?;
let init = init.expect("Local const must have init");

block.extend(emitter.finish(&ctx.function.expressions));
Expand Down
2 changes: 0 additions & 2 deletions naga/tests/out/glsl/constructors.main.Compute.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ struct Foo {
vec4 a;
int b;
};
const vec3 const2_ = vec3(0.0, 1.0, 2.0);
const mat2x2 const3_ = mat2x2(vec2(0.0, 1.0), vec2(2.0, 3.0));
const mat2x2 const4_[1] = mat2x2[1](mat2x2(vec2(0.0, 1.0), vec2(2.0, 3.0)));
const bool cz0_ = false;
Expand All @@ -20,7 +19,6 @@ const uvec2 cz4_ = uvec2(0u);
const mat2x2 cz5_ = mat2x2(0.0);
const Foo cz6_[3] = Foo[3](Foo(vec4(0.0), 0), Foo(vec4(0.0), 0), Foo(vec4(0.0), 0));
const Foo cz7_ = Foo(vec4(0.0), 0);
const int cp3_[4] = int[4](0, 1, 2, 3);


void main() {
Expand Down
14 changes: 6 additions & 8 deletions naga/tests/out/hlsl/constructors.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,6 @@ ret_Constructarray1_float2x2_ Constructarray1_float2x2_(float2x2 arg0) {
return ret;
}

typedef int ret_Constructarray4_int_[4];
ret_Constructarray4_int_ Constructarray4_int_(int arg0, int arg1, int arg2, int arg3) {
int ret[4] = { arg0, arg1, arg2, arg3 };
return ret;
}

bool ZeroValuebool() {
return (bool)0;
}
Expand Down Expand Up @@ -51,7 +45,6 @@ Foo ZeroValueFoo() {
return (Foo)0;
}

static const float3 const2_ = float3(0.0, 1.0, 2.0);
static const float2x2 const3_ = float2x2(float2(0.0, 1.0), float2(2.0, 3.0));
static const float2x2 const4_[1] = Constructarray1_float2x2_(float2x2(float2(0.0, 1.0), float2(2.0, 3.0)));
static const bool cz0_ = ZeroValuebool();
Expand All @@ -62,7 +55,6 @@ static const uint2 cz4_ = ZeroValueuint2();
static const float2x2 cz5_ = ZeroValuefloat2x2();
static const Foo cz6_[3] = ZeroValuearray3_Foo_();
static const Foo cz7_ = ZeroValueFoo();
static const int cp3_[4] = Constructarray4_int_(0, 1, 2, 3);

Foo ConstructFoo(float4 arg0, int arg1) {
Foo ret = (Foo)0;
Expand All @@ -71,6 +63,12 @@ Foo ConstructFoo(float4 arg0, int arg1) {
return ret;
}

typedef int ret_Constructarray4_int_[4];
ret_Constructarray4_int_ Constructarray4_int_(int arg0, int arg1, int arg2, int arg3) {
int ret[4] = { arg0, arg1, arg2, arg3 };
return ret;
}

float2x3 ZeroValuefloat2x3() {
return (float2x3)0;
}
Expand Down
28 changes: 3 additions & 25 deletions naga/tests/out/ir/const_assert.compact.ron
Original file line number Diff line number Diff line change
@@ -1,36 +1,14 @@
(
types: [
(
name: None,
inner: Scalar((
kind: Sint,
width: 4,
)),
),
],
types: [],
special_types: (
ray_desc: None,
ray_intersection: None,
predeclared_types: {},
),
constants: [
(
name: Some("x"),
ty: 0,
init: 0,
),
(
name: Some("y"),
ty: 0,
init: 1,
),
],
constants: [],
overrides: [],
global_variables: [],
global_expressions: [
Literal(I32(1)),
Literal(I32(2)),
],
global_expressions: [],
functions: [
(
name: Some("foo"),
Expand Down
28 changes: 3 additions & 25 deletions naga/tests/out/ir/const_assert.ron
Original file line number Diff line number Diff line change
@@ -1,36 +1,14 @@
(
types: [
(
name: None,
inner: Scalar((
kind: Sint,
width: 4,
)),
),
],
types: [],
special_types: (
ray_desc: None,
ray_intersection: None,
predeclared_types: {},
),
constants: [
(
name: Some("x"),
ty: 0,
init: 0,
),
(
name: Some("y"),
ty: 0,
init: 1,
),
],
constants: [],
overrides: [],
global_variables: [],
global_expressions: [
Literal(I32(1)),
Literal(I32(2)),
],
global_expressions: [],
functions: [
(
name: Some("foo"),
Expand Down
59 changes: 26 additions & 33 deletions naga/tests/out/ir/local-const.compact.ron
Original file line number Diff line number Diff line change
Expand Up @@ -38,53 +38,28 @@
predeclared_types: {},
),
constants: [
(
name: Some("ga"),
ty: 0,
init: 0,
),
(
name: Some("gb"),
ty: 0,
init: 1,
init: 0,
),
(
name: Some("gc"),
ty: 1,
init: 2,
init: 1,
),
(
name: Some("gd"),
ty: 2,
init: 3,
),
(
name: Some("ge"),
ty: 3,
init: 4,
),
(
name: Some("gf"),
ty: 2,
init: 5,
init: 2,
),
],
overrides: [],
global_variables: [],
global_expressions: [
Literal(I32(4)),
Literal(I32(4)),
Literal(U32(4)),
Literal(F32(4.0)),
Compose(
ty: 3,
components: [
0,
0,
0,
],
),
Literal(F32(2.0)),
],
functions: [
(
Expand All @@ -106,12 +81,22 @@
],
),
Literal(F32(2.0)),
Literal(I32(4)),
Constant(0),
Constant(1),
Constant(2),
Constant(3),
Constant(4),
Constant(5),
Literal(I32(4)),
Literal(I32(4)),
Literal(I32(4)),
Compose(
ty: 3,
components: [
10,
11,
12,
],
),
Literal(F32(2.0)),
],
named_expressions: {
0: "a",
Expand All @@ -124,14 +109,22 @@
7: "bg",
8: "cg",
9: "dg",
10: "eg",
11: "fg",
13: "eg",
14: "fg",
},
body: [
Emit((
start: 4,
end: 5,
)),
Emit((
start: 0,
end: 0,
)),
Emit((
start: 13,
end: 14,
)),
Return(
value: None,
),
Expand Down
Loading

0 comments on commit efe6402

Please sign in to comment.