diff --git a/src/correlation.rs b/src/correlation.rs index 3452d76..12a3f3e 100644 --- a/src/correlation.rs +++ b/src/correlation.rs @@ -4,9 +4,9 @@ mod metal; mod vk; #[cfg(target_os = "macos")] -use metal as gpu; +type GpuContext = metal::GpuContext; #[cfg(not(target_os = "macos"))] -use vk as gpu; +type GpuContext = vk::GpuContext; use crate::data::{Grid, Point2D}; use nalgebra::{Matrix3, Vector3}; @@ -67,10 +67,11 @@ pub struct PointCorrelations { correlation_threshold: f32, corridor_extend_range: f64, fundamental_matrix: Matrix3, - gpu_context: Option, + gpu_context: Option, selected_hardware: String, } +#[derive(Copy, Clone)] enum CorrelationDirection { Forward, Reverse, @@ -141,7 +142,7 @@ impl PointCorrelations { let gpu_context = if matches!(hardware_mode, HardwareMode::Gpu | HardwareMode::GpuLowPower) { let low_power = matches!(hardware_mode, HardwareMode::GpuLowPower); - match gpu::GpuContext::new( + match GpuContext::new( (img1_dimensions.0 as usize, img1_dimensions.1 as usize), (img2_dimensions.0 as usize, img2_dimensions.1 as usize), projection_mode, @@ -224,8 +225,8 @@ impl PointCorrelations { CorrelationDirection::Reverse, )?; - self.cross_check_filter(scale, CorrelationDirection::Forward); - self.cross_check_filter(scale, CorrelationDirection::Reverse); + self.cross_check_filter(scale, CorrelationDirection::Forward)?; + self.cross_check_filter(scale, CorrelationDirection::Reverse)?; self.first_pass = false; @@ -543,10 +544,14 @@ impl PointCorrelations { .floor() as usize } - fn cross_check_filter(&mut self, scale: f32, dir: CorrelationDirection) { + fn cross_check_filter( + &mut self, + scale: f32, + dir: CorrelationDirection, + ) -> Result<(), Box> { if let Some(gpu_context) = &mut self.gpu_context { - gpu_context.cross_check_filter(scale, dir); - return; + gpu_context.cross_check_filter(scale, dir)?; + return Ok(()); }; let (correlated_points, correlated_points_reverse) = match dir { @@ -572,6 +577,7 @@ impl PointCorrelations { } } }); + Ok(()) } #[inline] diff --git a/src/correlation/cross_check_filter.comp.glsl b/src/correlation/cross_check_filter.comp.glsl new file mode 100644 index 0000000..8f8fca5 --- /dev/null +++ b/src/correlation/cross_check_filter.comp.glsl @@ -0,0 +1,41 @@ +#version 450 +#pragma shader_stage(compute) + +layout(std430, push_constant) uniform readonly Parameters +{ + uint img1_width; + uint img1_height; + uint img2_width; + uint img2_height; + uint out_width; + uint out_height; + float scale; + uint iteration_pass; + mat3x3 fundamental_matrix; + int corridor_offset; + uint corridor_start; + uint corridor_end; + uint kernel_size; + float threshold; + float min_stdev; + uint neighbor_distance; + float extend_range; + float min_range; +}; +layout(std430, set = 1, binding = 0) buffer Img1 +{ + ivec2 img1[]; +}; +layout(std430, set = 1, binding = 1) buffer readonly Img2 +{ + ivec2 img2[]; +}; +layout (local_size_x = 16, local_size_y = 16, local_size_z = 1) in; + +void main() { + const uint x = gl_GlobalInvocationID.x; + const uint y = gl_GlobalInvocationID.y; + + // TODO: remove this debug code + img1[0] = ivec2(img1_width, img1_height); +} diff --git a/src/correlation/correlation.comp.glsl b/src/correlation/cross_correlate.comp.glsl similarity index 100% rename from src/correlation/correlation.comp.glsl rename to src/correlation/cross_correlate.comp.glsl diff --git a/src/correlation/init_out_data.comp.glsl b/src/correlation/init_out_data.comp.glsl new file mode 100644 index 0000000..138da8d --- /dev/null +++ b/src/correlation/init_out_data.comp.glsl @@ -0,0 +1,72 @@ +#version 450 +#pragma shader_stage(compute) + +layout(std430, push_constant) uniform readonly Parameters +{ + uint img1_width; + uint img1_height; + uint img2_width; + uint img2_height; + uint out_width; + uint out_height; + float scale; + uint iteration_pass; + mat3x3 fundamental_matrix; + int corridor_offset; + uint corridor_start; + uint corridor_end; + uint kernel_size; + float threshold; + float min_stdev; + uint neighbor_distance; + float extend_range; + float min_range; +}; +layout(std430, set = 0, binding = 0) buffer readonly Images +{ + // Layout: + // img1; img2 + float images[]; +}; +layout(std430, set = 0, binding = 1) buffer Internals_Img1 +{ + // Layout: + // For searchdata: contains [min_corridor, stdev] for image1 + // For cross_correlate: contains [avg, stdev] for image1 + float internals_img1[]; +}; +layout(std430, set = 0, binding = 2) buffer Internals_Img2 +{ + // Layout: + // Contains [avg, stdev] for image 2 + float internals_img2[]; +}; +layout(std430, set = 0, binding = 3) buffer Internals_Int +{ + // Layout: + // Contains [min, max, neighbor_count] for the corridor range + int internals_int[]; +}; +layout(std430, set = 0, binding = 4) buffer Result_Matches +{ + ivec2 result_matches[]; +}; +layout(std430, set = 0, binding = 5) buffer Result_Corr +{ + float result_corr[]; +}; +layout (local_size_x = 16, local_size_y = 16, local_size_z = 1) in; + +void main() { + const uint x = gl_GlobalInvocationID.x; + const uint y = gl_GlobalInvocationID.y; + /* + if (x < out_width && y < out_height) { + result_matches[out_width*y+x] = ivec2(-1, -1); + result_corr[out_width*y+x] = -1.0; + } + */ + + // TODO: remove this debug code + result_corr[0] = threshold; +} diff --git a/src/correlation/prepare_initialdata_correlation.comp.glsl b/src/correlation/prepare_initialdata_correlation.comp.glsl new file mode 100644 index 0000000..06f3ba6 --- /dev/null +++ b/src/correlation/prepare_initialdata_correlation.comp.glsl @@ -0,0 +1,5 @@ +#version 450 +#pragma shader_stage(compute) + +void main() { +} diff --git a/src/correlation/prepare_initialdata_searchdata.comp.glsl b/src/correlation/prepare_initialdata_searchdata.comp.glsl new file mode 100644 index 0000000..06f3ba6 --- /dev/null +++ b/src/correlation/prepare_initialdata_searchdata.comp.glsl @@ -0,0 +1,5 @@ +#version 450 +#pragma shader_stage(compute) + +void main() { +} diff --git a/src/correlation/prepare_searchdata.comp.glsl b/src/correlation/prepare_searchdata.comp.glsl new file mode 100644 index 0000000..06f3ba6 --- /dev/null +++ b/src/correlation/prepare_searchdata.comp.glsl @@ -0,0 +1,5 @@ +#version 450 +#pragma shader_stage(compute) + +void main() { +} diff --git a/src/correlation/vk.rs b/src/correlation/vk.rs index 76de2a6..b7c80ed 100644 --- a/src/correlation/vk.rs +++ b/src/correlation/vk.rs @@ -1,12 +1,8 @@ -use std::{ - borrow::Cow, - error, - ffi::{CStr, CString}, - fmt, -}; +use std::{collections::HashMap, error, ffi::CStr, fmt, slice, time::SystemTime}; use ash::{prelude::VkResult, vk}; use nalgebra::Matrix3; +use rayon::iter::ParallelIterator; use crate::data::{Grid, Point2D}; @@ -15,6 +11,10 @@ use super::{ CROSS_CHECK_SEARCH_AREA, KERNEL_SIZE, NEIGHBOR_DISTANCE, }; +// This should be supported by most modern GPUs, even old/budget ones like Celeron N3350. +// Based on https://www.vulkan.gpuinfo.org, only old integrated GPUs like 4th gen Core i5 have a +// limit of 4. +// But storing everything in the same buffer will likely have other issues like memory limits. const MAX_BINDINGS: u32 = 6; // Decrease when using a low-powered GPU const CORRIDOR_SEGMENT_LENGTH_HIGHPERFORMANCE: usize = 512; @@ -23,7 +23,7 @@ const CORRIDOR_SEGMENT_LENGTH_LOWPOWER: usize = 8; const SEARCH_AREA_SEGMENT_LENGTH_LOWPOWER: usize = 128; #[repr(C)] -#[derive(Copy, Clone, Debug, PartialEq)] +#[derive(Copy, Clone)] struct ShaderParams { img1_width: u32, img1_height: u32, @@ -45,12 +45,19 @@ struct ShaderParams { min_range: f32, } +const _: () = { + // Validate that ShaderParams fits into the minimum guaranteed push constants size. + // Exceeding this size would require refactoring code (e.g. switch to uniform buffers), or risk + // dropping support for some basic devices (like integrated Intel GPUs). + assert!(std::mem::size_of::() < 128); +}; + pub struct GpuContext { min_stdev: f32, correlation_threshold: f32, fundamental_matrix: Matrix3, - img1_shape: (usize, usize), - img2_shape: (usize, usize), + img1_dimensions: (usize, usize), + img2_dimensions: (usize, usize), correlation_values: Grid>, @@ -59,28 +66,76 @@ pub struct GpuContext { corridor_size: usize, corridor_extend_range: f64, device: Device, - /* - device: wgpu::Device, - queue: wgpu::Queue, - shader_module: wgpu::ShaderModule, - buffer_img: wgpu::Buffer, - buffer_internal_img1: wgpu::Buffer, - buffer_internal_img2: wgpu::Buffer, - buffer_internal_int: wgpu::Buffer, - buffer_out: wgpu::Buffer, - buffer_out_reverse: wgpu::Buffer, - buffer_out_corr: wgpu::Buffer, - - pipeline_configs: HashMap, - */ } struct Device { - name: String, - entry: ash::Entry, + _entry: ash::Entry, instance: ash::Instance, - physical_device: vk::PhysicalDevice, - //device: vk::Device, + name: String, + device: ash::Device, + memory_properties: vk::PhysicalDeviceMemoryProperties, + buffers: DeviceBuffers, + descriptor_sets: DescriptorSets, + pipelines: HashMap, + control: Control, +} + +#[derive(Copy, Clone)] +struct Buffer { + buffer: vk::Buffer, + buffer_memory: vk::DeviceMemory, + host_visible: bool, + host_coherent: bool, + size: u64, +} + +enum BufferType { + GpuOnly, + GpuDestination, + GpuSource, + HostSource, + HostDestination, +} + +struct DeviceBuffers { + buffer_img: Buffer, + buffer_internal_img1: Buffer, + buffer_internal_img2: Buffer, + buffer_internal_int: Buffer, + buffer_out: Buffer, + buffer_out_reverse: Buffer, + buffer_out_corr: Buffer, +} + +struct DescriptorSets { + direction: CorrelationDirection, + descriptor_pool: vk::DescriptorPool, + regular_layout: vk::DescriptorSetLayout, + cross_check_layout: vk::DescriptorSetLayout, + pipeline_layout: vk::PipelineLayout, + descriptor_sets: Vec, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +enum ShaderModuleType { + InitOutData, + PrepareInitialdataSearchdata, + PrepareInitialdataCorrelation, + PrepareSearchdata, + CrossCorrelate, + CrossCheckFilter, +} + +struct ShaderPipeline { + shader_module: vk::ShaderModule, + pipeline: vk::Pipeline, +} + +struct Control { + queue: vk::Queue, + command_pool: vk::CommandPool, + fence: vk::Fence, + command_buffer: vk::CommandBuffer, } impl GpuContext { @@ -91,12 +146,8 @@ impl GpuContext { fundamental_matrix: Matrix3, low_power: bool, ) -> Result> { - let img1_shape = (img1_dimensions.0, img1_dimensions.1); - let img2_shape = (img2_dimensions.0, img2_dimensions.1); - let img1_pixels = img1_dimensions.0 * img1_dimensions.1; let img2_pixels = img2_dimensions.0 * img2_dimensions.1; - let max_pixels = img1_pixels.max(img2_pixels); let (search_area_segment_length, corridor_segment_length) = if low_power { ( @@ -110,62 +161,13 @@ impl GpuContext { ) }; - // Ensure there's enough memory for the largest buffer. - let max_buffer_size = max_pixels * 4 * std::mem::size_of::(); - let device = Device::new(max_buffer_size)?; - - /* - let shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor { - label: None, - source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(include_str!("correlation.wgsl"))), - }); + let start_time = SystemTime::now(); + let device = Device::new(img1_pixels, img2_pixels)?; - // Init buffers. - let buffer_img = init_buffer( - &device, - (img1_pixels + img2_pixels) * std::mem::size_of::(), - false, - true, - ); - let buffer_internal_img1 = init_buffer( - &device, - (img1_pixels * 2) * std::mem::size_of::(), - true, - false, - ); - let buffer_internal_img2 = init_buffer( - &device, - (img2_pixels * 2) * std::mem::size_of::(), - true, - false, - ); - let buffer_internal_int = init_buffer( - &device, - max_pixels * 4 * std::mem::size_of::(), - true, - false, - ); - let buffer_out = init_buffer( - &device, - max_pixels * 2 * std::mem::size_of::(), - false, - false, - ); - let buffer_out_reverse = init_buffer( - &device, - img2_pixels * 2 * std::mem::size_of::(), - true, - false, - ); - let buffer_out_corr = init_buffer( - &device, - max_pixels * std::mem::size_of::(), - false, - false, - ); - */ - - let correlation_values = Grid::new(img1_shape.0, img1_shape.1, None); + if let Ok(t) = start_time.elapsed() { + println!("Initialized device in {:.3} seconds", t.as_secs_f32()); + } + let correlation_values = Grid::new(img1_dimensions.0, img1_dimensions.1, None); let params = CorrelationParameters::for_projection(&projection_mode); let result = GpuContext { @@ -174,25 +176,12 @@ impl GpuContext { corridor_size: params.corridor_size, corridor_extend_range: params.corridor_extend_range, fundamental_matrix, - img1_shape, - img2_shape, + img1_dimensions, + img2_dimensions, correlation_values, corridor_segment_length, search_area_segment_length, device, - /* - device, - queue, - shader_module, - buffer_img, - buffer_internal_img1, - buffer_internal_img2, - buffer_internal_int, - buffer_out, - buffer_out_reverse, - buffer_out_corr, - pipeline_configs: HashMap::new(), - */ }; Ok(result) } @@ -201,78 +190,735 @@ impl GpuContext { &self.device.name.as_str() } - pub fn cross_check_filter(&mut self, _: f32, _: CorrelationDirection) {} + pub fn cross_check_filter( + &mut self, + scale: f32, + dir: CorrelationDirection, + ) -> Result<(), Box> { + // TODO: recreate descriptor sets if direction changes + let (out_dimensions, out_dimensions_reverse) = match dir { + CorrelationDirection::Forward => (self.img1_dimensions, self.img2_dimensions), + CorrelationDirection::Reverse => (self.img2_dimensions, self.img1_dimensions), + }; + + let search_area = CROSS_CHECK_SEARCH_AREA * (1.0 / scale).round() as usize; + + // Reuse/repurpose ShaderParams. + let params = ShaderParams { + img1_width: out_dimensions.0 as u32, + img1_height: out_dimensions.1 as u32, + img2_width: out_dimensions_reverse.0 as u32, + img2_height: out_dimensions_reverse.1 as u32, + out_width: 0, + out_height: 0, + fundamental_matrix: [0.0; 3 * 4], + scale: 0.0, + iteration_pass: 0, + corridor_offset: 0, + corridor_start: 0, + corridor_end: 0, + kernel_size: 0, + threshold: 0.0, + min_stdev: 0.0, + neighbor_distance: search_area as u32, + extend_range: 0.0, + min_range: 0.0, + }; + unsafe { + self.device + .run_shader(out_dimensions, ShaderModuleType::CrossCheckFilter, params)?; + } + Ok(()) + } pub fn complete_process( &mut self, ) -> Result>, Box> { - todo!() + let mut out_image = Grid::new(self.img1_dimensions.0, self.img1_dimensions.1, None); + unsafe { + self.device + .save_result(&mut out_image, &self.correlation_values)?; + } + Ok(out_image) } pub fn correlate_images( &mut self, - _: &Grid, - _: &Grid, - _: f32, - _: bool, - _: Option<&PL>, - _: CorrelationDirection, + img1: &Grid, + img2: &Grid, + scale: f32, + first_pass: bool, + progress_listener: Option<&PL>, + dir: CorrelationDirection, ) -> Result<(), Box> { - todo!() + if matches!(dir, CorrelationDirection::Reverse) { + // TODO: fix this + return Ok(()); + } + // TODO: recreate descriptor sets if direction changes + let max_width = img1.width().max(img2.width()); + let max_height = img1.height().max(img2.height()); + let max_dimensions = (max_width, max_height); + let img1_dimensions = (img1.width(), img1.height()); + let out_dimensions = match dir { + CorrelationDirection::Forward => self.img1_dimensions, + CorrelationDirection::Reverse => self.img2_dimensions, + }; + + let mut progressbar_completed_percentage = 0.02; + let send_progress = |value| { + let value = match dir { + CorrelationDirection::Forward => value * 0.98 / 2.0, + CorrelationDirection::Reverse => 0.51 + value * 0.98 / 2.0, + }; + if let Some(pl) = progress_listener { + pl.report_status(value); + } + }; + + let mut params = ShaderParams { + img1_width: img1.width() as u32, + img1_height: img1.height() as u32, + img2_width: img2.width() as u32, + img2_height: img2.height() as u32, + out_width: out_dimensions.0 as u32, + out_height: out_dimensions.1 as u32, + fundamental_matrix: self.convert_fundamental_matrix(&dir), + scale, + iteration_pass: 0, + corridor_offset: 0, + corridor_start: 0, + corridor_end: 0, + kernel_size: KERNEL_SIZE as u32, + threshold: self.correlation_threshold, + min_stdev: self.min_stdev, + neighbor_distance: NEIGHBOR_DISTANCE as u32, + extend_range: self.corridor_extend_range as f32, + min_range: CORRIDOR_MIN_RANGE as f32, + }; + + println!("correlate images"); + unsafe { self.device.transfer_in_images(img1, img2)? }; + + if first_pass { + // TODO: remove this test/debug code! + params.threshold = 3.14; + unsafe { + self.device + .run_shader(out_dimensions, ShaderModuleType::InitOutData, params)?; + } + } else { + unsafe { + self.device.run_shader( + out_dimensions, + ShaderModuleType::PrepareInitialdataSearchdata, + params, + )?; + } + progressbar_completed_percentage = 0.02; + send_progress(progressbar_completed_percentage); + + let segment_length = self.search_area_segment_length; + let neighbor_width = (NEIGHBOR_DISTANCE as f32 / scale).ceil() as usize * 2 + 1; + let neighbor_pixels = neighbor_width * neighbor_width; + let neighbor_segments = neighbor_pixels / segment_length + 1; + + params.iteration_pass = 0; + for l in 0u32..neighbor_segments as u32 { + params.corridor_start = l * segment_length as u32; + params.corridor_end = (l + 1) * segment_length as u32; + if params.corridor_end > neighbor_pixels as u32 { + params.corridor_end = neighbor_pixels as u32; + } + unsafe { + self.device.run_shader( + img1_dimensions, + ShaderModuleType::PrepareSearchdata, + params, + )?; + } + + let percent_complete = + progressbar_completed_percentage + 0.09 * (l as f32 / neighbor_segments as f32); + send_progress(percent_complete); + } + progressbar_completed_percentage = 0.11; + send_progress(progressbar_completed_percentage); + + params.iteration_pass = 1; + for l in 0u32..neighbor_segments as u32 { + params.corridor_start = l * segment_length as u32; + params.corridor_end = (l + 1) * segment_length as u32; + if params.corridor_end > neighbor_pixels as u32 { + params.corridor_end = neighbor_pixels as u32; + } + unsafe { + self.device.run_shader( + img1_dimensions, + ShaderModuleType::PrepareSearchdata, + params, + )?; + } + + let percent_complete = + progressbar_completed_percentage + 0.09 * (l as f32 / neighbor_segments as f32); + send_progress(percent_complete); + } + + progressbar_completed_percentage = 0.20; + } + send_progress(progressbar_completed_percentage); + params.iteration_pass = if first_pass { 0 } else { 1 }; + + unsafe { + self.device.run_shader( + max_dimensions, + ShaderModuleType::PrepareInitialdataCorrelation, + params, + )?; + } + + let corridor_size = self.corridor_size; + let corridor_stripes = 2 * corridor_size + 1; + let max_length = img2.width().max(img2.height()); + let segment_length = self.corridor_segment_length; + let corridor_length = max_length - (KERNEL_SIZE * 2); + let corridor_segments = corridor_length / segment_length + 1; + for corridor_offset in -(corridor_size as i32)..=corridor_size as i32 { + for l in 0u32..corridor_segments as u32 { + params.corridor_offset = corridor_offset; + params.corridor_start = l * segment_length as u32; + params.corridor_end = (l + 1) * segment_length as u32; + if params.corridor_end > corridor_length as u32 { + params.corridor_end = corridor_length as u32; + } + unsafe { + self.device.run_shader( + img1_dimensions, + ShaderModuleType::CrossCorrelate, + params, + )?; + } + + let corridor_complete = params.corridor_end as f32 / corridor_length as f32; + let percent_complete = progressbar_completed_percentage + + (1.0 - progressbar_completed_percentage) + * (corridor_offset as f32 + corridor_size as f32 + corridor_complete) + / corridor_stripes as f32; + send_progress(percent_complete); + } + } + + if matches!(dir, CorrelationDirection::Forward) { + unsafe { + self.device + .save_corr(&mut self.correlation_values, self.correlation_threshold)? + }; + } + Ok(()) + } + + fn convert_fundamental_matrix(&self, dir: &CorrelationDirection) -> [f32; 3 * 4] { + let fundamental_matrix = match dir { + CorrelationDirection::Forward => self.fundamental_matrix, + CorrelationDirection::Reverse => self.fundamental_matrix.transpose(), + }; + let mut f = [0f32; 3 * 4]; + for row in 0..3 { + for col in 0..3 { + f[col * 4 + row] = fundamental_matrix[(row, col)] as f32; + } + } + f } } impl Device { - fn new(max_buffer_size: usize) -> Result> { + fn new(img1_pixels: usize, img2_pixels: usize) -> Result> { + // Ensure there's enough memory for the largest buffer. + let max_pixels = img1_pixels.max(img2_pixels); + let max_buffer_size = max_pixels * 4 * std::mem::size_of::(); + // Init adapter. let entry = unsafe { ash::Entry::load()? }; - let instance = Device::init_vk(&entry)?; - let (physical_device, name) = unsafe { Device::find_device(&instance, max_buffer_size)? }; + let instance = unsafe { Device::init_vk(&entry)? }; + let cleanup_err = |err| unsafe { + // TODO: look at refactoring this to follow RAII patterns (like ashpan/scopeguard). + instance.destroy_instance(None); + err + }; + let (physical_device, name, compute_queue_index) = + unsafe { Device::find_device(&instance, max_buffer_size).map_err(cleanup_err)? }; let name = name.to_string(); - - //let device = instance.create_device(physical_device, None, None)?; + let device = unsafe { + match Device::create_device(&instance, physical_device, compute_queue_index) { + Ok(dev) => dev, + Err(err) => { + instance.destroy_instance(None); + return Err(err); + } + } + }; + let cleanup_err = |err| unsafe { + device.destroy_device(None); + instance.destroy_instance(None); + err + }; + // Init buffers. + let memory_properties = + unsafe { instance.get_physical_device_memory_properties(physical_device) }; + let buffers = unsafe { + Device::create_buffers(&device, &memory_properties, img1_pixels, img2_pixels) + .map_err(cleanup_err)? + }; + let cleanup_err = |err| unsafe { + buffers.destroy(&device); + device.destroy_device(None); + instance.destroy_instance(None); + err + }; + // Init pipelines and shaders. + let descriptor_sets = unsafe { + Device::create_descriptor_sets(&device, &buffers, CorrelationDirection::Forward) + .map_err(cleanup_err)? + }; + let cleanup_err = |err| unsafe { + descriptor_sets.destroy(&device); + buffers.destroy(&device); + device.destroy_device(None); + instance.destroy_instance(None); + err + }; + let pipelines = + unsafe { Device::create_pipelines(&device, &descriptor_sets).map_err(cleanup_err)? }; + let cleanup_err = |err| unsafe { + destroy_pipelines(&device, &pipelines); + descriptor_sets.destroy(&device); + buffers.destroy(&device); + device.destroy_device(None); + instance.destroy_instance(None); + err + }; + // Init control struct - queues, fences, command buffer. + let control = + unsafe { Device::create_control(&device, compute_queue_index).map_err(cleanup_err)? }; let result = Device { - entry, + _entry: entry, instance, - physical_device, name, + device, + memory_properties, + buffers, + descriptor_sets, + pipelines, + control, }; Ok(result) } - fn init_vk(entry: &ash::Entry) -> VkResult { - let app_name = CString::new("Cybervision").unwrap(); - let engine_name = CString::new("cybervision").unwrap(); + unsafe fn run_shader( + &mut self, + dimensions: (usize, usize), + shader_type: ShaderModuleType, + shader_params: ShaderParams, + ) -> VkResult<()> { + let pipeline_config = self.pipelines.get(&shader_type).unwrap(); + let command_buffer = self.control.command_buffer; + + let workgroup_size = ((dimensions.0 + 15) / 16, ((dimensions.1 + 15) / 16)); + self.device.reset_fences(&[self.control.fence])?; + self.device + .reset_command_buffer(command_buffer, vk::CommandBufferResetFlags::empty())?; + let info = vk::CommandBufferBeginInfo::builder() + .flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT); + self.device.begin_command_buffer(command_buffer, &info)?; + + self.device.cmd_bind_pipeline( + self.control.command_buffer, + vk::PipelineBindPoint::COMPUTE, + pipeline_config.pipeline, + ); + // It's way easier to map all descriptor sets identically, instead of ensuring that every + // kernel gets a to use descriptor = 0. + // The cross correlation kernel will need to switch to descriptor set = 1. + self.device.cmd_bind_descriptor_sets( + command_buffer, + vk::PipelineBindPoint::COMPUTE, + self.descriptor_sets.pipeline_layout, + 0, + &self.descriptor_sets.descriptor_sets, + &[], + ); + + let push_constants_data = slice::from_raw_parts( + &shader_params as *const ShaderParams as *const u8, + std::mem::size_of::(), + ); + self.device.cmd_push_constants( + command_buffer, + self.descriptor_sets.pipeline_layout, + vk::ShaderStageFlags::COMPUTE, + 0, + push_constants_data, + ); + self.device.cmd_dispatch( + command_buffer, + workgroup_size.0 as u32, + workgroup_size.1 as u32, + 1, + ); + + self.device.end_command_buffer(command_buffer)?; + + let command_buffers = [command_buffer]; + let submit_info = vk::SubmitInfo::builder().command_buffers(&command_buffers); + self.device.queue_submit( + self.control.queue, + &[submit_info.build()], + self.control.fence, + )?; + + self.device + .wait_for_fences(&[self.control.fence], true, u64::MAX) + } + + unsafe fn transfer_in_images( + &self, + img1: &Grid, + img2: &Grid, + ) -> Result<(), Box> { + let img2_offset = img1.width() * img1.height(); + let size = img1.width() * img1.height() + img2.width() * img2.height(); + let size_bytes = size * std::mem::size_of::(); + let copy_image = |buffer: &Buffer| -> VkResult<()> { + let memory = self.device.map_memory( + buffer.buffer_memory, + 0, + size_bytes as u64, + vk::MemoryMapFlags::empty(), + )?; + { + let img_slice = slice::from_raw_parts_mut(memory as *mut f32, size); + img1.iter() + .for_each(|(x, y, val)| img_slice[y * img1.width() + x] = *val as f32); + img2.iter().for_each(|(x, y, val)| { + img_slice[img2_offset + y * img2.width() + x] = *val as f32 + }); + } + + if !buffer.host_coherent { + let flush_memory_ranges = vk::MappedMemoryRange::builder() + .memory(buffer.buffer_memory) + .offset(0) + .size(size_bytes as u64); + self.device + .flush_mapped_memory_ranges(&[flush_memory_ranges.build()])?; + } + self.device.unmap_memory(buffer.buffer_memory); + Ok(()) + }; + + if self.buffers.buffer_img.host_visible { + // If memory is available to the host, copy data directly to the buffer. + copy_image(&self.buffers.buffer_img)?; + return Ok(()); + } + + let temp_buffer = Device::create_buffer( + &self.device, + &self.memory_properties, + size_bytes, + BufferType::HostSource, + )?; + let cleanup_err = |err| { + temp_buffer.destroy(&self.device); + err + }; + copy_image(&temp_buffer).map_err(cleanup_err)?; + + let command_buffer = self.control.command_buffer; + self.device.reset_fences(&[self.control.fence])?; + self.device + .reset_command_buffer(command_buffer, vk::CommandBufferResetFlags::empty()) + .map_err(cleanup_err)?; + + let info = vk::CommandBufferBeginInfo::builder() + .flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT); + + self.device + .begin_command_buffer(command_buffer, &info) + .map_err(cleanup_err)?; + let regions = vk::BufferCopy::builder().size(size_bytes as u64); + self.device.cmd_copy_buffer( + command_buffer, + temp_buffer.buffer, + self.buffers.buffer_img.buffer, + &[regions.build()], + ); + self.device + .end_command_buffer(command_buffer) + .map_err(cleanup_err)?; + + let command_buffers = [command_buffer]; + let submit_info = vk::SubmitInfo::builder().command_buffers(&command_buffers); + self.device + .queue_submit( + self.control.queue, + &[submit_info.build()], + self.control.fence, + ) + .map_err(cleanup_err)?; + self.device + .wait_for_fences(&[self.control.fence], true, u64::MAX) + .map_err(cleanup_err)?; + + temp_buffer.destroy(&self.device); + + Ok(()) + } + + unsafe fn save_corr( + &self, + correlation_values: &mut Grid>, + correlation_threshold: f32, + ) -> Result<(), Box> { + let size = correlation_values.width() * correlation_values.height(); + let size_bytes = size * std::mem::size_of::(); + let width = correlation_values.width(); + let copy_corr_data = + |buffer: &Buffer, correlation_values: &mut Grid>| -> VkResult<()> { + let memory = self.device.map_memory( + buffer.buffer_memory, + 0, + size_bytes as u64, + vk::MemoryMapFlags::empty(), + )?; + { + let out_corr = slice::from_raw_parts_mut(memory as *mut f32, size); + correlation_values + .par_iter_mut() + .for_each(|(x, y, out_point)| { + let corr = out_corr[y * width + x]; + if corr > correlation_threshold { + *out_point = Some(corr); + } + }); + } + // TODO: remove this debug code + println!("Corr check = {:?}", correlation_values.val(0, 0)); + + if !buffer.host_coherent { + let flush_memory_ranges = vk::MappedMemoryRange::builder() + .memory(buffer.buffer_memory) + .offset(0) + .size(size_bytes as u64); + self.device + .invalidate_mapped_memory_ranges(&[flush_memory_ranges.build()])?; + } + self.device.unmap_memory(buffer.buffer_memory); + Ok(()) + }; + + if self.buffers.buffer_out_corr.host_visible { + // If memory is available to the host, copy data directly to the buffer. + copy_corr_data(&self.buffers.buffer_out_corr, correlation_values)?; + return Ok(()); + } + + let temp_buffer = Device::create_buffer( + &self.device, + &self.memory_properties, + size_bytes, + BufferType::HostDestination, + )?; + let cleanup_err = |err| { + temp_buffer.destroy(&self.device); + err + }; + + let command_buffer = self.control.command_buffer; + self.device.reset_fences(&[self.control.fence])?; + self.device + .reset_command_buffer(command_buffer, vk::CommandBufferResetFlags::empty()) + .map_err(cleanup_err)?; + + let info = vk::CommandBufferBeginInfo::builder() + .flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT); + + self.device + .begin_command_buffer(command_buffer, &info) + .map_err(cleanup_err)?; + let regions = vk::BufferCopy::builder().size(size_bytes as u64); + self.device.cmd_copy_buffer( + command_buffer, + self.buffers.buffer_out_corr.buffer, + temp_buffer.buffer, + &[regions.build()], + ); + self.device + .end_command_buffer(command_buffer) + .map_err(cleanup_err)?; + + let command_buffers = [command_buffer]; + let submit_info = vk::SubmitInfo::builder().command_buffers(&command_buffers); + self.device + .queue_submit( + self.control.queue, + &[submit_info.build()], + self.control.fence, + ) + .map_err(cleanup_err)?; + self.device + .wait_for_fences(&[self.control.fence], true, u64::MAX) + .map_err(cleanup_err)?; + + copy_corr_data(&temp_buffer, correlation_values).map_err(cleanup_err)?; + temp_buffer.destroy(&self.device); + + Ok(()) + } + + unsafe fn save_result( + &self, + out_image: &mut Grid>, + correlation_values: &Grid>, + ) -> Result<(), Box> { + // TODO: combine this with save_result + let size = out_image.width() * out_image.height() * 2; + let size_bytes = size * std::mem::size_of::(); + let width = out_image.width(); + let copy_out_image = + |buffer: &Buffer, out_image: &mut Grid>| -> VkResult<()> { + let memory = self.device.map_memory( + buffer.buffer_memory, + 0, + size_bytes as u64, + vk::MemoryMapFlags::empty(), + )?; + { + let out_data = slice::from_raw_parts_mut(memory as *mut u32, size); + out_image.par_iter_mut().for_each(|(x, y, out_point)| { + let pos = 2 * (y * width + x); + let (match_x, match_y) = (out_data[pos], out_data[pos + 1]); + if let Some(corr) = correlation_values.val(x, y) { + *out_point = if match_x > 0 && match_y > 0 { + let point_match = Point2D::new(match_x as u32, match_y as u32); + Some((point_match, *corr)) + } else { + None + }; + } else { + *out_point = None; + }; + }); + } + // TODO: remove this debug code + println!("Corr check = {:?}", out_image.val(0, 0)); + + if !buffer.host_coherent { + let flush_memory_ranges = vk::MappedMemoryRange::builder() + .memory(buffer.buffer_memory) + .offset(0) + .size(size_bytes as u64); + self.device + .invalidate_mapped_memory_ranges(&[flush_memory_ranges.build()])?; + } + self.device.unmap_memory(buffer.buffer_memory); + Ok(()) + }; + + if self.buffers.buffer_out.host_visible { + // If memory is available to the host, copy data directly to the buffer. + copy_out_image(&self.buffers.buffer_out, out_image)?; + return Ok(()); + } + + let temp_buffer = Device::create_buffer( + &self.device, + &self.memory_properties, + size_bytes, + BufferType::HostDestination, + )?; + let cleanup_err = |err| { + temp_buffer.destroy(&self.device); + err + }; + + let command_buffer = self.control.command_buffer; + self.device.reset_fences(&[self.control.fence])?; + self.device + .reset_command_buffer(command_buffer, vk::CommandBufferResetFlags::empty()) + .map_err(cleanup_err)?; + + let info = vk::CommandBufferBeginInfo::builder() + .flags(vk::CommandBufferUsageFlags::ONE_TIME_SUBMIT); + + self.device + .begin_command_buffer(command_buffer, &info) + .map_err(cleanup_err)?; + let regions = vk::BufferCopy::builder().size(size_bytes as u64); + self.device.cmd_copy_buffer( + command_buffer, + self.buffers.buffer_out.buffer, + temp_buffer.buffer, + &[regions.build()], + ); + self.device + .end_command_buffer(command_buffer) + .map_err(cleanup_err)?; + + let command_buffers = [command_buffer]; + let submit_info = vk::SubmitInfo::builder().command_buffers(&command_buffers); + self.device + .queue_submit( + self.control.queue, + &[submit_info.build()], + self.control.fence, + ) + .map_err(cleanup_err)?; + self.device + .wait_for_fences(&[self.control.fence], true, u64::MAX) + .map_err(cleanup_err)?; + + copy_out_image(&temp_buffer, out_image).map_err(cleanup_err)?; + temp_buffer.destroy(&self.device); + + Ok(()) + } + unsafe fn init_vk(entry: &ash::Entry) -> VkResult { + let app_name = CStr::from_bytes_with_nul_unchecked(b"Cybervision\0"); + let engine_name = CStr::from_bytes_with_nul_unchecked(b"cybervision\0"); let appinfo = vk::ApplicationInfo::builder() - .application_name(app_name.as_c_str()) + .application_name(app_name) .application_version(0) - .engine_name(engine_name.as_c_str()) + .engine_name(engine_name) .engine_version(0) .api_version(vk::make_api_version(0, 1, 0, 0)); let create_flags = vk::InstanceCreateFlags::default(); - let create_info = vk::InstanceCreateInfo::builder() .application_info(&appinfo) .flags(create_flags); - unsafe { entry.create_instance(&create_info, None) } + entry.create_instance(&create_info, None) } unsafe fn find_device( instance: &ash::Instance, max_buffer_size: usize, - ) -> Result<(vk::PhysicalDevice, &'static str), Box> { + ) -> Result<(vk::PhysicalDevice, &'static str, u32), Box> { let devices = instance.enumerate_physical_devices()?; let device = devices .iter() .filter_map(|device| { - let props = instance.get_physical_device_properties(*device); + let device = *device; + let props = instance.get_physical_device_properties(device); if props.limits.max_push_constants_size < std::mem::size_of::() as u32 || props.limits.max_per_stage_descriptor_storage_buffers < MAX_BINDINGS || props.limits.max_storage_buffer_range < max_buffer_size as u32 { return None; } + let queue_index = Device::find_compute_queue(instance, device)?; let device_name = CStr::from_ptr(props.device_name.as_ptr()); let device_name = device_name.to_str().unwrap(); @@ -283,22 +929,574 @@ impl Device { vk::PhysicalDeviceType::INTEGRATED_GPU => 1, _ => 0, }; - Some((device.to_owned(), device_name, score)) + Some((device, device_name, queue_index, score)) }) - .max_by_key(|(_device, _name, score)| *score); - let (device, name) = if let Some((device, name, _score)) = device { - (device, name) + .max_by_key(|(_device, _name, _queue_index, score)| *score); + let (device, name, queue_index) = if let Some((device, name, queue_index, _score)) = device + { + (device, name, queue_index) } else { return Err(GpuError::new("Device not found").into()); }; - Ok((device, name)) + Ok((device, name, queue_index)) + } + + unsafe fn find_compute_queue( + instance: &ash::Instance, + device: vk::PhysicalDevice, + ) -> Option { + instance + .get_physical_device_queue_family_properties(device) + .iter() + .enumerate() + .flat_map(|(index, queue)| { + println!("queue {} has flags {}", index, queue.queue_flags.as_raw()); + if queue + .queue_flags + .contains(vk::QueueFlags::COMPUTE | vk::QueueFlags::TRANSFER) + && queue.queue_count > 0 + { + Some(index as u32) + } else { + None + } + }) + .next() + } + + unsafe fn create_device( + instance: &ash::Instance, + physical_device: vk::PhysicalDevice, + compute_queue_index: u32, + ) -> Result> { + let queue_info = vk::DeviceQueueCreateInfo::builder() + .queue_family_index(compute_queue_index) + .queue_priorities(&[1.0f32]); + let device_create_info = + vk::DeviceCreateInfo::builder().queue_create_infos(std::slice::from_ref(&queue_info)); + match instance.create_device(physical_device, &device_create_info, None) { + Ok(device) => Ok(device), + Err(err) => Err(err.into()), + } + } + + unsafe fn create_buffers( + device: &ash::Device, + memory_properties: &vk::PhysicalDeviceMemoryProperties, + img1_pixels: usize, + img2_pixels: usize, + ) -> Result> { + let max_pixels = img1_pixels.max(img2_pixels); + let mut buffers: Vec = vec![]; + let cleanup_err = |buffers: &[Buffer], err| { + buffers.iter().for_each(|buffer| { + device.free_memory(buffer.buffer_memory, None); + device.destroy_buffer(buffer.buffer, None) + }); + err + }; + let buffer_img = Device::create_buffer( + device, + &memory_properties, + (img1_pixels + img2_pixels) * std::mem::size_of::(), + BufferType::GpuDestination, + ) + .map_err(|err| cleanup_err(buffers.as_slice(), err))?; + buffers.push(buffer_img); + + let buffer_internal_img1 = Device::create_buffer( + device, + &memory_properties, + (img1_pixels * 2) * std::mem::size_of::(), + BufferType::GpuOnly, + ) + .map_err(|err| cleanup_err(buffers.as_slice(), err))?; + buffers.push(buffer_internal_img1); + + let buffer_internal_img2 = Device::create_buffer( + device, + &memory_properties, + (img2_pixels * 2) * std::mem::size_of::(), + BufferType::GpuOnly, + ) + .map_err(|err| cleanup_err(buffers.as_slice(), err))?; + buffers.push(buffer_internal_img2); + + let buffer_internal_int = Device::create_buffer( + device, + &memory_properties, + max_pixels * 4 * std::mem::size_of::(), + BufferType::GpuOnly, + ) + .map_err(|err| cleanup_err(buffers.as_slice(), err))?; + buffers.push(buffer_internal_int); + + let buffer_out = Device::create_buffer( + device, + &memory_properties, + max_pixels * 2 * std::mem::size_of::(), + BufferType::GpuSource, + ) + .map_err(|err| cleanup_err(buffers.as_slice(), err))?; + buffers.push(buffer_out); + + let buffer_out_reverse = Device::create_buffer( + device, + &memory_properties, + img2_pixels * 2 * std::mem::size_of::(), + BufferType::GpuOnly, + ) + .map_err(|err| cleanup_err(buffers.as_slice(), err))?; + buffers.push(buffer_out_reverse); + + let buffer_out_corr = Device::create_buffer( + device, + &memory_properties, + max_pixels * std::mem::size_of::(), + BufferType::GpuSource, + ) + .map_err(|err| cleanup_err(buffers.as_slice(), err))?; + + Ok(DeviceBuffers { + buffer_img, + buffer_internal_img1, + buffer_internal_img2, + buffer_internal_int, + buffer_out, + buffer_out_reverse, + buffer_out_corr, + }) + } + + unsafe fn create_buffer( + device: &ash::Device, + memory_properties: &vk::PhysicalDeviceMemoryProperties, + size: usize, + buffer_type: BufferType, + ) -> Result> { + let size = size as u64; + let gpu_local = match buffer_type { + BufferType::GpuOnly | BufferType::GpuDestination | BufferType::GpuSource => true, + BufferType::HostSource | BufferType::HostDestination => false, + }; + let host_visible = match buffer_type { + BufferType::HostSource | BufferType::HostDestination => true, + BufferType::GpuOnly | BufferType::GpuDestination | BufferType::GpuSource => false, + }; + let extra_usage_flags = match buffer_type { + BufferType::HostSource => vk::BufferUsageFlags::TRANSFER_SRC, + BufferType::HostDestination => vk::BufferUsageFlags::TRANSFER_DST, + BufferType::GpuOnly => vk::BufferUsageFlags::STORAGE_BUFFER, + BufferType::GpuSource => { + vk::BufferUsageFlags::TRANSFER_SRC | vk::BufferUsageFlags::STORAGE_BUFFER + } + BufferType::GpuDestination => { + vk::BufferUsageFlags::TRANSFER_DST | vk::BufferUsageFlags::STORAGE_BUFFER + } + }; + let buffer_create_info = vk::BufferCreateInfo { + size, + usage: extra_usage_flags, + sharing_mode: vk::SharingMode::EXCLUSIVE, + ..Default::default() + }; + let buffer = device.create_buffer(&buffer_create_info, None)?; + let memory_requirements = device.get_buffer_memory_requirements(buffer); + let memory_type_index = memory_properties.memory_types + [..memory_properties.memory_type_count as usize] + .iter() + .enumerate() + .find(|(memory_type_index, memory_type)| { + if memory_properties.memory_heaps[memory_type.heap_index as usize].size < size { + return false; + }; + if (1 << memory_type_index) & memory_requirements.memory_type_bits == 0 { + return false; + } + + if gpu_local + && memory_type + .property_flags + .contains(vk::MemoryPropertyFlags::DEVICE_LOCAL) + { + return true; + } + if host_visible + && memory_type + .property_flags + .contains(vk::MemoryPropertyFlags::HOST_VISIBLE) + { + return true; + } + false + }); + let memory_type_index = if let Some((index, _)) = memory_type_index { + index as u32 + } else { + return Err(GpuError::new("Cannot find suitable memory").into()); + }; + let property_flags = + memory_properties.memory_types[memory_type_index as usize].property_flags; + let host_visible = property_flags.contains(vk::MemoryPropertyFlags::HOST_VISIBLE); + let host_coherent = property_flags.contains(vk::MemoryPropertyFlags::HOST_COHERENT); + let allocate_info = vk::MemoryAllocateInfo { + allocation_size: size, + memory_type_index, + ..Default::default() + }; + let buffer_memory = device.allocate_memory(&allocate_info, None); + let buffer_memory = match buffer_memory { + Ok(mem) => mem, + Err(err) => { + device.destroy_buffer(buffer, None); + return Err(err.into()); + } + }; + let result = Buffer { + buffer, + buffer_memory, + host_visible, + host_coherent, + size, + }; + device.bind_buffer_memory(buffer, buffer_memory, 0)?; + Ok(result) + } + + unsafe fn create_descriptor_sets( + device: &ash::Device, + buffers: &DeviceBuffers, + direction: CorrelationDirection, + ) -> Result> { + let create_layout_bindings = |count| { + let bindings = (0..count) + .map(|i| { + vk::DescriptorSetLayoutBinding::builder() + .binding(i as u32) + .descriptor_type(vk::DescriptorType::STORAGE_BUFFER) + .descriptor_count(1) + .stage_flags(vk::ShaderStageFlags::COMPUTE) + .build() + }) + .collect::>(); + let layout_info = + vk::DescriptorSetLayoutCreateInfo::builder().bindings(bindings.as_slice()); + device.create_descriptor_set_layout(&layout_info, None) + }; + let descriptor_pool_size = [vk::DescriptorPoolSize::builder() + .ty(vk::DescriptorType::STORAGE_BUFFER) + .descriptor_count(2) + .build()]; + let descriptor_pool_info = vk::DescriptorPoolCreateInfo::builder() + .max_sets(1) + .pool_sizes(&descriptor_pool_size); + let descriptor_pool = device.create_descriptor_pool(&descriptor_pool_info, None)?; + let cleanup_err = |err| { + device.destroy_descriptor_pool(descriptor_pool, None); + err + }; + let regular_layout = create_layout_bindings(6).map_err(cleanup_err)?; + let cleanup_err = |err| { + device.destroy_descriptor_set_layout(regular_layout, None); + device.destroy_descriptor_pool(descriptor_pool, None); + err + }; + let cross_check_layout = create_layout_bindings(2).map_err(cleanup_err)?; + let cleanup_err = |err| { + device.destroy_descriptor_set_layout(cross_check_layout, None); + device.destroy_descriptor_set_layout(regular_layout, None); + device.destroy_descriptor_pool(descriptor_pool, None); + err + }; + let layouts = [regular_layout, cross_check_layout]; + let push_constant_ranges = vk::PushConstantRange::builder() + .offset(0) + .size(std::mem::size_of::() as u32) + .stage_flags(vk::ShaderStageFlags::COMPUTE) + .build(); + let pipeline_layout = device + .create_pipeline_layout( + &vk::PipelineLayoutCreateInfo::builder() + .set_layouts(&layouts) + .push_constant_ranges(&[push_constant_ranges]), + None, + ) + .map_err(cleanup_err)?; + let cleanup_err = |err| { + device.destroy_pipeline_layout(pipeline_layout, None); + device.destroy_descriptor_set_layout(cross_check_layout, None); + device.destroy_descriptor_set_layout(regular_layout, None); + device.destroy_descriptor_pool(descriptor_pool, None); + err + }; + let descriptor_set_allocate_info = vk::DescriptorSetAllocateInfo::builder() + .descriptor_pool(descriptor_pool) + .set_layouts(&layouts); + let descriptor_sets = device + .allocate_descriptor_sets(&descriptor_set_allocate_info) + .map_err(cleanup_err)?; + + // TODO: extract this to allow switching direction on the fly. + let create_buffer_infos = |buffers: &[Buffer]| { + buffers + .iter() + .map(|buf| { + vk::DescriptorBufferInfo::builder() + .buffer(buf.buffer) + .offset(0) + .range(vk::WHOLE_SIZE) + .build() + }) + .collect::>() + }; + let create_write_descriptor = |i: usize, buffer_infos: &[vk::DescriptorBufferInfo]| { + vk::WriteDescriptorSet::builder() + .dst_set(descriptor_sets[i]) + .dst_binding(0) + .descriptor_type(vk::DescriptorType::STORAGE_BUFFER) + .buffer_info(buffer_infos) + .build() + }; + let (buffer_out, buffer_out_reverse) = match direction { + CorrelationDirection::Forward => (buffers.buffer_out, buffers.buffer_out_reverse), + CorrelationDirection::Reverse => (buffers.buffer_out_reverse, buffers.buffer_out), + }; + let regular_buffer_infos = create_buffer_infos(&[ + buffers.buffer_img, + buffers.buffer_internal_img1, + buffers.buffer_internal_img2, + buffers.buffer_internal_int, + buffer_out, + buffers.buffer_out_corr, + ]); + let cross_check_buffer_infos = create_buffer_infos(&[buffer_out, buffer_out_reverse]); + let write_descriptors = [ + create_write_descriptor(0, regular_buffer_infos.as_slice()), + create_write_descriptor(1, cross_check_buffer_infos.as_slice()), + ]; + device.update_descriptor_sets(&write_descriptors, &[]); + + Ok(DescriptorSets { + direction: direction.to_owned(), + descriptor_pool, + regular_layout, + cross_check_layout, + pipeline_layout, + descriptor_sets, + }) + } + + unsafe fn load_shaders( + device: &ash::Device, + ) -> Result, Box> { + let mut result = vec![]; + let cleanup_err = |result: &mut Vec<(ShaderModuleType, vk::ShaderModule)>, err| { + result + .iter() + .for_each(|(_type, shader)| device.destroy_shader_module(*shader, None)); + err + }; + for module_type in ShaderModuleType::VALUES { + let shader = module_type + .load(device) + .map_err(|err| cleanup_err(&mut result, err))?; + result.push((module_type, shader)); + } + Ok(result) + } + + unsafe fn create_pipelines( + device: &ash::Device, + descriptor_sets: &DescriptorSets, + ) -> Result, Box> { + let shader_modules = Device::load_shaders(device)?; + + let main_module_name = CStr::from_bytes_with_nul_unchecked(b"main\0"); + + let pipeline_create_info = shader_modules + .iter() + .map(|(_shader_type, module)| { + let stage_create_info = vk::PipelineShaderStageCreateInfo::builder() + .module(*module) + .name(&main_module_name) + .stage(vk::ShaderStageFlags::COMPUTE); + vk::ComputePipelineCreateInfo::builder() + .stage(stage_create_info.build()) + .layout(descriptor_sets.pipeline_layout) + .build() + }) + .collect::>(); + let pipelines = match device.create_compute_pipelines( + vk::PipelineCache::null(), + pipeline_create_info.as_slice(), + None, + ) { + Ok(pipelines) => pipelines, + Err(err) => { + err.0 + .iter() + .for_each(|pipeline| device.destroy_pipeline(*pipeline, None)); + shader_modules + .iter() + .for_each(|(_shader_type, module)| device.destroy_shader_module(*module, None)); + return Err(err.1.into()); + } + }; + let mut result = HashMap::new(); + shader_modules + .iter() + .zip(pipelines.iter()) + .for_each(|(shader_module, pipeline)| { + result.insert( + shader_module.0, + ShaderPipeline { + shader_module: shader_module.1, + pipeline: *pipeline, + }, + ); + }); + Ok(result) + } + + unsafe fn create_control( + device: &ash::Device, + queue_family_index: u32, + ) -> Result> { + let queue = device.get_device_queue(queue_family_index, 0); + let command_pool_info = vk::CommandPoolCreateInfo::builder() + .queue_family_index(queue_family_index) + .flags(vk::CommandPoolCreateFlags::RESET_COMMAND_BUFFER); + let command_pool = device.create_command_pool(&command_pool_info, None)?; + let cleanup_err = |err| { + device.destroy_command_pool(command_pool, None); + err + }; + let fence_create_info = vk::FenceCreateInfo::builder(); + let fence = device + .create_fence(&fence_create_info, None) + .map_err(cleanup_err)?; + let cleanup_err = |err| { + device.destroy_command_pool(command_pool, None); + device.destroy_fence(fence, None); + err + }; + let command_buffers_info = vk::CommandBufferAllocateInfo::builder() + .command_buffer_count(1) + .command_pool(command_pool) + .level(vk::CommandBufferLevel::PRIMARY); + let command_buffer = device + .allocate_command_buffers(&command_buffers_info) + .map_err(cleanup_err)?[0]; + Ok(Control { + queue, + command_pool, + fence, + command_buffer, + }) + } +} + +impl Buffer { + unsafe fn destroy(&self, device: &ash::Device) { + device.free_memory(self.buffer_memory, None); + device.destroy_buffer(self.buffer, None); + } +} + +impl DeviceBuffers { + unsafe fn destroy(&self, device: &ash::Device) { + [ + self.buffer_img, + self.buffer_internal_img1, + self.buffer_internal_img2, + self.buffer_internal_int, + self.buffer_out, + self.buffer_out_reverse, + self.buffer_out_corr, + ] + .iter() + .for_each(|buffer| { + buffer.destroy(device); + }); + } +} + +impl DescriptorSets { + unsafe fn destroy(&self, device: &ash::Device) { + device.destroy_pipeline_layout(self.pipeline_layout, None); + device.destroy_descriptor_set_layout(self.cross_check_layout, None); + device.destroy_descriptor_set_layout(self.regular_layout, None); + device.destroy_descriptor_pool(self.descriptor_pool, None); + } +} + +unsafe fn destroy_pipelines( + device: &ash::Device, + pipelines: &HashMap, +) { + pipelines.values().for_each(|shader_pipeline| { + device.destroy_shader_module(shader_pipeline.shader_module, None); + device.destroy_pipeline(shader_pipeline.pipeline, None); + }); +} + +impl ShaderModuleType { + const VALUES: [Self; 7] = [ + Self::InitOutData, + Self::PrepareInitialdataSearchdata, + Self::PrepareInitialdataCorrelation, + Self::PrepareSearchdata, + Self::CrossCorrelate, + Self::CrossCorrelate, + Self::CrossCheckFilter, + ]; + + unsafe fn load(&self, device: &ash::Device) -> Result> { + const SHADER_INIT_OUT_DATA: &[u8] = include_bytes!("init_out_data.spv"); + const SHADER_PREPARE_INITIALDATA_SEARCHDATA: &[u8] = + include_bytes!("prepare_initialdata_searchdata.spv"); + const SHADER_PREPARE_INITIALDATA_CORRELATION: &[u8] = + include_bytes!("prepare_initialdata_correlation.spv"); + const SHADER_PREPARE_SEARCHDATA: &[u8] = include_bytes!("prepare_searchdata.spv"); + const SHADER_CROSS_CORRELATE: &[u8] = include_bytes!("cross_correlate.spv"); + const SHADER_CROSS_CHECK_FILTER: &[u8] = include_bytes!("cross_check_filter.spv"); + + let shader_module_spv = match self { + ShaderModuleType::InitOutData => SHADER_INIT_OUT_DATA, + ShaderModuleType::PrepareInitialdataSearchdata => SHADER_PREPARE_INITIALDATA_SEARCHDATA, + ShaderModuleType::PrepareInitialdataCorrelation => { + SHADER_PREPARE_INITIALDATA_CORRELATION + } + ShaderModuleType::PrepareSearchdata => SHADER_PREPARE_SEARCHDATA, + ShaderModuleType::CrossCorrelate => SHADER_CROSS_CORRELATE, + ShaderModuleType::CrossCheckFilter => SHADER_CROSS_CHECK_FILTER, + }; + let shader_code = ash::util::read_spv(&mut std::io::Cursor::new(shader_module_spv))?; + let shader_module = device.create_shader_module( + &vk::ShaderModuleCreateInfo::builder().code(shader_code.as_slice()), + None, + )?; + + Ok(shader_module) + } +} + +impl Control { + unsafe fn destroy(&self, device: &ash::Device) { + device.free_command_buffers(self.command_pool, &[self.command_buffer]); + device.destroy_command_pool(self.command_pool, None); + device.destroy_fence(self.fence, None); } } impl Drop for Device { fn drop(&mut self) { unsafe { - //self.device.destroy_device(None); + self.control.destroy(&self.device); + destroy_pipelines(&self.device, &self.pipelines); + self.descriptor_sets.destroy(&self.device); + self.buffers.destroy(&self.device); + self.device.destroy_device(None); self.instance.destroy_instance(None); } }