From c07fab2c1231a907a0b5bf98fb739085f8dd5417 Mon Sep 17 00:00:00 2001 From: Jamie Nicol Date: Thu, 30 Jan 2025 16:28:21 +0000 Subject: [PATCH] [naga wgsl-in] Allow abstract literals to be used as return values When lowering a return statement, call expression_for_abstract() rather than expression() to avoid concretizing the return value. Then, if the function has a return type, call try_automatic_conversions() to attempt to convert our return value to the correct type. This has the unfortunate side effect that some errors that would have been caught by the validator are instead encountered as conversion errors by the parser. This may result in a slightly less descriptive error message in some cases. (See the change to the invalid_functions() test, for example.) --- CHANGELOG.md | 1 + naga/src/front/wgsl/lower/mod.rs | 23 ++++-- naga/tests/in/abstract-types-return.wgsl | 26 +++++++ .../abstract-types-return.main.Compute.glsl | 36 ++++++++++ .../tests/out/hlsl/abstract-types-return.hlsl | 42 +++++++++++ naga/tests/out/hlsl/abstract-types-return.ron | 12 ++++ naga/tests/out/msl/abstract-types-return.msl | 44 ++++++++++++ .../out/spv/abstract-types-return.spvasm | 70 +++++++++++++++++++ .../tests/out/wgsl/abstract-types-return.wgsl | 28 ++++++++ naga/tests/snapshots.rs | 4 ++ naga/tests/wgsl_errors.rs | 19 ++--- 11 files changed, 292 insertions(+), 13 deletions(-) create mode 100644 naga/tests/in/abstract-types-return.wgsl create mode 100644 naga/tests/out/glsl/abstract-types-return.main.Compute.glsl create mode 100644 naga/tests/out/hlsl/abstract-types-return.hlsl create mode 100644 naga/tests/out/hlsl/abstract-types-return.ron create mode 100644 naga/tests/out/msl/abstract-types-return.msl create mode 100644 naga/tests/out/spv/abstract-types-return.spvasm create mode 100644 naga/tests/out/wgsl/abstract-types-return.wgsl diff --git a/CHANGELOG.md b/CHANGELOG.md index 52af51bf8d..6101c7e5bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -73,6 +73,7 @@ By @brodycj in [#6924](https://github.com/gfx-rs/wgpu/pull/6924). #### Naga - Fix some instances of functions which have a return type but don't return a value being incorrectly validated. By @jamienicol in [#7013](https://github.com/gfx-rs/wgpu/pull/7013). +- Allow abstract expressions to be used in WGSL function return statements. By @jamienicol in [#7035](https://github.com/gfx-rs/wgpu/pull/7035). #### General diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index d25eb362c1..e405d3e56f 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -1672,13 +1672,28 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { } ast::StatementKind::Break => crate::Statement::Break, ast::StatementKind::Continue => crate::Statement::Continue, - ast::StatementKind::Return { value } => { + ast::StatementKind::Return { value: ast_value } => { let mut emitter = Emitter::default(); emitter.start(&ctx.function.expressions); - let value = value - .map(|expr| self.expression(expr, &mut ctx.as_expression(block, &mut emitter))) - .transpose()?; + let value; + if let Some(ast_expr) = ast_value { + let result_ty = ctx.function.result.as_ref().map(|r| r.ty); + let mut ectx = ctx.as_expression(block, &mut emitter); + let expr = self.expression_for_abstract(ast_expr, &mut ectx)?; + + if let Some(result_ty) = result_ty { + let mut ectx = ctx.as_expression(block, &mut emitter); + let resolution = crate::proc::TypeResolution::Handle(result_ty); + let converted = + ectx.try_automatic_conversions(expr, &resolution, Span::default())?; + value = Some(converted); + } else { + value = Some(expr); + } + } else { + value = None; + } block.extend(emitter.finish(&ctx.function.expressions)); crate::Statement::Return { value } diff --git a/naga/tests/in/abstract-types-return.wgsl b/naga/tests/in/abstract-types-return.wgsl new file mode 100644 index 0000000000..da77c129a0 --- /dev/null +++ b/naga/tests/in/abstract-types-return.wgsl @@ -0,0 +1,26 @@ +@compute @workgroup_size(1) +fn main() {} + +fn return_i32_ai() -> i32 { + return 1; +} + +fn return_u32_ai() -> u32 { + return 1; +} + +fn return_f32_ai() -> f32 { + return 1; +} + +fn return_f32_af() -> f32 { + return 1.0; +} + +fn return_vec2f32_ai() -> vec2 { + return vec2(1); +} + +fn return_arrf32_ai() -> array { + return array(1, 1, 1, 1); +} diff --git a/naga/tests/out/glsl/abstract-types-return.main.Compute.glsl b/naga/tests/out/glsl/abstract-types-return.main.Compute.glsl new file mode 100644 index 0000000000..07c75d7c51 --- /dev/null +++ b/naga/tests/out/glsl/abstract-types-return.main.Compute.glsl @@ -0,0 +1,36 @@ +#version 310 es + +precision highp float; +precision highp int; + +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; + + +int return_i32_ai() { + return 1; +} + +uint return_u32_ai() { + return 1u; +} + +float return_f32_ai() { + return 1.0; +} + +float return_f32_af() { + return 1.0; +} + +vec2 return_vec2f32_ai() { + return vec2(1.0); +} + +float[4] return_arrf32_ai() { + return float[4](1.0, 1.0, 1.0, 1.0); +} + +void main() { + return; +} + diff --git a/naga/tests/out/hlsl/abstract-types-return.hlsl b/naga/tests/out/hlsl/abstract-types-return.hlsl new file mode 100644 index 0000000000..fe2de29aee --- /dev/null +++ b/naga/tests/out/hlsl/abstract-types-return.hlsl @@ -0,0 +1,42 @@ +int return_i32_ai() +{ + return 1; +} + +uint return_u32_ai() +{ + return 1u; +} + +float return_f32_ai() +{ + return 1.0; +} + +float return_f32_af() +{ + return 1.0; +} + +float2 return_vec2f32_ai() +{ + return (1.0).xx; +} + +typedef float ret_Constructarray4_float_[4]; +ret_Constructarray4_float_ Constructarray4_float_(float arg0, float arg1, float arg2, float arg3) { + float ret[4] = { arg0, arg1, arg2, arg3 }; + return ret; +} + +typedef float ret_return_arrf32_ai[4]; +ret_return_arrf32_ai return_arrf32_ai() +{ + return Constructarray4_float_(1.0, 1.0, 1.0, 1.0); +} + +[numthreads(1, 1, 1)] +void main() +{ + return; +} diff --git a/naga/tests/out/hlsl/abstract-types-return.ron b/naga/tests/out/hlsl/abstract-types-return.ron new file mode 100644 index 0000000000..a07b03300b --- /dev/null +++ b/naga/tests/out/hlsl/abstract-types-return.ron @@ -0,0 +1,12 @@ +( + vertex:[ + ], + fragment:[ + ], + compute:[ + ( + entry_point:"main", + target_profile:"cs_5_1", + ), + ], +) diff --git a/naga/tests/out/msl/abstract-types-return.msl b/naga/tests/out/msl/abstract-types-return.msl new file mode 100644 index 0000000000..99072e7473 --- /dev/null +++ b/naga/tests/out/msl/abstract-types-return.msl @@ -0,0 +1,44 @@ +// language: metal1.0 +#include +#include + +using metal::uint; + +struct type_4 { + float inner[4]; +}; + +int return_i32_ai( +) { + return 1; +} + +uint return_u32_ai( +) { + return 1u; +} + +float return_f32_ai( +) { + return 1.0; +} + +float return_f32_af( +) { + return 1.0; +} + +metal::float2 return_vec2f32_ai( +) { + return metal::float2(1.0); +} + +type_4 return_arrf32_ai( +) { + return type_4 {1.0, 1.0, 1.0, 1.0}; +} + +kernel void main_( +) { + return; +} diff --git a/naga/tests/out/spv/abstract-types-return.spvasm b/naga/tests/out/spv/abstract-types-return.spvasm new file mode 100644 index 0000000000..8ce671adfc --- /dev/null +++ b/naga/tests/out/spv/abstract-types-return.spvasm @@ -0,0 +1,70 @@ +; SPIR-V +; Version: 1.1 +; Generator: rspirv +; Bound: 41 +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %38 "main" +OpExecutionMode %38 LocalSize 1 1 1 +OpDecorate %7 ArrayStride 4 +%2 = OpTypeVoid +%3 = OpTypeInt 32 1 +%4 = OpTypeInt 32 0 +%5 = OpTypeFloat 32 +%6 = OpTypeVector %5 2 +%8 = OpConstant %4 4 +%7 = OpTypeArray %5 %8 +%11 = OpTypeFunction %3 +%12 = OpConstant %3 1 +%16 = OpTypeFunction %4 +%17 = OpConstant %4 1 +%21 = OpTypeFunction %5 +%22 = OpConstant %5 1.0 +%29 = OpTypeFunction %6 +%30 = OpConstantComposite %6 %22 %22 +%34 = OpTypeFunction %7 +%35 = OpConstantComposite %7 %22 %22 %22 %22 +%39 = OpTypeFunction %2 +%10 = OpFunction %3 None %11 +%9 = OpLabel +OpBranch %13 +%13 = OpLabel +OpReturnValue %12 +OpFunctionEnd +%15 = OpFunction %4 None %16 +%14 = OpLabel +OpBranch %18 +%18 = OpLabel +OpReturnValue %17 +OpFunctionEnd +%20 = OpFunction %5 None %21 +%19 = OpLabel +OpBranch %23 +%23 = OpLabel +OpReturnValue %22 +OpFunctionEnd +%25 = OpFunction %5 None %21 +%24 = OpLabel +OpBranch %26 +%26 = OpLabel +OpReturnValue %22 +OpFunctionEnd +%28 = OpFunction %6 None %29 +%27 = OpLabel +OpBranch %31 +%31 = OpLabel +OpReturnValue %30 +OpFunctionEnd +%33 = OpFunction %7 None %34 +%32 = OpLabel +OpBranch %36 +%36 = OpLabel +OpReturnValue %35 +OpFunctionEnd +%38 = OpFunction %2 None %39 +%37 = OpLabel +OpBranch %40 +%40 = OpLabel +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/wgsl/abstract-types-return.wgsl b/naga/tests/out/wgsl/abstract-types-return.wgsl new file mode 100644 index 0000000000..c25a217dac --- /dev/null +++ b/naga/tests/out/wgsl/abstract-types-return.wgsl @@ -0,0 +1,28 @@ +fn return_i32_ai() -> i32 { + return 1i; +} + +fn return_u32_ai() -> u32 { + return 1u; +} + +fn return_f32_ai() -> f32 { + return 1f; +} + +fn return_f32_af() -> f32 { + return 1f; +} + +fn return_vec2f32_ai() -> vec2 { + return vec2(1f); +} + +fn return_arrf32_ai() -> array { + return array(1f, 1f, 1f, 1f); +} + +@compute @workgroup_size(1, 1, 1) +fn main() { + return; +} diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index 5a9cb343df..4f98259c35 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -921,6 +921,10 @@ fn convert_wgsl() { "abstract-types-operators", Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::WGSL, ), + ( + "abstract-types-return", + Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::HLSL | Targets::WGSL, + ), ( "int64", Targets::SPIRV | Targets::HLSL | Targets::WGSL | Targets::METAL, diff --git a/naga/tests/wgsl_errors.rs b/naga/tests/wgsl_errors.rs index a23eaf7047..5abd61cd8f 100644 --- a/naga/tests/wgsl_errors.rs +++ b/naga/tests/wgsl_errors.rs @@ -1174,7 +1174,7 @@ fn invalid_functions() { if function_name == "return_pointer" } - check_validation! { + check( " @group(0) @binding(0) var atom: atomic; @@ -1182,14 +1182,15 @@ fn invalid_functions() { fn return_atomic() -> atomic { return atom; } - ": - Err(naga::valid::ValidationError::Function { - name: function_name, - source: naga::valid::FunctionError::NonConstructibleReturnType, - .. - }) - if function_name == "return_atomic" - } + ", + "error: automatic conversions cannot convert `u32` to `atomic` + ┌─ wgsl:6:19 + │ +6 │ return atom; + │ ^^^^ this expression has type u32 + +", + ); } #[test]