Skip to content

Commit

Permalink
Add support for time-varying boundary conditions (#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
SallySoul authored Feb 3, 2025
1 parent 59543fc commit 907ca91
Show file tree
Hide file tree
Showing 16 changed files with 260 additions and 49 deletions.
4 changes: 4 additions & 0 deletions examples/heat_1d_ap_direct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,19 @@ fn main() {
i.add_line(0, input_domain.buffer());
img = Some(i);
}

let mut global_time = 0;
for t in 1..args.lines as u32 {
box_apply(
&bc,
&stencil,
&mut input_domain,
&mut output_domain,
args.steps_per_line,
global_time,
args.chunk_size,
);
global_time += args.steps_per_line;
std::mem::swap(&mut input_domain, &mut output_domain);
if let Some(i) = img.as_mut() {
i.add_line(t, input_domain.buffer());
Expand Down
7 changes: 5 additions & 2 deletions examples/heat_1d_ap_fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@ fn main() {
let (args, output_image_path) = Args::cli_parse("heat_1d_ap_fft");

let stencil = nhls::standard_stencils::heat_1d(1.0, 1.0, 0.5);
let grid_bound = args.grid_bounds();

// Create domains
let grid_bound = args.grid_bounds();
let mut buffer_1 = OwnedDomain::new(grid_bound);
let mut buffer_2 = OwnedDomain::new(grid_bound);
let mut input_domain = buffer_1.as_slice_domain();
Expand Down Expand Up @@ -44,8 +44,11 @@ fn main() {
i.add_line(0, input_domain.buffer());
img = Some(i);
}

let mut global_time = 0;
for t in 1..args.lines as u32 {
solver.apply(&mut input_domain, &mut output_domain);
solver.apply(&mut input_domain, &mut output_domain, global_time);
global_time += args.steps_per_line;
std::mem::swap(&mut input_domain, &mut output_domain);
if let Some(i) = img.as_mut() {
i.add_line(t, input_domain.buffer());
Expand Down
11 changes: 3 additions & 8 deletions examples/heat_2d_ap_direct.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
use nhls::domain::*;
use nhls::image::*;
use nhls::image_2d_example::*;
use nhls::init;

fn main() {
let args = Args::cli_parse("heat_2d_ap_direct");
Expand All @@ -15,12 +14,6 @@ fn main() {
let mut input_domain = OwnedDomain::new(grid_bound);
let mut output_domain = OwnedDomain::new(grid_bound);

if args.rand_init {
init::rand(&mut input_domain, 1024, args.chunk_size);
} else {
init::normal_ic_2d(&mut input_domain, args.chunk_size);
}

if args.write_images {
image2d(&input_domain, &args.frame_name(0));
}
Expand All @@ -29,16 +22,18 @@ fn main() {
let bc = ConstantCheck::new(0.0, grid_bound);

// Apply direct solver
let mut global_time = 0;
for t in 1..args.images {
nhls::solver::direct::box_apply(
&bc,
&stencil,
&mut input_domain,
&mut output_domain,
args.steps_per_image,
global_time,
args.chunk_size,
);

global_time += args.steps_per_image;
std::mem::swap(&mut input_domain, &mut output_domain);
if args.write_images {
image2d(&input_domain, &args.frame_name(t));
Expand Down
22 changes: 12 additions & 10 deletions examples/heat_2d_ap_fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,7 @@ fn main() {
let args = Args::cli_parse("heat_2d_ap_fft");

let stencil = nhls::standard_stencils::heat_2d(1.0, 1.0, 1.0, 0.2, 0.2);

// Create domains
let grid_bound = args.grid_bounds();
let mut buffer_1 = OwnedDomain::new(grid_bound);
let mut buffer_2 = OwnedDomain::new(grid_bound);
let mut input_domain = buffer_1.as_slice_domain();
let mut output_domain = buffer_2.as_slice_domain();
if args.write_images {
image2d(&input_domain, &args.frame_name(0));
}

// Create BC
let bc = ConstantCheck::new(1.0, grid_bound);
Expand Down Expand Up @@ -45,8 +36,19 @@ fn main() {
solver.scratch_descriptor_file(&d_path);
}

// Create domains
let mut buffer_1 = OwnedDomain::new(grid_bound);
let mut buffer_2 = OwnedDomain::new(grid_bound);
let mut input_domain = buffer_1.as_slice_domain();
let mut output_domain = buffer_2.as_slice_domain();
if args.write_images {
image2d(&input_domain, &args.frame_name(0));
}

let mut global_time = 0;
for t in 1..args.images {
solver.apply(&mut input_domain, &mut output_domain);
solver.apply(&mut input_domain, &mut output_domain, global_time);
global_time += args.steps_per_image;
std::mem::swap(&mut input_domain, &mut output_domain);
if args.write_images {
image2d(&input_domain, &args.frame_name(t));
Expand Down
4 changes: 3 additions & 1 deletion examples/heat_3d_ap_fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@ fn main() {
write_vtk3d(&input_domain, &args.frame_name(0));
}

let mut global_time = 0;
for t in 1..args.images {
solver.apply(&mut input_domain, &mut output_domain);
solver.apply(&mut input_domain, &mut output_domain, global_time);
global_time += args.steps_per_image;
std::mem::swap(&mut input_domain, &mut output_domain);
if args.write_images {
write_vtk3d(&input_domain, &args.frame_name(t));
Expand Down
117 changes: 117 additions & 0 deletions examples/time_varying_2d.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
use nhls::domain::*;
use nhls::fft_solver::*;
use nhls::image::*;
use nhls::image_2d_example::*;
use nhls::stencil::*;
use nhls::util::*;

pub struct PulseBC<const GRID_DIMENSION: usize> {
rate: f64,
n_f: f64,
sigma_sq: f64,
aabb: AABB<GRID_DIMENSION>,
}

impl<const GRID_DIMENSION: usize> PulseBC<GRID_DIMENSION> {
pub fn new(rate: f64, aabb: AABB<GRID_DIMENSION>) -> Self {
let n_f = aabb.exclusive_bounds()[1] as f64;
let sigma_sq: f64 = (n_f / 10.0) * (n_f / 10.0);
PulseBC {
rate,
aabb,
n_f,
sigma_sq,
}
}
}

impl<const GRID_DIMENSION: usize> BCCheck<GRID_DIMENSION>
for PulseBC<GRID_DIMENSION>
{
fn check(
&self,
coord: &Coord<GRID_DIMENSION>,
global_time: usize,
) -> Option<f64> {
if self.aabb.contains(coord) {
None
} else if coord[0] < self.aabb.bounds[(0, 0)] {
let x = (coord[1] as f64) - (self.n_f / 2.0);
let exp = -x * x / (2.0 * self.sigma_sq);
let normal_value = exp.exp();
let sin_mod = 0.5 * ((global_time as f64 * self.rate).sin() + 1.0);
Some(sin_mod * normal_value)
} else {
Some(0.0)
}
}
}

fn main() {
let args = Args::cli_parse("time_varying_2d");

let stencil = Stencil::new(
[[0, 0], [-1, 0], [1, 0], [0, -1], [0, 1]],
move |args: &[f64; 5]| {
let middle = args[0];
let left = args[1];
let right = args[2];
let bottom = args[3];
let top = args[4];
0.1 * middle + 0.4 * left + 0.1 * right + 0.2 * top + 0.2 * bottom
},
);

let w = stencil.extract_weights();
let s: f64 = w.iter().sum();
println!("{:?}, {}", w, s);
let grid_bound = args.grid_bounds();

// Create BC
let bc = PulseBC::new((2.0 * std::f64::consts::PI) / 1000.0, grid_bound);

// Create AP Solver
let planner_params = PlannerParameters {
plan_type: args.plan_type,
cutoff: 40,
ratio: 0.5,
chunk_size: args.chunk_size,
};
let solver = APSolver::new(
&bc,
&stencil,
grid_bound,
args.steps_per_image,
&planner_params,
);
if args.write_dot {
let mut dot_path = args.output_dir.clone();
dot_path.push("plan.dot");
solver.to_dot_file(&dot_path);

let mut d_path = args.output_dir.clone();
d_path.push("scratch.txt");
solver.scratch_descriptor_file(&d_path);
}

// Create domains
let mut buffer_1 = OwnedDomain::new(grid_bound);
let mut buffer_2 = OwnedDomain::new(grid_bound);
let mut input_domain = buffer_1.as_slice_domain();
let mut output_domain = buffer_2.as_slice_domain();
if args.write_images {
image2d(&input_domain, &args.frame_name(0));
}

let mut global_time = 0;
for t in 1..args.images {
solver.apply(&mut input_domain, &mut output_domain, global_time);
global_time += args.steps_per_image;
std::mem::swap(&mut input_domain, &mut output_domain);
if args.write_images {
image2d(&input_domain, &args.frame_name(t));
}
}

args.save_wisdom();
}
12 changes: 8 additions & 4 deletions src/domain/bc/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ impl<const GRID_DIMENSION: usize> ConstantCheck<GRID_DIMENSION> {
impl<const GRID_DIMENSION: usize> BCCheck<GRID_DIMENSION>
for ConstantCheck<GRID_DIMENSION>
{
fn check(&self, coord: &Coord<GRID_DIMENSION>) -> Option<f64> {
fn check(
&self,
coord: &Coord<GRID_DIMENSION>,
_global_time: usize,
) -> Option<f64> {
if self.bound.contains(coord) {
return None;
}
Expand All @@ -39,18 +43,18 @@ mod unit_tests {
}
let bc = ConstantCheck::new(-1.0, bound);
for i in 0..n_r {
let v = bc.check(&vector![i as i32]);
let v = bc.check(&vector![i as i32], 0);
assert_eq!(v, None);
}

{
let v = bc.check(&vector![-1]);
let v = bc.check(&vector![-1], 1);
assert!(v.is_some());
assert_approx_eq!(f64, v.unwrap(), -1.0);
}

{
let v = bc.check(&vector![11]);
let v = bc.check(&vector![11], 2);
assert!(v.is_some());
assert_approx_eq!(f64, v.unwrap(), -1.0);
}
Expand Down
6 changes: 5 additions & 1 deletion src/domain/bc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,9 @@ pub use periodic::*;
use crate::util::*;

pub trait BCCheck<const GRID_DIMENSION: usize>: Sync {
fn check(&self, world_coord: &Coord<GRID_DIMENSION>) -> Option<f64>;
fn check(
&self,
world_coord: &Coord<GRID_DIMENSION>,
global_time: usize,
) -> Option<f64>;
}
12 changes: 8 additions & 4 deletions src/domain/bc/periodic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@ impl<
impl<const GRID_DIMENSION: usize, DomainType: DomainView<GRID_DIMENSION>>
BCCheck<GRID_DIMENSION> for PeriodicCheck<'_, GRID_DIMENSION, DomainType>
{
fn check(&self, world_coord: &Coord<GRID_DIMENSION>) -> Option<f64> {
fn check(
&self,
world_coord: &Coord<GRID_DIMENSION>,
_global_time: usize,
) -> Option<f64> {
let p_coord = &self.domain.aabb().periodic_coord(world_coord);
if p_coord != world_coord {
return Some(self.domain.view(p_coord));
Expand All @@ -47,18 +51,18 @@ mod unit_tests {
domain.par_set_values(|coord| coord[0] as f64, 1);
let bc = PeriodicCheck::new(&domain);
for (i, _) in domain.buffer().iter().enumerate() {
let v = bc.check(&vector![i as i32]);
let v = bc.check(&vector![i as i32], 0);
assert_eq!(v, None);
}

{
let v = bc.check(&vector![-1]);
let v = bc.check(&vector![-1], 1);
assert!(v.is_some());
assert_approx_eq!(f64, v.unwrap(), 10.0);
}

{
let v = bc.check(&vector![11]);
let v = bc.check(&vector![11], 2);
assert!(v.is_some());
assert_approx_eq!(f64, v.unwrap(), 0.0);
}
Expand Down
7 changes: 4 additions & 3 deletions src/domain/gather_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ pub fn gather_args<
bc: &BC,
input: &DomainType,
world_coord: &Coord<GRID_DIMENSION>,
global_time: usize,
) -> [f64; NEIGHBORHOOD_SIZE]
where
Operation: StencilOperation<f64, NEIGHBORHOOD_SIZE>,
Expand All @@ -22,7 +23,7 @@ where
for (i, n_i) in stencil.offsets().iter().enumerate() {
let n_world_coord = world_coord + n_i;
result[i] = bc
.check(&n_world_coord)
.check(&n_world_coord, global_time)
.unwrap_or_else(|| input.view(&n_world_coord));
}
result
Expand All @@ -47,7 +48,7 @@ mod unit_tests {
[[0, -1], [0, 1], [1, 0], [-1, 0], [0, 0]],
|_: &[f64; 5]| -1.0,
);
let r = gather_args(&stencil, &bc, &domain, &vector![9, 9]);
let r = gather_args(&stencil, &bc, &domain, &vector![9, 9], 0);
let e = [
(9 + 3 * 8) as f64,
-4.0,
Expand Down Expand Up @@ -76,7 +77,7 @@ mod unit_tests {
[[0, -1], [0, 1], [1, 0], [-1, 0], [0, 0]],
|_: &[f64; 5]| -1.0,
);
let r = gather_args(&stencil, &bc, &domain, &vector![9, 9]);
let r = gather_args(&stencil, &bc, &domain, &vector![9, 9], 11);
let e = [
(9 + 3 * 8) as f64,
9.0,
Expand Down
Loading

0 comments on commit 907ca91

Please sign in to comment.