Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
jimblandy committed Feb 27, 2025
1 parent 3229aab commit 921f696
Show file tree
Hide file tree
Showing 21 changed files with 3,411 additions and 3,481 deletions.
6 changes: 2 additions & 4 deletions naga/src/back/spv/image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1235,10 +1235,8 @@ impl BlockContext<'_> {
return Err(Error::Validation("Invalid image class"));
};
let scalar = format.into();
let pointer_type_id = self.get_type_id(LookupType::Local(LocalType::LocalPointer {
base: NumericType::Scalar(scalar),
class: spirv::StorageClass::Image,
}));
let scalar_type_id = self.get_numeric_type_id(NumericType::Scalar(scalar));
let pointer_type_id = self.get_pointer_type_id(scalar_type_id, spirv::StorageClass::Image);
let signed = scalar.kind == crate::ScalarKind::Sint;
if scalar.width == 8 {
self.writer
Expand Down
3 changes: 2 additions & 1 deletion naga/src/back/spv/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,9 @@ impl BlockContext<'_> {
Some(index_id) => {
let element_type_id = match self.ir_module.types[global.ty].inner {
crate::TypeInner::BindingArray { base, size: _ } => {
let base_id = self.get_handle_type_id(base);
let class = map_storage_class(global.space);
self.get_pointer_type_id(base, class)
self.get_pointer_type_id(base_id, class)
}
_ => return Err(Error::Validation("array length expression case-5")),
};
Expand Down
72 changes: 7 additions & 65 deletions naga/src/back/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,8 +363,10 @@ impl NumericType {
/// variant, is designed to help us deduplicate `OpTypeImage` instructions. See
/// its documentation for details.
///
/// `LocalType` also includes variants like `Pointer` that do not need to be
/// unique - but it is harmless to avoid the duplication.
/// SPIR-V does not require pointer types to be unique - but different
/// SPIR-V ids are considered to be distinct pointer types. Since Naga
/// uses structural type equality, we need to represent each Naga
/// equivalence class with a single SPIR-V `OpTypePointer`.
///
/// As it always must, the `Hash` implementation respects the `Eq` relation.
///
Expand All @@ -373,29 +375,15 @@ impl NumericType {
enum LocalType {
/// A numeric type.
Numeric(NumericType),
LocalPointer {
base: NumericType,
class: spirv::StorageClass,
},
Pointer {
base: Handle<crate::Type>,
base: Word,
class: spirv::StorageClass,
},
Image(LocalImageType),
SampledImage {
image_type_id: Word,
},
Sampler,
/// Equivalent to a [`LocalType::Pointer`] whose `base` is a Naga IR [`BindingArray`]. SPIR-V
/// permits duplicated `OpTypePointer` ids, so it's fine to have two different [`LocalType`]
/// representations for pointer types.
///
/// [`BindingArray`]: crate::TypeInner::BindingArray
PointerToBindingArray {
base: Handle<crate::Type>,
size: u32,
space: crate::AddressSpace,
},
BindingArray {
base: Handle<crate::Type>,
size: u32,
Expand Down Expand Up @@ -444,52 +432,6 @@ struct LookupFunctionType {
return_type_id: Word,
}

impl LocalType {
fn from_inner(inner: &crate::TypeInner) -> Option<Self> {
Some(match *inner {
crate::TypeInner::Scalar(_)
| crate::TypeInner::Atomic(_)
| crate::TypeInner::Vector { .. }
| crate::TypeInner::Matrix { .. } => {
// We expect `NumericType::from_inner` to handle all
// these cases, so unwrap.
LocalType::Numeric(NumericType::from_inner(inner).unwrap())
}
crate::TypeInner::Pointer { base, space } => LocalType::Pointer {
base,
class: helpers::map_storage_class(space),
},
crate::TypeInner::ValuePointer {
size: Some(size),
scalar,
space,
} => LocalType::LocalPointer {
base: NumericType::Vector { size, scalar },
class: helpers::map_storage_class(space),
},
crate::TypeInner::ValuePointer {
size: None,
scalar,
space,
} => LocalType::LocalPointer {
base: NumericType::Scalar(scalar),
class: helpers::map_storage_class(space),
},
crate::TypeInner::Image {
dim,
arrayed,
class,
} => LocalType::Image(LocalImageType::from_inner(dim, arrayed, class)),
crate::TypeInner::Sampler { comparison: _ } => LocalType::Sampler,
crate::TypeInner::AccelerationStructure => LocalType::AccelerationStructure,
crate::TypeInner::RayQuery => LocalType::RayQuery,
crate::TypeInner::Array { .. }
| crate::TypeInner::Struct { .. }
| crate::TypeInner::BindingArray { .. } => return None,
})
}
}

#[derive(Debug)]
enum Dimension {
Scalar,
Expand Down Expand Up @@ -760,10 +702,10 @@ impl BlockContext<'_> {

fn get_pointer_type_id(
&mut self,
handle: Handle<crate::Type>,
base: Word,
class: spirv::StorageClass,
) -> Word {
self.writer.get_pointer_type_id(handle, class)
self.writer.get_pointer_type_id(base, class)
}

fn get_numeric_type_id(&mut self, numeric: NumericType) -> Word {
Expand Down
47 changes: 6 additions & 41 deletions naga/src/back/spv/ray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,65 +21,30 @@ impl Writer {
let ray_intersection = ir_module.special_types.ray_intersection.unwrap();
let intersection_type_id = self.get_handle_type_id(ray_intersection);
let intersection_pointer_type_id =
self.get_pointer_type_id(ray_intersection, spirv::StorageClass::Function);
self.get_pointer_type_id(intersection_type_id, spirv::StorageClass::Function);

let flag_type_id = self.get_u32_type_id();
let flag_type = ir_module
.types
.get(&Type {
name: None,
inner: TypeInner::Scalar(crate::Scalar::U32),
})
.unwrap();
let flag_pointer_type_id =
self.get_pointer_type_id(flag_type, spirv::StorageClass::Function);
self.get_pointer_type_id(flag_type_id, spirv::StorageClass::Function);

let transform_type_id = self.get_numeric_type_id(NumericType::Matrix {
columns: crate::VectorSize::Quad,
rows: crate::VectorSize::Tri,
scalar: crate::Scalar::F32,
});
let transform_type = ir_module
.types
.get(&Type {
name: None,
inner: TypeInner::Matrix {
columns: crate::VectorSize::Quad,
rows: crate::VectorSize::Tri,
scalar: crate::Scalar::F32,
},
})
.unwrap();
let transform_pointer_type_id =
self.get_pointer_type_id(transform_type, spirv::StorageClass::Function);
self.get_pointer_type_id(transform_type_id, spirv::StorageClass::Function);

let barycentrics_type_id = self.get_numeric_type_id(NumericType::Vector {
size: crate::VectorSize::Bi,
scalar: crate::Scalar::F32,
});
let barycentrics_type = ir_module
.types
.get(&Type {
name: None,
inner: TypeInner::Vector {
size: crate::VectorSize::Bi,
scalar: crate::Scalar::F32,
},
})
.unwrap();
let barycentrics_pointer_type_id =
self.get_pointer_type_id(barycentrics_type, spirv::StorageClass::Function);
self.get_pointer_type_id(barycentrics_type_id, spirv::StorageClass::Function);

let bool_type_id = self.get_bool_type_id();
let bool_type = ir_module
.types
.get(&Type {
name: None,
inner: TypeInner::Scalar(crate::Scalar::BOOL),
})
.unwrap();
let bool_pointer_type_id =
self.get_pointer_type_id(bool_type, spirv::StorageClass::Function);
self.get_pointer_type_id(bool_type_id, spirv::StorageClass::Function);

let scalar_type_id = self.get_f32_type_id();
let float_pointer_type_id = self.get_f32_pointer_type_id(spirv::StorageClass::Function);
Expand All @@ -91,7 +56,7 @@ impl Writer {
inner: TypeInner::RayQuery,
})
.expect("ray_query type should have been populated by the variable passed into this!");
let argument_type_id = self.get_pointer_type_id(rq_ty, spirv::StorageClass::Function);
let argument_type_id = self.get_handle_pointer_type_id(rq_ty, spirv::StorageClass::Function);
let func_ty = self.get_function_type(LookupFunctionType {
parameter_type_ids: vec![argument_type_id],
return_type_id: intersection_type_id,
Expand Down
Loading

0 comments on commit 921f696

Please sign in to comment.