diff --git a/CHANGELOG.md b/CHANGELOG.md index 40a625a443..f15634586d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] ### Changed 🛠 +- [PR#1118](https://github.com/EmbarkStudios/rust-gpu/pull/1118) update rspirv to 0.12 + - added support for `VK_KHR_fragment_shader_barycentric` with `bary_coord`/`bary_coord_no_persp` buildins - [PR#1127](https://github.com/EmbarkStudios/rust-gpu/pull/1127) updated `spirv-tools` to `0.10.0`, which follows `vulkan-sdk-1.3.275` - [PR#1101](https://github.com/EmbarkStudios/rust-gpu/pull/1101) added `ignore` and `no_run` to documentation to make `cargo test` pass - [PR#1112](https://github.com/EmbarkStudios/rust-gpu/pull/1112) updated wgpu and winit in example runners diff --git a/Cargo.lock b/Cargo.lock index 252d22e370..fa46939a71 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -317,12 +317,6 @@ dependencies = [ "syn 2.0.48", ] -[[package]] -name = "byteorder" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" - [[package]] name = "bytes" version = "1.5.0" @@ -1025,15 +1019,6 @@ dependencies = [ "slab", ] -[[package]] -name = "fxhash" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c" -dependencies = [ - "byteorder", -] - [[package]] name = "gethostname" version = "0.4.3" @@ -1613,7 +1598,7 @@ dependencies = [ "num-traits", "petgraph", "rustc-hash", - "spirv 0.3.0+sdk-1.3.268.0", + "spirv", "termcolor", "thiserror", "unicode-xid", @@ -2112,13 +2097,12 @@ checksum = "216080ab382b992234dda86873c18d4c48358f5cfcb70fd693d7f6f2131b628b" [[package]] name = "rspirv" -version = "0.11.0+1.5.4" +version = "0.12.0+sdk-1.3.268.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1503993b59ca9ae4127365c3293517576d7ce56be9f3d8abb1625c85ddc583ba" +checksum = "69cf3a93856b6e5946537278df0d3075596371b1950ccff012f02b0f7eafec8d" dependencies = [ - "fxhash", - "num-traits", - "spirv 0.2.0+1.5.4", + "rustc-hash", + "spirv", ] [[package]] @@ -2432,16 +2416,6 @@ dependencies = [ "smallvec", ] -[[package]] -name = "spirv" -version = "0.2.0+1.5.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "246bfa38fe3db3f1dfc8ca5a2cdeb7348c78be2112740cc0ec8ef18b6d94f830" -dependencies = [ - "bitflags 1.3.2", - "num-traits", -] - [[package]] name = "spirv" version = "0.3.0+sdk-1.3.268.0" diff --git a/crates/rustc_codegen_spirv-types/Cargo.toml b/crates/rustc_codegen_spirv-types/Cargo.toml index 41b8119a89..5fdaa65aa7 100644 --- a/crates/rustc_codegen_spirv-types/Cargo.toml +++ b/crates/rustc_codegen_spirv-types/Cargo.toml @@ -8,5 +8,5 @@ license.workspace = true repository.workspace = true [dependencies] -rspirv = "0.11" +rspirv = "0.12" serde = { version = "1.0", features = ["derive"] } diff --git a/crates/rustc_codegen_spirv/Cargo.toml b/crates/rustc_codegen_spirv/Cargo.toml index 93b9bd635d..b086cf541e 100644 --- a/crates/rustc_codegen_spirv/Cargo.toml +++ b/crates/rustc_codegen_spirv/Cargo.toml @@ -50,7 +50,7 @@ regex = { version = "1", features = ["perf"] } ar = "0.9.0" either = "1.8.0" indexmap = "1.6.0" -rspirv = "0.11" +rspirv = "0.12" rustc_codegen_spirv-types.workspace = true rustc-demangle = "0.1.21" sanitize-filename = "0.4" diff --git a/crates/rustc_codegen_spirv/src/abi.rs b/crates/rustc_codegen_spirv/src/abi.rs index fbe20ed462..fa8a7799a4 100644 --- a/crates/rustc_codegen_spirv/src/abi.rs +++ b/crates/rustc_codegen_spirv/src/abi.rs @@ -4,7 +4,7 @@ use crate::attr::{AggregatedSpirvAttributes, IntrinsicType}; use crate::codegen_cx::CodegenCx; use crate::spirv_type::SpirvType; -use rspirv::spirv::{StorageClass, Word}; +use rspirv::spirv::{Dim, ImageFormat, StorageClass, Word}; use rustc_data_structures::fx::FxHashMap; use rustc_errors::ErrorGuaranteed; use rustc_index::Idx; @@ -647,7 +647,7 @@ fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx> // spir-v doesn't support zero-sized arrays create_zst(cx, span, ty) } else { - let count_const = cx.constant_u32(span, count as u32); + let count_const = cx.constant_bit32(span, count as u32); let element_spv = cx.lookup_type(element_type); let stride_spv = element_spv .sizeof(cx) @@ -856,13 +856,35 @@ fn trans_intrinsic_type<'tcx>( // let image_format: spirv::ImageFormat = // type_from_variant_discriminant(cx, args.const_at(6)); - fn const_int_value<'tcx, P: FromPrimitive>( + trait FromU128Const: Sized { + fn from_u128_const(n: u128) -> Option; + } + + impl FromU128Const for u32 { + fn from_u128_const(n: u128) -> Option { + u32::from_u128(n) + } + } + + impl FromU128Const for Dim { + fn from_u128_const(n: u128) -> Option { + Dim::from_u32(u32::from_u128(n)?) + } + } + + impl FromU128Const for ImageFormat { + fn from_u128_const(n: u128) -> Option { + ImageFormat::from_u32(u32::from_u128(n)?) + } + } + + fn const_int_value<'tcx, P: FromU128Const>( cx: &CodegenCx<'tcx>, const_: Const<'tcx>, ) -> Result { assert!(const_.ty().is_integral()); let value = const_.eval_bits(cx.tcx, ParamEnv::reveal_all()); - match P::from_u128(value) { + match P::from_u128_const(value) { Some(v) => Ok(v), None => Err(cx .tcx diff --git a/crates/rustc_codegen_spirv/src/attr.rs b/crates/rustc_codegen_spirv/src/attr.rs index b796e86db6..2ef98a7a5a 100644 --- a/crates/rustc_codegen_spirv/src/attr.rs +++ b/crates/rustc_codegen_spirv/src/attr.rs @@ -91,6 +91,7 @@ pub enum SpirvAttribute { DescriptorSet(u32), Binding(u32), Flat, + PerPrimitiveExt, Invariant, InputAttachmentIndex(u32), SpecConstant(SpecConstant), @@ -127,6 +128,7 @@ pub struct AggregatedSpirvAttributes { pub binding: Option>, pub flat: Option>, pub invariant: Option>, + pub per_primitive_ext: Option>, pub input_attachment_index: Option>, pub spec_constant: Option>, @@ -213,6 +215,12 @@ impl AggregatedSpirvAttributes { Binding(value) => try_insert(&mut self.binding, value, span, "#[spirv(binding)]"), Flat => try_insert(&mut self.flat, (), span, "#[spirv(flat)]"), Invariant => try_insert(&mut self.invariant, (), span, "#[spirv(invariant)]"), + PerPrimitiveExt => try_insert( + &mut self.per_primitive_ext, + (), + span, + "#[spirv(per_primitive_ext)]", + ), InputAttachmentIndex(value) => try_insert( &mut self.input_attachment_index, value, @@ -314,6 +322,7 @@ impl CheckSpirvAttrVisitor<'_> { | SpirvAttribute::Binding(_) | SpirvAttribute::Flat | SpirvAttribute::Invariant + | SpirvAttribute::PerPrimitiveExt | SpirvAttribute::InputAttachmentIndex(_) | SpirvAttribute::SpecConstant(_) => match target { Target::Param => { diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index 86766ff75f..bb04d62caf 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -175,7 +175,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { | MemorySemantics::SEQUENTIALLY_CONSISTENT } }; - let semantics = self.constant_u32(self.span(), semantics.bits()); + let semantics = self.constant_bit32(self.span(), semantics.bits()); if invalid_seq_cst { self.zombie( semantics.def(self), @@ -196,10 +196,10 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { .constant_u16(self.span(), memset_fill_u16(fill_byte)) .def(self), 32 => self - .constant_u32(self.span(), memset_fill_u32(fill_byte)) + .constant_bit32(self.span(), memset_fill_u32(fill_byte)) .def(self), 64 => self - .constant_u64(self.span(), memset_fill_u64(fill_byte)) + .constant_bit64(self.span(), memset_fill_u64(fill_byte)) .def(self), _ => self.fatal(format!( "memset on integer width {width} not implemented yet" @@ -314,7 +314,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { self.store(pat, ptr, Align::from_bytes(0).unwrap()); } else { for index in 0..count { - let const_index = self.constant_u32(self.span(), index as u32); + let const_index = self.constant_bit32(self.span(), index as u32); let gep_ptr = self.gep(pat.ty, ptr, &[const_index]); self.store(pat, gep_ptr, Align::from_bytes(0).unwrap()); } @@ -428,7 +428,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { } else { let indices = indices .into_iter() - .map(|idx| self.constant_u32(self.span(), idx).def(self)) + .map(|idx| self.constant_bit32(self.span(), idx).def(self)) .collect::>(); self.emit() .access_chain(leaf_ptr_ty, None, ptr.def(self), indices) @@ -904,9 +904,9 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { )) } else if signed { // this cast chain can probably be collapsed, but, whatever, be safe - Operand::LiteralInt32(v as u8 as i8 as i32 as u32) + Operand::LiteralBit32(v as u8 as i8 as i32 as u32) } else { - Operand::LiteralInt32(v as u8 as u32) + Operand::LiteralBit32(v as u8 as u32) } } fn construct_16(self_: &Builder<'_, '_>, signed: bool, v: u128) -> Operand { @@ -915,9 +915,9 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { "Switches to values above u16::MAX not supported: {v:?}" )) } else if signed { - Operand::LiteralInt32(v as u16 as i16 as i32 as u32) + Operand::LiteralBit32(v as u16 as i16 as i32 as u32) } else { - Operand::LiteralInt32(v as u16 as u32) + Operand::LiteralBit32(v as u16 as u32) } } fn construct_32(self_: &Builder<'_, '_>, _signed: bool, v: u128) -> Operand { @@ -926,7 +926,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { "Switches to values above u32::MAX not supported: {v:?}" )) } else { - Operand::LiteralInt32(v as u32) + Operand::LiteralBit32(v as u32) } } fn construct_64(self_: &Builder<'_, '_>, _signed: bool, v: u128) -> Operand { @@ -935,7 +935,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { "Switches to values above u64::MAX not supported: {v:?}" )) } else { - Operand::LiteralInt64(v as u64) + Operand::LiteralBit64(v as u64) } } // pass in signed into the closure to be able to unify closure types @@ -1217,7 +1217,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { let (ptr, access_ty) = self.adjust_pointer_for_typed_access(ptr, ty); // TODO: Default to device scope - let memory = self.constant_u32(self.span(), Scope::Device as u32); + let memory = self.constant_bit32(self.span(), Scope::Device as u32); let semantics = self.ordering_to_semantics_def(order); let result = self .emit() @@ -1347,7 +1347,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { let val = self.bitcast(val, access_ty); // TODO: Default to device scope - let memory = self.constant_u32(self.span(), Scope::Device as u32); + let memory = self.constant_bit32(self.span(), Scope::Device as u32); let semantics = self.ordering_to_semantics_def(order); self.validate_atomic(val.ty, ptr.def(self)); self.emit() @@ -1413,7 +1413,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { let original_ptr = ptr.def(self); let indices = indices .into_iter() - .map(|idx| self.constant_u32(self.span(), idx).def(self)) + .map(|idx| self.constant_bit32(self.span(), idx).def(self)) .collect::>(); return self .emit() @@ -1433,7 +1433,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { if idx > u32::MAX as u64 { self.fatal("struct_gep bigger than u32::MAX"); } - let index_const = self.constant_u32(self.span(), idx as u32).def(self); + let index_const = self.constant_bit32(self.span(), idx as u32).def(self); self.emit() .access_chain( result_type, @@ -1741,7 +1741,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { ) { let indices = indices .into_iter() - .map(|idx| self.constant_u32(self.span(), idx).def(self)) + .map(|idx| self.constant_bit32(self.span(), idx).def(self)) .collect::>(); self.emit() .access_chain(dest_ty, None, ptr.def(self), indices) @@ -2292,7 +2292,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { self.validate_atomic(access_ty, dst.def(self)); // TODO: Default to device scope - let memory = self.constant_u32(self.span(), Scope::Device as u32); + let memory = self.constant_bit32(self.span(), Scope::Device as u32); let semantics_equal = self.ordering_to_semantics_def(order); let semantics_unequal = self.ordering_to_semantics_def(failure_order); // Note: OpAtomicCompareExchangeWeak is deprecated, and has the same semantics @@ -2328,7 +2328,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { self.validate_atomic(access_ty, dst.def(self)); // TODO: Default to device scope let memory = self - .constant_u32(self.span(), Scope::Device as u32) + .constant_bit32(self.span(), Scope::Device as u32) .def(self); let semantics = self.ordering_to_semantics_def(order).def(self); use AtomicRmwBinOp::*; @@ -2424,7 +2424,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { // Ignore sync scope (it only has "single thread" and "cross thread") // TODO: Default to device scope let memory = self - .constant_u32(self.span(), Scope::Device as u32) + .constant_bit32(self.span(), Scope::Device as u32) .def(self); let semantics = self.ordering_to_semantics_def(order).def(self); self.emit().memory_barrier(memory, semantics).unwrap(); @@ -2697,7 +2697,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { // HACK(eddyb) avoid the logic below that assumes only ID operands if inst.class.opcode == Op::CompositeExtract { - if let (Some(r), &[Operand::IdRef(x), Operand::LiteralInt32(i)]) = + if let (Some(r), &[Operand::IdRef(x), Operand::LiteralBit32(i)]) = (inst.result_id, &inst.operands[..]) { return Some(Inst::CompositeExtract(r, x, i)); diff --git a/crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs b/crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs index 1248da7f80..128e3fba0b 100644 --- a/crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs +++ b/crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs @@ -31,7 +31,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { constant_offset: u32, ) -> SpirvValue { let actual_index = if constant_offset != 0 { - let const_offset_val = self.constant_u32(DUMMY_SP, constant_offset); + let const_offset_val = self.constant_bit32(DUMMY_SP, constant_offset); self.add(dynamic_index, const_offset_val) } else { dynamic_index @@ -199,7 +199,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // Note that the &[u32] gets split into two arguments - pointer, length let array = args[0]; let byte_index = args[2]; - let two = self.constant_u32(DUMMY_SP, 2); + let two = self.constant_bit32(DUMMY_SP, 2); let word_index = self.lshr(byte_index, two); self.recurse_load_type(result_type, result_type, array, word_index, 0) } @@ -223,7 +223,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { value: SpirvValue, ) -> Result<(), ErrorGuaranteed> { let actual_index = if constant_offset != 0 { - let const_offset_val = self.constant_u32(DUMMY_SP, constant_offset); + let const_offset_val = self.constant_bit32(DUMMY_SP, constant_offset); self.add(dynamic_index, const_offset_val) } else { dynamic_index @@ -367,7 +367,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // Note that the &[u32] gets split into two arguments - pointer, length let array = args[0]; let byte_index = args[2]; - let two = self.constant_u32(DUMMY_SP, 2); + let two = self.constant_bit32(DUMMY_SP, 2); let word_index = self.lshr(byte_index, two); if is_pair { let value_one = args[3]; diff --git a/crates/rustc_codegen_spirv/src/builder/intrinsics.rs b/crates/rustc_codegen_spirv/src/builder/intrinsics.rs index 9d8b886ea8..dbee5aae7e 100644 --- a/crates/rustc_codegen_spirv/src/builder/intrinsics.rs +++ b/crates/rustc_codegen_spirv/src/builder/intrinsics.rs @@ -45,12 +45,12 @@ impl Builder<'_, '_> { let int_ty = SpirvType::Integer(width, false).def(self.span(), self); let (mask_sign, mask_value) = match width { 32 => ( - self.constant_u32(self.span(), 1 << 31), - self.constant_u32(self.span(), u32::max_value() >> 1), + self.constant_bit32(self.span(), 1 << 31), + self.constant_bit32(self.span(), u32::max_value() >> 1), ), 64 => ( - self.constant_u64(self.span(), 1 << 63), - self.constant_u64(self.span(), u64::max_value() >> 1), + self.constant_bit64(self.span(), 1 << 63), + self.constant_bit64(self.span(), u64::max_value() >> 1), ), _ => bug!("copysign must have width 32 or 64, not {}", width), }; @@ -269,10 +269,10 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> { self.or(tmp1, tmp2) } 32 => { - let offset8 = self.constant_u32(self.span(), 8); - let offset24 = self.constant_u32(self.span(), 24); - let mask16 = self.constant_u32(self.span(), 0xFF00); - let mask24 = self.constant_u32(self.span(), 0xFF0000); + let offset8 = self.constant_bit32(self.span(), 8); + let offset24 = self.constant_bit32(self.span(), 24); + let mask16 = self.constant_bit32(self.span(), 0xFF00); + let mask24 = self.constant_bit32(self.span(), 0xFF0000); let tmp4 = self.shl(arg, offset24); let tmp3 = self.shl(arg, offset8); let tmp2 = self.lshr(arg, offset8); @@ -284,16 +284,16 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> { self.or(res1, res2) } 64 => { - let offset8 = self.constant_u64(self.span(), 8); - let offset24 = self.constant_u64(self.span(), 24); - let offset40 = self.constant_u64(self.span(), 40); - let offset56 = self.constant_u64(self.span(), 56); - let mask16 = self.constant_u64(self.span(), 0xff00); - let mask24 = self.constant_u64(self.span(), 0xff0000); - let mask32 = self.constant_u64(self.span(), 0xff000000); - let mask40 = self.constant_u64(self.span(), 0xff00000000); - let mask48 = self.constant_u64(self.span(), 0xff0000000000); - let mask56 = self.constant_u64(self.span(), 0xff000000000000); + let offset8 = self.constant_bit64(self.span(), 8); + let offset24 = self.constant_bit64(self.span(), 24); + let offset40 = self.constant_bit64(self.span(), 40); + let offset56 = self.constant_bit64(self.span(), 56); + let mask16 = self.constant_bit64(self.span(), 0xff00); + let mask24 = self.constant_bit64(self.span(), 0xff0000); + let mask32 = self.constant_bit64(self.span(), 0xff000000); + let mask40 = self.constant_bit64(self.span(), 0xff00000000); + let mask48 = self.constant_bit64(self.span(), 0xff0000000000); + let mask56 = self.constant_bit64(self.span(), 0xff000000000000); let tmp8 = self.shl(arg, offset56); let tmp7 = self.shl(arg, offset40); let tmp6 = self.shl(arg, offset24); diff --git a/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs b/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs index 97c4d9cfa1..5141cf2ede 100644 --- a/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs +++ b/crates/rustc_codegen_spirv/src/builder/spirv_asm.rs @@ -5,8 +5,9 @@ use crate::spirv_type::SpirvType; use rspirv::dr; use rspirv::grammar::{reflect, LogicalOperand, OperandKind, OperandQuantifier}; use rspirv::spirv::{ - FPFastMathMode, FragmentShadingRate, FunctionControl, ImageOperands, KernelProfilingInfo, - LoopControl, MemoryAccess, MemorySemantics, Op, RayFlags, SelectionControl, StorageClass, Word, + CooperativeMatrixOperands, FPFastMathMode, FragmentShadingRate, FunctionControl, ImageOperands, + KernelProfilingInfo, LoopControl, MemoryAccess, MemorySemantics, Op, RayFlags, + SelectionControl, StorageClass, Word, }; use rustc_ast::ast::{InlineAsmOptions, InlineAsmTemplatePiece}; use rustc_codegen_ssa::mir::place::PlaceRef; @@ -273,12 +274,12 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { Op::TypeVoid => SpirvType::Void.def(self.span(), self), Op::TypeBool => SpirvType::Bool.def(self.span(), self), Op::TypeInt => SpirvType::Integer( - inst.operands[0].unwrap_literal_int32(), - inst.operands[1].unwrap_literal_int32() != 0, + inst.operands[0].unwrap_literal_bit32(), + inst.operands[1].unwrap_literal_bit32() != 0, ) .def(self.span(), self), Op::TypeFloat => { - SpirvType::Float(inst.operands[0].unwrap_literal_int32()).def(self.span(), self) + SpirvType::Float(inst.operands[0].unwrap_literal_bit32()).def(self.span(), self) } Op::TypeStruct => { self.err("OpTypeStruct in asm! is not supported yet"); @@ -286,12 +287,12 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { } Op::TypeVector => SpirvType::Vector { element: inst.operands[0].unwrap_id_ref(), - count: inst.operands[1].unwrap_literal_int32(), + count: inst.operands[1].unwrap_literal_bit32(), } .def(self.span(), self), Op::TypeMatrix => SpirvType::Matrix { element: inst.operands[0].unwrap_id_ref(), - count: inst.operands[1].unwrap_literal_int32(), + count: inst.operands[1].unwrap_literal_bit32(), } .def(self.span(), self), Op::TypeArray => { @@ -322,10 +323,10 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { Op::TypeImage => SpirvType::Image { sampled_type: inst.operands[0].unwrap_id_ref(), dim: inst.operands[1].unwrap_dim(), - depth: inst.operands[2].unwrap_literal_int32(), - arrayed: inst.operands[3].unwrap_literal_int32(), - multisampled: inst.operands[4].unwrap_literal_int32(), - sampled: inst.operands[5].unwrap_literal_int32(), + depth: inst.operands[2].unwrap_literal_bit32(), + arrayed: inst.operands[3].unwrap_literal_bit32(), + multisampled: inst.operands[4].unwrap_literal_bit32(), + sampled: inst.operands[5].unwrap_literal_bit32(), image_format: inst.operands[6].unwrap_image_format(), } .def(self.span(), self), @@ -694,7 +695,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { let index_to_usize = || match *index { // FIXME(eddyb) support more than just literals, // by looking up `IdRef`s as constant integers. - dr::Operand::LiteralInt32(i) => usize::try_from(i).ok(), + dr::Operand::LiteralBit32(i) => usize::try_from(i).ok(), _ => None, }; @@ -1077,9 +1078,15 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { } (OperandKind::LiteralInteger, Some(word)) => match word.parse() { - Ok(v) => inst.operands.push(dr::Operand::LiteralInt32(v)), + Ok(v) => inst.operands.push(dr::Operand::LiteralBit32(v)), Err(e) => self.err(format!("invalid integer: {e}")), }, + (OperandKind::LiteralFloat, Some(word)) => match word.parse() { + Ok(v) => inst + .operands + .push(dr::Operand::LiteralBit32(f32::to_bits(v))), + Err(e) => self.err(format!("invalid float: {e}")), + }, (OperandKind::LiteralString, _) => { if let Token::String(value) = token { inst.operands.push(dr::Operand::LiteralString(value)); @@ -1094,34 +1101,34 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { } Ok(match ty { SpirvType::Integer(8, false) => { - dr::Operand::LiteralInt32(w.parse::().map_err(fmt)? as u32) + dr::Operand::LiteralBit32(w.parse::().map_err(fmt)? as u32) } SpirvType::Integer(16, false) => { - dr::Operand::LiteralInt32(w.parse::().map_err(fmt)? as u32) + dr::Operand::LiteralBit32(w.parse::().map_err(fmt)? as u32) } SpirvType::Integer(32, false) => { - dr::Operand::LiteralInt32(w.parse::().map_err(fmt)?) + dr::Operand::LiteralBit32(w.parse::().map_err(fmt)?) } SpirvType::Integer(64, false) => { - dr::Operand::LiteralInt64(w.parse::().map_err(fmt)?) + dr::Operand::LiteralBit64(w.parse::().map_err(fmt)?) } SpirvType::Integer(8, true) => { - dr::Operand::LiteralInt32(w.parse::().map_err(fmt)? as i32 as u32) + dr::Operand::LiteralBit32(w.parse::().map_err(fmt)? as i32 as u32) } SpirvType::Integer(16, true) => { - dr::Operand::LiteralInt32(w.parse::().map_err(fmt)? as i32 as u32) + dr::Operand::LiteralBit32(w.parse::().map_err(fmt)? as i32 as u32) } SpirvType::Integer(32, true) => { - dr::Operand::LiteralInt32(w.parse::().map_err(fmt)? as u32) + dr::Operand::LiteralBit32(w.parse::().map_err(fmt)? as u32) } SpirvType::Integer(64, true) => { - dr::Operand::LiteralInt64(w.parse::().map_err(fmt)? as u64) + dr::Operand::LiteralBit64(w.parse::().map_err(fmt)? as u64) } SpirvType::Float(32) => { - dr::Operand::LiteralFloat32(w.parse::().map_err(fmt)?) + dr::Operand::LiteralBit32(w.parse::().map_err(fmt)?.to_bits()) } SpirvType::Float(64) => { - dr::Operand::LiteralFloat64(w.parse::().map_err(fmt)?) + dr::Operand::LiteralBit64(w.parse::().map_err(fmt)?.to_bits()) } _ => return Err("expected number literal in OpConstant".to_string()), }) @@ -1152,7 +1159,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { inst.operands.push(dr::Operand::IdRef(id)); match tokens.next() { Some(Token::Word(word)) => match word.parse() { - Ok(v) => inst.operands.push(dr::Operand::LiteralInt32(v)), + Ok(v) => inst.operands.push(dr::Operand::LiteralBit32(v)), Err(e) => { self.err(format!("invalid integer: {e}")); } @@ -1357,6 +1364,60 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> { .push(dr::Operand::RayQueryCandidateIntersectionType(x)), Err(()) => self.err(format!("unknown RayQueryCandidateIntersectionType {word}")), }, + (OperandKind::FPDenormMode, Some(word)) => match word.parse() { + Ok(x) => inst.operands.push(dr::Operand::FPDenormMode(x)), + Err(()) => self.err(format!("unknown FPDenormMode {word}")), + }, + (OperandKind::QuantizationModes, Some(word)) => match word.parse() { + Ok(x) => inst.operands.push(dr::Operand::QuantizationModes(x)), + Err(()) => self.err(format!("unknown QuantizationModes {word}")), + }, + (OperandKind::FPOperationMode, Some(word)) => match word.parse() { + Ok(x) => inst.operands.push(dr::Operand::FPOperationMode(x)), + Err(()) => self.err(format!("unknown FPOperationMode {word}")), + }, + (OperandKind::OverflowModes, Some(word)) => match word.parse() { + Ok(x) => inst.operands.push(dr::Operand::OverflowModes(x)), + Err(()) => self.err(format!("unknown OverflowModes {word}")), + }, + (OperandKind::PackedVectorFormat, Some(word)) => match word.parse() { + Ok(x) => inst.operands.push(dr::Operand::PackedVectorFormat(x)), + Err(()) => self.err(format!("unknown PackedVectorFormat {word}")), + }, + (OperandKind::HostAccessQualifier, Some(word)) => match word.parse() { + Ok(x) => inst.operands.push(dr::Operand::HostAccessQualifier(x)), + Err(()) => self.err(format!("unknown HostAccessQualifier {word}")), + }, + (OperandKind::CooperativeMatrixOperands, Some(word)) => { + match parse_bitflags_operand(COOPERATIVE_MATRIX_OPERANDS, word) { + Some(x) => inst + .operands + .push(dr::Operand::CooperativeMatrixOperands(x)), + None => self.err(format!("Unknown CooperativeMatrixOperands {word}")), + } + } + (OperandKind::CooperativeMatrixLayout, Some(word)) => match word.parse() { + Ok(x) => inst.operands.push(dr::Operand::CooperativeMatrixLayout(x)), + Err(()) => self.err(format!("unknown CooperativeMatrixLayout {word}")), + }, + (OperandKind::CooperativeMatrixUse, Some(word)) => match word.parse() { + Ok(x) => inst.operands.push(dr::Operand::CooperativeMatrixUse(x)), + Err(()) => self.err(format!("unknown CooperativeMatrixUse {word}")), + }, + (OperandKind::InitializationModeQualifier, Some(word)) => match word.parse() { + Ok(x) => inst + .operands + .push(dr::Operand::InitializationModeQualifier(x)), + Err(()) => self.err(format!("unknown InitializationModeQualifier {word}")), + }, + (OperandKind::LoadCacheControl, Some(word)) => match word.parse() { + Ok(x) => inst.operands.push(dr::Operand::LoadCacheControl(x)), + Err(()) => self.err(format!("unknown LoadCacheControl {word}")), + }, + (OperandKind::StoreCacheControl, Some(word)) => match word.parse() { + Ok(x) => inst.operands.push(dr::Operand::StoreCacheControl(x)), + Err(()) => self.err(format!("unknown StoreCacheControl {word}")), + }, (kind, None) => match token { Token::Word(_) => bug!(), Token::String(_) => { @@ -1528,6 +1589,29 @@ pub const FRAGMENT_SHADING_RATE: &[(&str, FragmentShadingRate)] = &[ FragmentShadingRate::HORIZONTAL4_PIXELS, ), ]; +pub const COOPERATIVE_MATRIX_OPERANDS: &[(&str, CooperativeMatrixOperands)] = &[ + ("NONE_KHR", CooperativeMatrixOperands::NONE_KHR), + ( + "MATRIX_A_SIGNED_COMPONENTS_KHR", + CooperativeMatrixOperands::MATRIX_A_SIGNED_COMPONENTS_KHR, + ), + ( + "MATRIX_B_SIGNED_COMPONENTS_KHR", + CooperativeMatrixOperands::MATRIX_B_SIGNED_COMPONENTS_KHR, + ), + ( + "MATRIX_C_SIGNED_COMPONENTS_KHR", + CooperativeMatrixOperands::MATRIX_C_SIGNED_COMPONENTS_KHR, + ), + ( + "MATRIX_RESULT_SIGNED_COMPONENTS_KHR", + CooperativeMatrixOperands::MATRIX_RESULT_SIGNED_COMPONENTS_KHR, + ), + ( + "SATURATING_ACCUMULATION_KHR", + CooperativeMatrixOperands::SATURATING_ACCUMULATION_KHR, + ), +]; fn parse_bitflags_operand + Copy>( values: &'static [(&'static str, T)], diff --git a/crates/rustc_codegen_spirv/src/builder_spirv.rs b/crates/rustc_codegen_spirv/src/builder_spirv.rs index 6a7c8e72a2..f04cf843a6 100644 --- a/crates/rustc_codegen_spirv/src/builder_spirv.rs +++ b/crates/rustc_codegen_spirv/src/builder_spirv.rs @@ -583,10 +583,8 @@ impl<'tcx> BuilderSpirv<'tcx> { } let val = val_with_type.val; let id = match val { - SpirvConst::U32(v) => builder.constant_u32(ty, v), - SpirvConst::U64(v) => builder.constant_u64(ty, v), - SpirvConst::F32(v) => builder.constant_f32(ty, f32::from_bits(v)), - SpirvConst::F64(v) => builder.constant_f64(ty, f64::from_bits(v)), + SpirvConst::U32(v) | SpirvConst::F32(v) => builder.constant_bit32(ty, v), + SpirvConst::U64(v) | SpirvConst::F64(v) => builder.constant_bit64(ty, v), SpirvConst::Bool(v) => { if v { builder.constant_true(ty) diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs b/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs index 83b60b4fdd..7efff374a7 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs @@ -35,12 +35,12 @@ impl<'tcx> CodegenCx<'tcx> { self.def_constant(ty, SpirvConst::U32(val as u32)) } - pub fn constant_u32(&self, span: Span, val: u32) -> SpirvValue { + pub fn constant_bit32(&self, span: Span, val: u32) -> SpirvValue { let ty = SpirvType::Integer(32, false).def(span, self); self.def_constant(ty, SpirvConst::U32(val)) } - pub fn constant_u64(&self, span: Span, val: u64) -> SpirvValue { + pub fn constant_bit64(&self, span: Span, val: u64) -> SpirvValue { let ty = SpirvType::Integer(64, false).def(span, self); self.def_constant(ty, SpirvConst::U64(val)) } @@ -163,10 +163,10 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> { self.constant_i32(DUMMY_SP, i) } fn const_u32(&self, i: u32) -> Self::Value { - self.constant_u32(DUMMY_SP, i) + self.constant_bit32(DUMMY_SP, i) } fn const_u64(&self, i: u64) -> Self::Value { - self.constant_u64(DUMMY_SP, i) + self.constant_bit64(DUMMY_SP, i) } fn const_u128(&self, i: u128) -> Self::Value { let ty = SpirvType::Integer(128, false).def(DUMMY_SP, self); @@ -336,7 +336,7 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> { self.tcx .sess .fatal("Non-zero scalar_to_backend ptr.offset not supported") - // let offset = self.constant_u64(ptr.offset.bytes()); + // let offset = self.constant_bit64(ptr.offset.bytes()); // self.gep(base_addr, once(offset)) }; if let Primitive::Pointer(_) = layout.primitive() { diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs index 04e8268568..1fad1e4090 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs @@ -449,11 +449,12 @@ impl<'tcx> CodegenCx<'tcx> { ), Err(SpecConstant { id, default }) => { let mut emit = self.emit_global(); - let spec_const_id = emit.spec_constant_u32(value_spirv_type, default.unwrap_or(0)); + let spec_const_id = + emit.spec_constant_bit32(value_spirv_type, default.unwrap_or(0)); emit.decorate( spec_const_id, Decoration::SpecId, - [Operand::LiteralInt32(id)], + [Operand::LiteralBit32(id)], ); ( Err("`#[spirv(spec_constant)]` is not an entry-point interface variable"), @@ -669,7 +670,7 @@ impl<'tcx> CodegenCx<'tcx> { self.emit_global().decorate( var_id.unwrap(), Decoration::DescriptorSet, - std::iter::once(Operand::LiteralInt32(descriptor_set.value)), + std::iter::once(Operand::LiteralBit32(descriptor_set.value)), ); decoration_supersedes_location = true; } @@ -683,7 +684,7 @@ impl<'tcx> CodegenCx<'tcx> { self.emit_global().decorate( var_id.unwrap(), Decoration::Binding, - std::iter::once(Operand::LiteralInt32(binding.value)), + std::iter::once(Operand::LiteralBit32(binding.value)), ); decoration_supersedes_location = true; } @@ -707,6 +708,27 @@ impl<'tcx> CodegenCx<'tcx> { self.emit_global() .decorate(var_id.unwrap(), Decoration::Invariant, std::iter::empty()); } + if let Some(per_primitive_ext) = attrs.per_primitive_ext { + if storage_class != Ok(StorageClass::Output) { + self.tcx.sess.span_fatal( + per_primitive_ext.span, + "`#[spirv(per_primitive_ext)]` is only valid on Output variables", + ); + } + if !(execution_model == ExecutionModel::MeshEXT + || execution_model == ExecutionModel::MeshNV) + { + self.tcx.sess.span_fatal( + per_primitive_ext.span, + "`#[spirv(per_primitive_ext)]` is only valid in mesh shaders", + ); + } + self.emit_global().decorate( + var_id.unwrap(), + Decoration::PerPrimitiveEXT, + std::iter::empty(), + ); + } let is_subpass_input = match self.lookup_type(value_spirv_type) { SpirvType::Image { @@ -728,7 +750,7 @@ impl<'tcx> CodegenCx<'tcx> { self.emit_global().decorate( var_id.unwrap(), Decoration::InputAttachmentIndex, - std::iter::once(Operand::LiteralInt32(attachment_index.value)), + std::iter::once(Operand::LiteralBit32(attachment_index.value)), ); } else if is_subpass_input { self.tcx @@ -775,7 +797,7 @@ impl<'tcx> CodegenCx<'tcx> { self.emit_global().decorate( var_id.unwrap(), Decoration::Location, - std::iter::once(Operand::LiteralInt32(*location)), + std::iter::once(Operand::LiteralBit32(*location)), ); *location += 1; } diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/type_.rs b/crates/rustc_codegen_spirv/src/codegen_cx/type_.rs index 6effea59b9..089a90a02f 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/type_.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/type_.rs @@ -174,7 +174,7 @@ impl<'tcx> BaseTypeMethods<'tcx> for CodegenCx<'tcx> { fn type_array(&self, ty: Self::Type, len: u64) -> Self::Type { SpirvType::Array { element: ty, - count: self.constant_u64(DUMMY_SP, len), + count: self.constant_bit64(DUMMY_SP, len), } .def(DUMMY_SP, self) } diff --git a/crates/rustc_codegen_spirv/src/custom_decorations.rs b/crates/rustc_codegen_spirv/src/custom_decorations.rs index 90844487d6..ab385a9f87 100644 --- a/crates/rustc_codegen_spirv/src/custom_decorations.rs +++ b/crates/rustc_codegen_spirv/src/custom_decorations.rs @@ -379,8 +379,8 @@ impl<'a> SpanRegenerator<'a> { let (file_id, line_start, line_end, col_start, col_end) = match inst.class.opcode { Op::Line => { let file = inst.operands[0].unwrap_id_ref(); - let line = inst.operands[1].unwrap_literal_int32(); - let col = inst.operands[2].unwrap_literal_int32(); + let line = inst.operands[1].unwrap_literal_bit32(); + let col = inst.operands[2].unwrap_literal_bit32(); (file, line, line, col, col) } Op::ExtInst @@ -397,7 +397,7 @@ impl<'a> SpanRegenerator<'a> { } => { let const_u32 = |operand: Operand| { spv_debug_info.id_to_op_constant_operand[&operand.unwrap_id_ref()] - .unwrap_literal_int32() + .unwrap_literal_bit32() }; ( file.unwrap_id_ref(), diff --git a/crates/rustc_codegen_spirv/src/linker/destructure_composites.rs b/crates/rustc_codegen_spirv/src/linker/destructure_composites.rs index 287537568e..9384eecf60 100644 --- a/crates/rustc_codegen_spirv/src/linker/destructure_composites.rs +++ b/crates/rustc_codegen_spirv/src/linker/destructure_composites.rs @@ -23,13 +23,13 @@ pub fn destructure_composites(function: &mut Function) { for inst in function.all_inst_iter_mut() { if inst.class.opcode == Op::CompositeExtract && inst.operands.len() == 2 { let mut composite = inst.operands[0].unwrap_id_ref(); - let index = inst.operands[1].unwrap_literal_int32(); + let index = inst.operands[1].unwrap_literal_bit32(); let origin = loop { if let Some(inst) = reference.get(&composite) { match inst.class.opcode { Op::CompositeInsert => { - let insert_index = inst.operands[2].unwrap_literal_int32(); + let insert_index = inst.operands[2].unwrap_literal_bit32(); if insert_index == index { break Some(inst.operands[0].unwrap_id_ref()); } diff --git a/crates/rustc_codegen_spirv/src/linker/duplicates.rs b/crates/rustc_codegen_spirv/src/linker/duplicates.rs index 3ca05022f0..692f292036 100644 --- a/crates/rustc_codegen_spirv/src/linker/duplicates.rs +++ b/crates/rustc_codegen_spirv/src/linker/duplicates.rs @@ -267,7 +267,7 @@ pub fn remove_duplicate_types(module: &mut Module) { && (inst.class.opcode != Op::MemberName || member_name_ids.insert(( inst.operands[0].unwrap_id_ref(), - inst.operands[1].unwrap_literal_int32(), + inst.operands[1].unwrap_literal_bit32(), ))) }); } diff --git a/crates/rustc_codegen_spirv/src/linker/mem2reg.rs b/crates/rustc_codegen_spirv/src/linker/mem2reg.rs index 2f93e2b340..17042e1e27 100644 --- a/crates/rustc_codegen_spirv/src/linker/mem2reg.rs +++ b/crates/rustc_codegen_spirv/src/linker/mem2reg.rs @@ -517,7 +517,7 @@ impl Renamer<'_> { ); let mut operands = vec![Operand::IdRef(val), Operand::IdRef(prev_comp)]; operands - .extend(var_info.indices.iter().copied().map(Operand::LiteralInt32)); + .extend(var_info.indices.iter().copied().map(Operand::LiteralBit32)); *inst = Instruction::new( Op::CompositeInsert, Some(self.base_var_type), @@ -545,7 +545,7 @@ impl Renamer<'_> { let new_id = id(self.header); let mut operands = vec![Operand::IdRef(current_obj)]; operands - .extend(var_info.indices.iter().copied().map(Operand::LiteralInt32)); + .extend(var_info.indices.iter().copied().map(Operand::LiteralBit32)); *inst = Instruction::new( Op::CompositeExtract, Some(var_info.ty), diff --git a/crates/rustc_codegen_spirv/src/linker/mod.rs b/crates/rustc_codegen_spirv/src/linker/mod.rs index 3606779ee9..52f4a5450b 100644 --- a/crates/rustc_codegen_spirv/src/linker/mod.rs +++ b/crates/rustc_codegen_spirv/src/linker/mod.rs @@ -301,14 +301,14 @@ pub fn link( .insert(inst.result_id.unwrap(), inst.operands[1].unwrap_id_ref()); } Op::TypeInt - if inst.operands[0].unwrap_literal_int32() == 32 - && inst.operands[1].unwrap_literal_int32() == 0 => + if inst.operands[0].unwrap_literal_bit32() == 32 + && inst.operands[1].unwrap_literal_bit32() == 0 => { assert!(u32.is_none()); u32 = Some(inst.result_id.unwrap()); } Op::Constant if u32.is_some() && inst.result_type == u32 => { - let value = inst.operands[0].unwrap_literal_int32(); + let value = inst.operands[0].unwrap_literal_bit32(); constants.insert(inst.result_id.unwrap(), value); } _ => {} @@ -357,14 +357,14 @@ pub fn link( .insert(inst.result_id.unwrap(), inst.operands[1].unwrap_id_ref()); } Op::TypeInt - if inst.operands[0].unwrap_literal_int32() == 32 - && inst.operands[1].unwrap_literal_int32() == 0 => + if inst.operands[0].unwrap_literal_bit32() == 32 + && inst.operands[1].unwrap_literal_bit32() == 0 => { assert!(u32.is_none()); u32 = Some(inst.result_id.unwrap()); } Op::Constant if u32.is_some() && inst.result_type == u32 => { - let value = inst.operands[0].unwrap_literal_int32(); + let value = inst.operands[0].unwrap_literal_bit32(); constants.insert(inst.result_id.unwrap(), value); } _ => {} diff --git a/crates/rustc_codegen_spirv/src/linker/peephole_opts.rs b/crates/rustc_codegen_spirv/src/linker/peephole_opts.rs index 668f6f50e9..47a314731c 100644 --- a/crates/rustc_codegen_spirv/src/linker/peephole_opts.rs +++ b/crates/rustc_codegen_spirv/src/linker/peephole_opts.rs @@ -16,7 +16,7 @@ fn composite_count(types: &FxHashMap, ty_id: Word) -> Option< let ty = types.get(&ty_id)?; match ty.class.opcode { Op::TypeStruct => Some(ty.operands.len()), - Op::TypeVector => Some(ty.operands[1].unwrap_literal_int32() as usize), + Op::TypeVector => Some(ty.operands[1].unwrap_literal_bit32() as usize), Op::TypeArray => { let length_id = ty.operands[1].unwrap_id_ref(); let const_inst = types.get(&length_id)?; @@ -28,8 +28,8 @@ fn composite_count(types: &FxHashMap, ty_id: Word) -> Option< return None; } let const_value = match const_inst.operands[0] { - Operand::LiteralInt32(v) => v as usize, - Operand::LiteralInt64(v) => v as usize, + Operand::LiteralBit32(v) => v as usize, + Operand::LiteralBit64(v) => v as usize, _ => bug!(), }; Some(const_value) @@ -67,7 +67,7 @@ pub fn composite_construct(types: &FxHashMap, function: &mut break; } let value = cur_inst.operands[0].unwrap_id_ref(); - let index = cur_inst.operands[2].unwrap_literal_int32() as usize; + let index = cur_inst.operands[2].unwrap_literal_bit32() as usize; if index >= components.len() { // Theoretically shouldn't happen, as it's invalid SPIR-V if the index is out // of bounds, but just stop optimizing instead of panicing here. @@ -131,7 +131,7 @@ fn get_composite_and_index( return None; } let composite = inst.operands[0].unwrap_id_ref(); - let index = inst.operands[1].unwrap_literal_int32(); + let index = inst.operands[1].unwrap_literal_bit32(); let composite_def = defs.get(&composite).or_else(|| types.get(&composite))?; let vector_def = types.get(&composite_def.result_type.unwrap())?; @@ -140,7 +140,7 @@ fn get_composite_and_index( // Width mismatch would be doing something like `vec2(a.x + b.x, a.y + b.y)` where `a` is a // vec4 - if we optimized it to just `a + b`, it'd be incorrect. if vector_def.class.opcode != Op::TypeVector - || vector_width != vector_def.operands[1].unwrap_literal_int32() + || vector_width != vector_def.operands[1].unwrap_literal_bit32() { return None; } @@ -330,7 +330,7 @@ fn process_instruction( if vector_ty_inst.class.opcode != Op::TypeVector { return None; } - let vector_width = vector_ty_inst.operands[1].unwrap_literal_int32(); + let vector_width = vector_ty_inst.operands[1].unwrap_literal_bit32(); // `results` is the defining instruction for each scalar component of the final result. let results = match inst .operands @@ -463,7 +463,7 @@ fn can_fuse_bool( return None; } match inst.operands[0] { - Operand::LiteralInt32(v) => Some(v), + Operand::LiteralBit32(v) => Some(v), _ => None, } } diff --git a/crates/rustc_codegen_spirv/src/linker/simple_passes.rs b/crates/rustc_codegen_spirv/src/linker/simple_passes.rs index 7ca44add13..a828e16e5c 100644 --- a/crates/rustc_codegen_spirv/src/linker/simple_passes.rs +++ b/crates/rustc_codegen_spirv/src/linker/simple_passes.rs @@ -94,7 +94,8 @@ pub fn outgoing_edges(block: &Block) -> impl Iterator + '_ { | Op::Kill | Op::Unreachable | Op::IgnoreIntersectionKHR - | Op::TerminateRayKHR => (0..0).step_by(1), + | Op::TerminateRayKHR + | Op::EmitMeshTasksEXT => (0..0).step_by(1), _ => panic!("Invalid block terminator: {terminator:?}"), }; operand_indices.map(move |i| terminator.operands[i].unwrap_id_ref()) diff --git a/crates/rustc_codegen_spirv/src/linker/specializer.rs b/crates/rustc_codegen_spirv/src/linker/specializer.rs index 3fb003b4b1..658ceeeeeb 100644 --- a/crates/rustc_codegen_spirv/src/linker/specializer.rs +++ b/crates/rustc_codegen_spirv/src/linker/specializer.rs @@ -542,7 +542,7 @@ struct Specializer { // FIXME(eddyb) compact SPIR-V IDs to allow flatter maps. generics: IndexMap, - /// Integer `OpConstant`s (i.e. containing a `LiteralInt32`), to be used + /// Integer `OpConstant`s (i.e. containing a `LiteralBit32`), to be used /// for interpreting `TyPat::IndexComposite` (such as for `OpAccessChain`). int_consts: FxHashMap, } @@ -599,7 +599,7 @@ impl Specializer { // Record all integer `OpConstant`s (used for `IndexComposite`). if inst.class.opcode == Op::Constant { - if let Operand::LiteralInt32(x) = inst.operands[0] { + if let Operand::LiteralBit32(x) = inst.operands[0] { self.int_consts.insert(result_id, x); } } @@ -1190,7 +1190,7 @@ impl<'a> Match<'a> { // known `OpConstant`s (e.g. struct field indices). let maybe_idx = match operand { Operand::IdRef(id) => cx.specializer.int_consts.get(id), - Operand::LiteralInt32(idx) => Some(idx), + Operand::LiteralBit32(idx) => Some(idx), _ => None, }; match maybe_idx { @@ -1315,7 +1315,7 @@ impl<'a, S: Specialization> InferCx<'a, S> { TyPat::Array(pat) => simple(Op::TypeArray, pat), TyPat::Vector(pat) => simple(Op::TypeVector, pat), TyPat::Vector4(pat) => match ty_operands.operands { - [_, Operand::LiteralInt32(4)] => simple(Op::TypeVector, pat), + [_, Operand::LiteralBit32(4)] => simple(Op::TypeVector, pat), _ => Err(Unapplicable), }, TyPat::Matrix(pat) => simple(Op::TypeMatrix, pat), @@ -1722,7 +1722,7 @@ impl<'a, S: Specialization> InferCx<'a, S> { unreachable!("non-constant `OpTypeStruct` field index {}", id); }) } - &Operand::LiteralInt32(i) => i, + &Operand::LiteralBit32(i) => i, _ => { unreachable!("invalid `OpTypeStruct` field index operand {:?}", idx); } diff --git a/crates/rustc_codegen_spirv/src/linker/test.rs b/crates/rustc_codegen_spirv/src/linker/test.rs index 7e92b89f47..c32245db16 100644 --- a/crates/rustc_codegen_spirv/src/linker/test.rs +++ b/crates/rustc_codegen_spirv/src/linker/test.rs @@ -275,7 +275,7 @@ fn standard() { OpMemoryModel Logical OpenCL OpDecorate %1 LinkageAttributes "foo" Export %2 = OpTypeFloat 32 - %3 = OpConstant %2 42.0 + %3 = OpConstant %2 42 %1 = OpVariable %2 Uniform %3"#; without_header_eq(result, expect); diff --git a/crates/rustc_codegen_spirv/src/spirv_type.rs b/crates/rustc_codegen_spirv/src/spirv_type.rs index 73cfdc662f..2253095163 100644 --- a/crates/rustc_codegen_spirv/src/spirv_type.rs +++ b/crates/rustc_codegen_spirv/src/spirv_type.rs @@ -162,7 +162,7 @@ impl SpirvType<'_> { result, index as u32, Decoration::Offset, - [Operand::LiteralInt32(offset.bytes() as u32)] + [Operand::LiteralBit32(offset.bytes() as u32)] .iter() .cloned(), ); @@ -188,7 +188,7 @@ impl SpirvType<'_> { emit.decorate( result, Decoration::ArrayStride, - iter::once(Operand::LiteralInt32(element_size as u32)), + iter::once(Operand::LiteralBit32(element_size as u32)), ); result } @@ -204,7 +204,7 @@ impl SpirvType<'_> { emit.decorate( result, Decoration::ArrayStride, - iter::once(Operand::LiteralInt32(element_size as u32)), + iter::once(Operand::LiteralBit32(element_size as u32)), ); result } @@ -264,7 +264,7 @@ impl SpirvType<'_> { result, 0, Decoration::Offset, - [Operand::LiteralInt32(0)].iter().cloned(), + [Operand::LiteralBit32(0)].iter().cloned(), ); result } diff --git a/crates/rustc_codegen_spirv/src/spirv_type_constraints.rs b/crates/rustc_codegen_spirv/src/spirv_type_constraints.rs index b4e180f913..17d4fb2e7c 100644 --- a/crates/rustc_codegen_spirv/src/spirv_type_constraints.rs +++ b/crates/rustc_codegen_spirv/src/spirv_type_constraints.rs @@ -494,6 +494,12 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> { Op::IAddCarry | Op::ISubBorrow | Op::UMulExtended | Op::SMulExtended => sig! { (T, T) -> Struct([T, T]) }, + Op::SDot | Op::UDot | Op::SUDot | Op::SDotAccSat | Op::UDotAccSat | Op::SUDotAccSat => { + sig! { + // FIXME(eddyb) missing equality constraint between two vectors + (Vector(T), T) -> Vector(T) + } + } // 3.37.14. Bit Instructions Op::ShiftRightLogical @@ -593,6 +599,8 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> { }, // Capability: Kernel Op::AtomicFlagTestAndSet | Op::AtomicFlagClear => {} + // SPV_EXT_shader_atomic_float_min_max + Op::AtomicFMinEXT | Op::AtomicFMaxEXT => sig! { (Pointer(_, T), _, _, T) -> T }, // SPV_EXT_shader_atomic_float_add Op::AtomicFAddEXT => sig! { (Pointer(_, T), _, _, T) -> T }, @@ -821,7 +829,7 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> { // Instructions not present in current SPIR-V specification // SPV_INTEL_function_pointers - Op::FunctionPointerINTEL | Op::FunctionPointerCallINTEL => { + Op::ConstantFunctionPointerINTEL | Op::FunctionPointerCallINTEL => { reserved!(SPV_INTEL_function_pointers); } // SPV_INTEL_device_side_avc_motion_estimation @@ -945,6 +953,182 @@ pub fn instruction_signatures(op: Op) -> Option<&'static [InstSig<'static>]> { | Op::SubgroupAvcSicGetInterRawSadsINTEL => { reserved!(SPV_INTEL_device_side_avc_motion_estimation); } + // SPV_EXT_mesh_shader + Op::EmitMeshTasksEXT | Op::SetMeshOutputsEXT => { + // NOTE(eddyb) we actually use these despite not being in the standard yet. + // reserved!(SPV_EXT_mesh_shader) + } + // SPV_NV_ray_tracing_motion_blur + Op::TraceMotionNV | Op::TraceRayMotionNV => reserved!(SPV_NV_ray_tracing_motion_blur), + // SPV_NV_bindless_texture + Op::ConvertUToImageNV + | Op::ConvertUToSamplerNV + | Op::ConvertImageToUNV + | Op::ConvertSamplerToUNV + | Op::ConvertUToSampledImageNV + | Op::ConvertSampledImageToUNV + | Op::SamplerImageAddressingModeNV => reserved!(SPV_NV_bindless_texture), + // SPV_INTEL_inline_assembly + Op::AsmTargetINTEL | Op::AsmINTEL | Op::AsmCallINTEL => reserved!(SPV_NV_bindless_texture), + // SPV_INTEL_variable_length_array + Op::VariableLengthArrayINTEL => reserved!(SPV_INTEL_variable_length_array), + // SPV_KHR_uniform_group_instructions + Op::GroupIMulKHR + | Op::GroupFMulKHR + | Op::GroupBitwiseAndKHR + | Op::GroupBitwiseOrKHR + | Op::GroupBitwiseXorKHR + | Op::GroupLogicalAndKHR + | Op::GroupLogicalOrKHR + | Op::GroupLogicalXorKHR => reserved!(SPV_KHR_uniform_group_instructions), + // SPV_KHR_expect_assume + Op::AssumeTrueKHR | Op::ExpectKHR => reserved!(SPV_KHR_expect_assume), + // SPV_KHR_subgroup_rotate + Op::GroupNonUniformRotateKHR => reserved!(SPV_KHR_subgroup_rotate), + // SPV_NV_shader_invocation_reorder + Op::HitObjectRecordHitMotionNV + | Op::HitObjectRecordHitWithIndexMotionNV + | Op::HitObjectRecordMissMotionNV + | Op::HitObjectGetWorldToObjectNV + | Op::HitObjectGetObjectToWorldNV + | Op::HitObjectGetObjectRayDirectionNV + | Op::HitObjectGetObjectRayOriginNV + | Op::HitObjectTraceRayMotionNV + | Op::HitObjectGetShaderRecordBufferHandleNV + | Op::HitObjectGetShaderBindingTableRecordIndexNV + | Op::HitObjectRecordEmptyNV + | Op::HitObjectTraceRayNV + | Op::HitObjectRecordHitNV + | Op::HitObjectRecordHitWithIndexNV + | Op::HitObjectRecordMissNV + | Op::HitObjectExecuteShaderNV + | Op::HitObjectGetCurrentTimeNV + | Op::HitObjectGetAttributesNV + | Op::HitObjectGetHitKindNV + | Op::HitObjectGetPrimitiveIndexNV + | Op::HitObjectGetGeometryIndexNV + | Op::HitObjectGetInstanceIdNV + | Op::HitObjectGetInstanceCustomIndexNV + | Op::HitObjectGetWorldRayDirectionNV + | Op::HitObjectGetWorldRayOriginNV + | Op::HitObjectGetRayTMaxNV + | Op::HitObjectGetRayTMinNV + | Op::HitObjectIsEmptyNV + | Op::HitObjectIsHitNV + | Op::HitObjectIsMissNV + | Op::ReorderThreadWithHitObjectNV + | Op::ReorderThreadWithHintNV + | Op::TypeHitObjectNV => reserved!(SPV_NV_shader_invocation_reorder), + // SPV_INTEL_arbitrary_precision_floating_point + Op::ArbitraryFloatAddINTEL + | Op::ArbitraryFloatSubINTEL + | Op::ArbitraryFloatMulINTEL + | Op::ArbitraryFloatDivINTEL + | Op::ArbitraryFloatGTINTEL + | Op::ArbitraryFloatGEINTEL + | Op::ArbitraryFloatLTINTEL + | Op::ArbitraryFloatLEINTEL + | Op::ArbitraryFloatEQINTEL + | Op::ArbitraryFloatRecipINTEL + | Op::ArbitraryFloatRSqrtINTEL + | Op::ArbitraryFloatCbrtINTEL + | Op::ArbitraryFloatHypotINTEL + | Op::ArbitraryFloatSqrtINTEL + | Op::ArbitraryFloatLogINTEL + | Op::ArbitraryFloatLog2INTEL + | Op::ArbitraryFloatLog10INTEL + | Op::ArbitraryFloatLog1pINTEL + | Op::ArbitraryFloatExpINTEL + | Op::ArbitraryFloatExp2INTEL + | Op::ArbitraryFloatExp10INTEL + | Op::ArbitraryFloatExpm1INTEL + | Op::ArbitraryFloatSinINTEL + | Op::ArbitraryFloatCosINTEL + | Op::ArbitraryFloatSinCosINTEL + | Op::ArbitraryFloatSinPiINTEL + | Op::ArbitraryFloatCosPiINTEL + | Op::ArbitraryFloatSinCosPiINTEL + | Op::ArbitraryFloatASinINTEL + | Op::ArbitraryFloatASinPiINTEL + | Op::ArbitraryFloatACosINTEL + | Op::ArbitraryFloatACosPiINTEL + | Op::ArbitraryFloatATanINTEL + | Op::ArbitraryFloatATanPiINTEL + | Op::ArbitraryFloatATan2INTEL + | Op::ArbitraryFloatPowINTEL + | Op::ArbitraryFloatPowRINTEL + | Op::ArbitraryFloatPowNINTEL => { + reserved!(SPV_INTEL_arbitrary_precision_floating_point) + } + // TODO these instructions are outdated, and will be replaced by the ones in comments below. When updating, consider merging with the branch above. + Op::ArbitraryFloatCastINTEL + | Op::ArbitraryFloatCastFromIntINTEL + | Op::ArbitraryFloatCastToIntINTEL => { + // Op::ArbitraryFloatConvertINTEL + // | Op::ArbitraryFloatConvertFromUIntINTEL + // | Op::ArbitraryFloatConvertFromSIntINTEL + // | Op::ArbitraryFloatConvertToUIntINTEL + // | Op::ArbitraryFloatConvertToSIntINTEL + reserved!(SPV_INTEL_arbitrary_precision_floating_point) + } + // SPV_INTEL_arbitrary_precision_fixed_point + Op::FixedSqrtINTEL + | Op::FixedRecipINTEL + | Op::FixedRsqrtINTEL + | Op::FixedSinINTEL + | Op::FixedCosINTEL + | Op::FixedSinCosINTEL + | Op::FixedSinPiINTEL + | Op::FixedCosPiINTEL + | Op::FixedSinCosPiINTEL + | Op::FixedLogINTEL + | Op::FixedExpINTEL => reserved!(SPV_INTEL_arbitrary_precision_fixed_point), + // SPV_EXT_shader_tile_image + Op::ColorAttachmentReadEXT | Op::DepthAttachmentReadEXT | Op::StencilAttachmentReadEXT => { + reserved!(SPV_EXT_shader_tile_image) + } + // SPV_KHR_cooperative_matrix + Op::TypeCooperativeMatrixKHR + | Op::CooperativeMatrixLoadKHR + | Op::CooperativeMatrixStoreKHR + | Op::CooperativeMatrixMulAddKHR + | Op::CooperativeMatrixLengthKHR => reserved!(SPV_KHR_cooperative_matrix), + // SPV_QCOM_image_processing + Op::ImageSampleWeightedQCOM + | Op::ImageBoxFilterQCOM + | Op::ImageBlockMatchSSDQCOM + | Op::ImageBlockMatchSADQCOM => reserved!(SPV_QCOM_image_processing), + // SPV_AMDX_shader_enqueue + Op::FinalizeNodePayloadsAMDX + | Op::FinishWritingNodePayloadAMDX + | Op::InitializeNodePayloadsAMDX => reserved!(SPV_AMDX_shader_enqueue), + // SPV_NV_displacement_micromap + Op::FetchMicroTriangleVertexPositionNV | Op::FetchMicroTriangleVertexBarycentricNV => { + reserved!(SPV_NV_displacement_micromap) + } + // SPV_KHR_ray_tracing_position_fetch + Op::RayQueryGetIntersectionTriangleVertexPositionsKHR => { + reserved!(SPV_KHR_ray_tracing_position_fetch) + } + // SPV_INTEL_bfloat16_conversion + Op::ConvertFToBF16INTEL | Op::ConvertBF16ToFINTEL => { + reserved!(SPV_INTEL_bfloat16_conversion) + } + + // TODO unknown_extension_INTEL + Op::SaveMemoryINTEL + | Op::RestoreMemoryINTEL + | Op::AliasDomainDeclINTEL + | Op::AliasScopeDeclINTEL + | Op::AliasScopeListDeclINTEL + | Op::PtrCastToCrossWorkgroupINTEL + | Op::CrossWorkgroupCastToPtrINTEL + | Op::TypeBufferSurfaceINTEL + | Op::TypeStructContinuedINTEL + | Op::ConstantCompositeContinuedINTEL + | Op::SpecConstantCompositeContinuedINTEL + | Op::ControlBarrierArriveINTEL + | Op::ControlBarrierWaitINTEL => reserved!(unknown_extension_INTEL), } None diff --git a/crates/rustc_codegen_spirv/src/symbols.rs b/crates/rustc_codegen_spirv/src/symbols.rs index 4129187b39..1aadf02ee0 100644 --- a/crates/rustc_codegen_spirv/src/symbols.rs +++ b/crates/rustc_codegen_spirv/src/symbols.rs @@ -118,8 +118,17 @@ const BUILTINS: &[(&str, BuiltIn)] = { ("layer_per_view_nv", LayerPerViewNV), ("mesh_view_count_nv", MeshViewCountNV), ("mesh_view_indices_nv", MeshViewIndicesNV), - ("bary_coord_nv", BaryCoordNV), - ("bary_coord_no_persp_nv", BaryCoordNoPerspNV), + ("bary_coord_nv", BuiltIn::BaryCoordNV), + ("bary_coord_no_persp_nv", BuiltIn::BaryCoordNoPerspNV), + ("bary_coord", BaryCoordKHR), + ("bary_coord_no_persp", BaryCoordNoPerspKHR), + ("primitive_point_indices_ext", PrimitivePointIndicesEXT), + ("primitive_line_indices_ext", PrimitiveLineIndicesEXT), + ( + "primitive_triangle_indices_ext", + PrimitiveTriangleIndicesEXT, + ), + ("cull_primitive_ext", CullPrimitiveEXT), ("frag_size_ext", FragSizeEXT), ("frag_invocation_count_ext", FragInvocationCountEXT), ("launch_id", BuiltIn::LaunchIdKHR), @@ -169,6 +178,7 @@ const STORAGE_CLASSES: &[(&str, StorageClass)] = { ("incoming_ray_payload", StorageClass::IncomingRayPayloadKHR), ("shader_record_buffer", StorageClass::ShaderRecordBufferKHR), ("physical_storage_buffer", PhysicalStorageBuffer), + ("task_payload_workgroup_ext", TaskPayloadWorkgroupEXT), ] }; @@ -183,6 +193,8 @@ const EXECUTION_MODELS: &[(&str, ExecutionModel)] = { ("compute", GLCompute), ("task_nv", TaskNV), ("mesh_nv", MeshNV), + ("task_ext", TaskEXT), + ("mesh_ext", MeshEXT), ("ray_generation", ExecutionModel::RayGenerationKHR), ("intersection", ExecutionModel::IntersectionKHR), ("any_hit", ExecutionModel::AnyHitKHR), @@ -263,6 +275,17 @@ const EXECUTION_MODES: &[(&str, ExecutionMode, ExecutionModeExtraDim)] = { ("output_primitives_nv", OutputPrimitivesNV, Value), ("derivative_group_quads_nv", DerivativeGroupQuadsNV, None), ("output_triangles_nv", OutputTrianglesNV, None), + ("output_lines_ext", ExecutionMode::OutputLinesEXT, None), + ( + "output_triangles_ext", + ExecutionMode::OutputTrianglesEXT, + None, + ), + ( + "output_primitives_ext", + ExecutionMode::OutputPrimitivesEXT, + Value, + ), ( "pixel_interlock_ordered_ext", PixelInterlockOrderedEXT, @@ -334,6 +357,7 @@ impl Symbols { ("block", SpirvAttribute::Block), ("flat", SpirvAttribute::Flat), ("invariant", SpirvAttribute::Invariant), + ("per_primitive_ext", SpirvAttribute::PerPrimitiveExt), ( "sampled_image", SpirvAttribute::IntrinsicType(IntrinsicType::SampledImage), @@ -715,7 +739,7 @@ fn parse_entry_attrs( .execution_modes .push((origin_mode, ExecutionModeExtra::new([]))); } - GLCompute | MeshNV | TaskNV => { + GLCompute | MeshNV | TaskNV | TaskEXT | MeshEXT => { if let Some(local_size) = local_size { entry .execution_modes @@ -724,7 +748,7 @@ fn parse_entry_attrs( return Err(( arg.span(), String::from( - "The `threads` argument must be specified when using `#[spirv(compute)]`, `#[spirv(mesh_nv)]` or `#[spirv(task_nv)]`", + "The `threads` argument must be specified when using `#[spirv(compute)]`, `#[spirv(mesh_nv)]`, `#[spirv(task_nv)]`, `#[spirv(task_ext)]` or `#[spirv(mesh_ext)]`", ), )); } diff --git a/crates/spirv-std/src/arch.rs b/crates/spirv-std/src/arch.rs index 0fa43fea0f..07dbfc15ce 100644 --- a/crates/spirv-std/src/arch.rs +++ b/crates/spirv-std/src/arch.rs @@ -17,6 +17,7 @@ mod atomics; mod barrier; mod demote_to_helper_invocation_ext; mod derivative; +mod mesh_shading; mod primitive; mod ray_tracing; @@ -24,6 +25,7 @@ pub use atomics::*; pub use barrier::*; pub use demote_to_helper_invocation_ext::*; pub use derivative::*; +pub use mesh_shading::*; pub use primitive::*; pub use ray_tracing::*; diff --git a/crates/spirv-std/src/arch/mesh_shading.rs b/crates/spirv-std/src/arch/mesh_shading.rs new file mode 100644 index 0000000000..e3806ee86d --- /dev/null +++ b/crates/spirv-std/src/arch/mesh_shading.rs @@ -0,0 +1,109 @@ +#[cfg(target_arch = "spirv")] +use core::arch::asm; + +/// Sets the actual output size of the primitives and vertices that the mesh shader +/// workgroup will emit upon completion. +/// +/// 'Vertex Count' must be a 32-bit unsigned integer value. +/// It defines the array size of per-vertex outputs. +/// +/// 'Primitive Count' must a 32-bit unsigned integer value. +/// It defines the array size of per-primitive outputs. +/// +/// The arguments are taken from the first invocation in each workgroup. +/// Any invocation must execute this instruction no more than once and under +/// uniform control flow. +/// There must not be any control flow path to an output write that is not preceded +/// by this instruction. +/// +/// This instruction is only valid in the *MeshEXT* Execution Model. +#[spirv_std_macros::gpu_only] +#[doc(alias = "OpSetMeshOutputsEXT")] +#[inline] +pub unsafe fn set_mesh_outputs_ext(vertex_count: u32, primitive_count: u32) { + asm! { + "OpSetMeshOutputsEXT {vertex_count} {primitive_count}", + vertex_count = in(reg) vertex_count, + primitive_count = in(reg) primitive_count, + } +} + +/// Defines the grid size of subsequent mesh shader workgroups to generate +/// upon completion of the task shader workgroup. +/// +/// 'Group Count X Y Z' must each be a 32-bit unsigned integer value. +/// They configure the number of local workgroups in each respective dimensions +/// for the launch of child mesh tasks. See Vulkan API specification for more detail. +/// +/// 'Payload' is an optional pointer to the payload structure to pass to the generated mesh shader invocations. +/// 'Payload' must be the result of an *OpVariable* with a storage class of *TaskPayloadWorkgroupEXT*. +/// +/// The arguments are taken from the first invocation in each workgroup. +/// Any invocation must execute this instruction exactly once and under uniform +/// control flow. +/// This instruction also serves as an *OpControlBarrier* instruction, and also +/// performs and adheres to the description and semantics of an *OpControlBarrier* +/// instruction with the 'Execution' and 'Memory' operands set to *Workgroup* and +/// the 'Semantics' operand set to a combination of *WorkgroupMemory* and +/// *AcquireRelease*. +/// Ceases all further processing: Only instructions executed before +/// *OpEmitMeshTasksEXT* have observable side effects. +/// +/// This instruction must be the last instruction in a block. +/// +/// This instruction is only valid in the *TaskEXT* Execution Model. +#[spirv_std_macros::gpu_only] +#[doc(alias = "OpEmitMeshTasksEXT")] +#[inline] +pub unsafe fn emit_mesh_tasks_ext(group_count_x: u32, group_count_y: u32, group_count_z: u32) -> ! { + asm! { + "OpEmitMeshTasksEXT {group_count_x} {group_count_y} {group_count_z}", + group_count_x = in(reg) group_count_x, + group_count_y = in(reg) group_count_y, + group_count_z = in(reg) group_count_z, + options(noreturn), + } +} + +/// Defines the grid size of subsequent mesh shader workgroups to generate +/// upon completion of the task shader workgroup. +/// +/// 'Group Count X Y Z' must each be a 32-bit unsigned integer value. +/// They configure the number of local workgroups in each respective dimensions +/// for the launch of child mesh tasks. See Vulkan API specification for more detail. +/// +/// 'Payload' is an optional pointer to the payload structure to pass to the generated mesh shader invocations. +/// 'Payload' must be the result of an *OpVariable* with a storage class of *TaskPayloadWorkgroupEXT*. +/// +/// The arguments are taken from the first invocation in each workgroup. +/// Any invocation must execute this instruction exactly once and under uniform +/// control flow. +/// This instruction also serves as an *OpControlBarrier* instruction, and also +/// performs and adheres to the description and semantics of an *OpControlBarrier* +/// instruction with the 'Execution' and 'Memory' operands set to *Workgroup* and +/// the 'Semantics' operand set to a combination of *WorkgroupMemory* and +/// *AcquireRelease*. +/// Ceases all further processing: Only instructions executed before +/// *OpEmitMeshTasksEXT* have observable side effects. +/// +/// This instruction must be the last instruction in a block. +/// +/// This instruction is only valid in the *TaskEXT* Execution Model. +#[spirv_std_macros::gpu_only] +#[doc(alias = "OpEmitMeshTasksEXT")] +#[inline] +pub unsafe fn emit_mesh_tasks_ext_payload( + group_count_x: u32, + group_count_y: u32, + group_count_z: u32, + payload: &mut T, +) -> ! { + asm! { + "OpEmitMeshTasksEXT {group_count_x} {group_count_y} {group_count_z} {payload}", + group_count_x = in(reg) group_count_x, + group_count_y = in(reg) group_count_y, + group_count_z = in(reg) group_count_z, + payload = in(reg) payload, + options(noreturn), + } +} diff --git a/tests/ui/arch/mesh_shader_output_lines.rs b/tests/ui/arch/mesh_shader_output_lines.rs new file mode 100644 index 0000000000..e2e83694c1 --- /dev/null +++ b/tests/ui/arch/mesh_shader_output_lines.rs @@ -0,0 +1,27 @@ +// build-pass +// only-vulkan1.2 +// compile-flags: -Ctarget-feature=+MeshShadingEXT,+ext:SPV_EXT_mesh_shader + +use spirv_std::arch::set_mesh_outputs_ext; +use spirv_std::glam::{UVec2, Vec4}; +use spirv_std::spirv; + +#[spirv(mesh_ext( + threads(1), + output_vertices = 2, + output_primitives_ext = 1, + output_lines_ext +))] +pub fn main( + #[spirv(position)] positions: &mut [Vec4; 2], + #[spirv(primitive_line_indices_ext)] indices: &mut [UVec2; 1], +) { + unsafe { + set_mesh_outputs_ext(2, 1); + } + + positions[0] = Vec4::new(-0.5, 0.5, 0.0, 1.0); + positions[1] = Vec4::new(0.5, 0.5, 0.0, 1.0); + + indices[0] = UVec2::new(0, 1); +} diff --git a/tests/ui/arch/mesh_shader_output_points.rs b/tests/ui/arch/mesh_shader_output_points.rs new file mode 100644 index 0000000000..1cc7662ca9 --- /dev/null +++ b/tests/ui/arch/mesh_shader_output_points.rs @@ -0,0 +1,26 @@ +// build-pass +// only-vulkan1.2 +// compile-flags: -Ctarget-feature=+MeshShadingEXT,+ext:SPV_EXT_mesh_shader + +use spirv_std::arch::set_mesh_outputs_ext; +use spirv_std::glam::{UVec2, Vec4}; +use spirv_std::spirv; + +#[spirv(mesh_ext( + threads(1), + output_vertices = 1, + output_primitives_ext = 1, + output_points +))] +pub fn main( + #[spirv(position)] positions: &mut [Vec4; 1], + #[spirv(primitive_point_indices_ext)] indices: &mut [u32; 1], +) { + unsafe { + set_mesh_outputs_ext(1, 1); + } + + positions[0] = Vec4::new(-0.5, 0.5, 0.0, 1.0); + + indices[0] = 0; +} diff --git a/tests/ui/arch/mesh_shader_output_triangles.rs b/tests/ui/arch/mesh_shader_output_triangles.rs new file mode 100644 index 0000000000..289f1faf62 --- /dev/null +++ b/tests/ui/arch/mesh_shader_output_triangles.rs @@ -0,0 +1,28 @@ +// build-pass +// only-vulkan1.2 +// compile-flags: -Ctarget-feature=+MeshShadingEXT,+ext:SPV_EXT_mesh_shader + +use spirv_std::arch::set_mesh_outputs_ext; +use spirv_std::glam::{UVec3, Vec4}; +use spirv_std::spirv; + +#[spirv(mesh_ext( + threads(1), + output_vertices = 3, + output_primitives_ext = 1, + output_triangles_ext +))] +pub fn main( + #[spirv(position)] positions: &mut [Vec4; 3], + #[spirv(primitive_triangle_indices_ext)] indices: &mut [UVec3; 1], +) { + unsafe { + set_mesh_outputs_ext(3, 1); + } + + positions[0] = Vec4::new(-0.5, 0.5, 0.0, 1.0); + positions[1] = Vec4::new(0.5, 0.5, 0.0, 1.0); + positions[2] = Vec4::new(0.0, -0.5, 0.0, 1.0); + + indices[0] = UVec3::new(0, 1, 2); +} diff --git a/tests/ui/arch/mesh_shader_payload.rs b/tests/ui/arch/mesh_shader_payload.rs new file mode 100644 index 0000000000..3fdf11f683 --- /dev/null +++ b/tests/ui/arch/mesh_shader_payload.rs @@ -0,0 +1,35 @@ +// build-pass +// only-vulkan1.2 +// compile-flags: -Ctarget-feature=+MeshShadingEXT,+ext:SPV_EXT_mesh_shader + +use spirv_std::arch::set_mesh_outputs_ext; +use spirv_std::glam::{UVec3, Vec4}; +use spirv_std::spirv; + +pub struct Payload { + pub first: f32, + pub second: f32, + pub third: f32, +} + +#[spirv(mesh_ext( + threads(1), + output_vertices = 3, + output_primitives_ext = 1, + output_triangles_ext +))] +pub fn main( + #[spirv(position)] positions: &mut [Vec4; 3], + #[spirv(primitive_triangle_indices_ext)] indices: &mut [UVec3; 1], + #[spirv(task_payload_workgroup_ext)] payload: &Payload, +) { + unsafe { + set_mesh_outputs_ext(3, 1); + } + + positions[0] = payload.first * Vec4::new(-0.5, 0.5, 0.0, 1.0); + positions[1] = payload.second * Vec4::new(0.5, 0.5, 0.0, 1.0); + positions[2] = payload.third * Vec4::new(0.0, -0.5, 0.0, 1.0); + + indices[0] = UVec3::new(0, 1, 2); +} diff --git a/tests/ui/arch/mesh_shader_per_primitive.rs b/tests/ui/arch/mesh_shader_per_primitive.rs new file mode 100644 index 0000000000..29a060cbc0 --- /dev/null +++ b/tests/ui/arch/mesh_shader_per_primitive.rs @@ -0,0 +1,34 @@ +// build-pass +// only-vulkan1.2 +// compile-flags: -Ctarget-feature=+MeshShadingEXT,+ext:SPV_EXT_mesh_shader + +use spirv_std::arch::set_mesh_outputs_ext; +use spirv_std::glam::{UVec3, Vec4}; +use spirv_std::spirv; + +#[spirv(mesh_ext( + threads(1), + output_vertices = 3, + output_primitives_ext = 1, + output_triangles_ext +))] +pub fn main( + #[spirv(position)] positions: &mut [Vec4; 3], + out_per_vertex: &mut [u32; 3], + #[spirv(per_primitive_ext)] out_per_primitive: &mut [u32; 1], + #[spirv(primitive_triangle_indices_ext)] indices: &mut [UVec3; 1], +) { + unsafe { + set_mesh_outputs_ext(3, 1); + } + + positions[0] = Vec4::new(-0.5, 0.5, 0.0, 1.0); + positions[1] = Vec4::new(0.5, 0.5, 0.0, 1.0); + positions[2] = Vec4::new(0.0, -0.5, 0.0, 1.0); + out_per_vertex[0] = 0; + out_per_vertex[1] = 1; + out_per_vertex[2] = 2; + + indices[0] = UVec3::new(0, 1, 2); + out_per_primitive[0] = 42; +} diff --git a/tests/ui/arch/task_shader.rs b/tests/ui/arch/task_shader.rs new file mode 100644 index 0000000000..ea23516869 --- /dev/null +++ b/tests/ui/arch/task_shader.rs @@ -0,0 +1,13 @@ +// build-pass +// only-vulkan1.2 +// compile-flags: -Ctarget-feature=+MeshShadingEXT,+ext:SPV_EXT_mesh_shader + +use spirv_std::arch::emit_mesh_tasks_ext; +use spirv_std::spirv; + +#[spirv(task_ext(threads(1)))] +pub fn main() { + unsafe { + emit_mesh_tasks_ext(1, 2, 3); + } +} diff --git a/tests/ui/arch/task_shader_mispile.rs b/tests/ui/arch/task_shader_mispile.rs new file mode 100644 index 0000000000..ec012789f6 --- /dev/null +++ b/tests/ui/arch/task_shader_mispile.rs @@ -0,0 +1,14 @@ +// build-pass +// only-vulkan1.2 +// compile-flags: -Ctarget-feature=+MeshShadingEXT,+ext:SPV_EXT_mesh_shader + +use spirv_std::arch::emit_mesh_tasks_ext; +use spirv_std::spirv; + +#[spirv(task_ext(threads(1)))] +pub fn main(#[spirv(push_constant)] push: &u32) { + let count = 20 / *push; + unsafe { + emit_mesh_tasks_ext(1, 2, 3); + } +} diff --git a/tests/ui/arch/task_shader_payload.rs b/tests/ui/arch/task_shader_payload.rs new file mode 100644 index 0000000000..30f72a094a --- /dev/null +++ b/tests/ui/arch/task_shader_payload.rs @@ -0,0 +1,21 @@ +// build-pass +// only-vulkan1.2 +// compile-flags: -Ctarget-feature=+MeshShadingEXT,+ext:SPV_EXT_mesh_shader + +use spirv_std::arch::emit_mesh_tasks_ext_payload; +use spirv_std::spirv; + +pub struct Payload { + pub first: u32, + pub second: i32, +} + +#[spirv(task_ext(threads(1)))] +pub fn main(#[spirv(task_payload_workgroup_ext)] payload: &mut Payload) { + payload.first = 1; + payload.second = 2; + + unsafe { + emit_mesh_tasks_ext_payload(3, 4, 5, payload); + } +} diff --git a/tests/ui/spirv-attr/all-builtins.rs b/tests/ui/spirv-attr/all-builtins.rs index 3f79ae484e..8ad8a988a0 100644 --- a/tests/ui/spirv-attr/all-builtins.rs +++ b/tests/ui/spirv-attr/all-builtins.rs @@ -1,6 +1,6 @@ // build-pass // only-vulkan1.1 -// compile-flags: -Ctarget-feature=+DeviceGroup,+DrawParameters,+FragmentBarycentricNV,+FragmentDensityEXT,+FragmentFullyCoveredEXT,+Geometry,+GroupNonUniform,+GroupNonUniformBallot,+MeshShadingNV,+MultiView,+MultiViewport,+RayTracingKHR,+SampleRateShading,+ShaderSMBuiltinsNV,+ShaderStereoViewNV,+StencilExportEXT,+Tessellation,+ext:SPV_AMD_shader_explicit_vertex_parameter,+ext:SPV_EXT_fragment_fully_covered,+ext:SPV_EXT_fragment_invocation_density,+ext:SPV_EXT_shader_stencil_export,+ext:SPV_KHR_ray_tracing,+ext:SPV_NV_fragment_shader_barycentric,+ext:SPV_NV_mesh_shader,+ext:SPV_NV_shader_sm_builtins,+ext:SPV_NV_stereo_view_rendering +// compile-flags: -Ctarget-feature=+DeviceGroup,+DrawParameters,+FragmentBarycentricNV,+FragmentBarycentricKHR,+FragmentDensityEXT,+FragmentFullyCoveredEXT,+Geometry,+GroupNonUniform,+GroupNonUniformBallot,+MeshShadingNV,+MultiView,+MultiViewport,+RayTracingKHR,+SampleRateShading,+ShaderSMBuiltinsNV,+ShaderStereoViewNV,+StencilExportEXT,+Tessellation,+ext:SPV_AMD_shader_explicit_vertex_parameter,+ext:SPV_EXT_fragment_fully_covered,+ext:SPV_EXT_fragment_invocation_density,+ext:SPV_EXT_shader_stencil_export,+ext:SPV_KHR_ray_tracing,+ext:SPV_NV_fragment_shader_barycentric,+ext:SPV_NV_mesh_shader,+ext:SPV_NV_shader_sm_builtins,+ext:SPV_NV_stereo_view_rendering use spirv_std::glam::*; use spirv_std::spirv; @@ -82,10 +82,12 @@ pub fn vertex( #[spirv(fragment)] pub fn fragment( + #[spirv(bary_coord_no_persp)] bary_coord_no_persp: Vec3, #[spirv(bary_coord_no_persp_amd)] bary_coord_no_persp_amd: Vec3, #[spirv(bary_coord_no_persp_centroid_amd)] bary_coord_no_persp_centroid_amd: Vec3, #[spirv(bary_coord_no_persp_nv)] bary_coord_no_persp_nv: Vec3, #[spirv(bary_coord_no_persp_sample_amd)] bary_coord_no_persp_sample_amd: Vec3, + #[spirv(bary_coord)] bary_coord: Vec3, #[spirv(bary_coord_nv)] bary_coord_nv: Vec3, #[spirv(bary_coord_pull_model_amd)] bary_coord_pull_model_amd: Vec3, #[spirv(bary_coord_smooth_amd)] bary_coord_smooth_amd: Vec3,