Skip to content

Commit

Permalink
More cleanup, finish renaming
Browse files Browse the repository at this point in the history
  • Loading branch information
SallySoul committed Dec 19, 2024
1 parent 0ec4600 commit 1cef274
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 69 deletions.
2 changes: 1 addition & 1 deletion examples/gen_1d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fn main() {
let mut img = nhls::image::Image1D::new(grid_bound, n_lines as u32);
img.add_line(0, input_domain.buffer());
for t in 1..n_lines as u32 {
periodic_naive::box_solve(
direct_periodic_apply(
&stencil,
&mut input_domain,
&mut output_domain,
Expand Down
2 changes: 1 addition & 1 deletion examples/gen_2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ fn main() {
// Make image
nhls::image::image2d(&input_domain, "gen_2d/frame_000.png");
for t in 1..n_images as u32 {
periodic_naive::box_solve(
direct_periodic_apply(
&stencil,
&mut input_domain,
&mut output_domain,
Expand Down
26 changes: 7 additions & 19 deletions examples/heat_1d_p_direct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@ use nhls::domain::*;
use nhls::solver::*;
use nhls::stencil::*;
use nhls::util::*;
use rayon::prelude::*;

use nalgebra::matrix;

fn main() {
const GRID_DIMENSION: usize = 1;

// Grid size
let grid_bound = AABB::new(matrix![0, 999]);

Expand Down Expand Up @@ -45,27 +42,18 @@ fn main() {
// Fill in with IC values (use normal dist for spike in the middle)
let n_f = buffer_size as f32;
let sigma_sq: f32 = (n_f / 25.0) * (n_f / 25.0);
input_domain.par_modify_access(100).for_each(
|mut d: DomainChunk<'_, GRID_DIMENSION>| {
d.coord_iter_mut().for_each(
|(world_coord, value_mut): (
Coord<GRID_DIMENSION>,
&mut f32,
)| {
let x = (world_coord[0] as f32) - (n_f / 2.0);
//let f = ( 1.0 / (2.0 * std::f32::consts::PI * sigma_sq)).sqrt();
let exp = -x * x / (2.0 * sigma_sq);
*value_mut = exp.exp()
},
)
},
);
let ic_gen = |world_coord: Coord<1>| {
let x = (world_coord[0] as f32) - (n_f / 2.0);
let exp = -x * x / (2.0 * sigma_sq);
exp.exp()
};
input_domain.par_set_values(ic_gen, chunk_size);

// Make image
let mut img = nhls::image::Image1D::new(grid_bound, n_lines as u32);
img.add_line(0, input_domain.buffer());
for t in 1..n_lines as u32 {
periodic_naive::box_solve(
direct_periodic_apply(
&stencil,
&mut input_domain,
&mut output_domain,
Expand Down
7 changes: 3 additions & 4 deletions examples/heat_2d_p_direct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,7 @@ fn main() {
image2d(&input_domain, "heat_2d_direct/frame_0000.png");

// Create boundary condition
let bc = ConstantCheck::new(
0.0,
grid_bound);
let bc = ConstantCheck::new(0.0, grid_bound);

// Apply direct solver
for t in 1..n_images {
Expand All @@ -77,7 +75,8 @@ fn main() {
&mut input_domain,
&mut output_domain,
steps_per_image,
chunk_size);
chunk_size,
);

std::mem::swap(&mut input_domain, &mut output_domain);
image2d(&input_domain, &format!("heat_2d_direct/frame_{:04}.png", t));
Expand Down
8 changes: 6 additions & 2 deletions src/solver/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
pub mod fft_plan;
pub mod direct;
pub mod periodic_naive;
pub mod fft_plan;
pub mod periodic_direct;
pub mod periodic_plan;
pub mod trapezoid;

pub use direct::*;
pub use periodic_direct::*;
pub use periodic_plan::*;
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::domain::*;
use crate::par_stencil;
use crate::stencil::*;

pub fn box_solve<
pub fn direct_periodic_apply<
'a,
Operation,
const GRID_DIMENSION: usize,
Expand Down Expand Up @@ -56,7 +56,7 @@ mod unit_tests {
let mut output_buffer = vec![2.0; n_r];
let mut input_domain = Domain::new(*bound, &mut input_buffer);
let mut output_domain = Domain::new(*bound, &mut output_buffer);
box_solve(
direct_periodic_apply(
stencil,
&mut input_domain,
&mut output_domain,
Expand Down Expand Up @@ -153,7 +153,7 @@ mod unit_tests {
Domain::new(bound, output_buffer.as_slice_mut());
let chunk_size = 1;
let n = 1;
box_solve(
direct_periodic_apply(
&stencil,
&mut input_domain,
&mut output_domain,
Expand Down
54 changes: 15 additions & 39 deletions tests/base_solver_compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,13 @@ use nhls::domain::*;
use nhls::solver::*;
use nhls::stencil::*;
use nhls::util::*;
use rayon::prelude::*;

use fftw::array::*;
use float_cmp::assert_approx_eq;
use nalgebra::matrix;

#[test]
fn thermal_1d_compare() {
const GRID_DIMENSION: usize = 1;

// Grid size
let grid_bound = AABB::new(matrix![0, 999]);

Expand All @@ -38,10 +35,10 @@ fn thermal_1d_compare() {
// Create domains
let buffer_size = grid_bound.buffer_size();
let mut grid_input = vec![0.0; buffer_size];
let mut naive_input_domain = Domain::new(grid_bound, &mut grid_input);
let mut direct_input_domain = Domain::new(grid_bound, &mut grid_input);

let mut grid_output = vec![0.0; buffer_size];
let mut naive_output_domain = Domain::new(grid_bound, &mut grid_output);
let mut direct_output_domain = Domain::new(grid_bound, &mut grid_output);

let mut fft_input = AlignedVec::new(buffer_size);
let mut fft_output = AlignedVec::new(buffer_size);
Expand All @@ -51,54 +48,33 @@ fn thermal_1d_compare() {
// Fill in with IC values (use normal dist for spike in the middle)
let n_f = buffer_size as f32;
let sigma_sq: f32 = (n_f / 25.0) * (n_f / 25.0);
naive_input_domain.par_modify_access(100).for_each(
|mut d: DomainChunk<'_, GRID_DIMENSION>| {
d.coord_iter_mut().for_each(
|(world_coord, value_mut): (
Coord<GRID_DIMENSION>,
&mut f32,
)| {
let x = (world_coord[0] as f32) - (n_f / 2.0);
//let f = ( 1.0 / (2.0 * std::f32::consts::PI * sigma_sq)).sqrt();
let exp = -x * x / (2.0 * sigma_sq);
*value_mut = exp.exp()
},
)
},
);
let ic_gen = |world_coord: Coord<1>| {
let x = (world_coord[0] as f32) - (n_f / 2.0);
let exp = -x * x / (2.0 * sigma_sq);
exp.exp()
};

fft_input_domain.par_modify_access(100).for_each(
|mut d: DomainChunk<'_, GRID_DIMENSION>| {
d.coord_iter_mut().for_each(
|(world_coord, value_mut): (
Coord<GRID_DIMENSION>,
&mut f32,
)| {
let x = (world_coord[0] as f32) - (n_f / 2.0);
//let f = ( 1.0 / (2.0 * std::f32::consts::PI * sigma_sq)).sqrt();
let exp = -x * x / (2.0 * sigma_sq);
*value_mut = exp.exp()
},
)
},
);
direct_input_domain.par_set_values(ic_gen, chunk_size);

fft_input_domain.par_set_values(ic_gen, chunk_size);

let mut periodic_library =
nhls::solver::periodic_plan::PeriodicPlanLibrary::new(
&grid_bound,
&stencil,
);

periodic_library.apply(
&mut fft_input_domain,
&mut fft_output_domain,
n_steps,
chunk_size,
);

periodic_naive::box_solve(
direct_periodic_apply(
&stencil,
&mut naive_input_domain,
&mut naive_output_domain,
&mut direct_input_domain,
&mut direct_output_domain,
n_steps,
chunk_size,
);
Expand Down Expand Up @@ -138,7 +114,7 @@ fn periodic_compare() {
let mut domain_a_output = Domain::new(bound, output_a.as_slice_mut());
let mut domain_b_output = Domain::new(bound, output_b.as_slice_mut());

periodic_naive::box_solve(
direct_periodic_apply(
&stencil,
&mut domain_a_input,
&mut domain_a_output,
Expand Down

0 comments on commit 1cef274

Please sign in to comment.