Skip to content

Commit

Permalink
Merge pull request #164 from JuliaGeodynamics/adm/MKC
Browse files Browse the repository at this point in the history
Improve and fix marker chain on GPUs
  • Loading branch information
albert-de-montserrat authored Oct 28, 2024
2 parents 7c1a609 + c34a837 commit 7fc97ae
Show file tree
Hide file tree
Showing 12 changed files with 172 additions and 108 deletions.
14 changes: 14 additions & 0 deletions ext/JustPICAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,20 @@ module _2D

## MakerChain

function JustPIC._2D.init_markerchain(::Type{AMDGPUBackend}, nxcell, min_xcell, max_xcell, xv, initial_elevation)
nx = length(xv) - 1
dx = xv[2] - xv[1]
dx_chain = dx / (nxcell + 1)
px, py = ntuple(_ -> @fill(NaN, (nx,), celldims = (max_xcell,)), Val(2))
index = @fill(false, (nx,), celldims = (max_xcell,), eltype = Bool)

@parallel (1:nx) fill_markerchain_coords_index!(
px, py, index, xv, initial_elevation, dx_chain, nxcell, max_xcell
)

return MarkerChain(AMDGPUBackend, (px, py), index, xv, min_xcell, max_xcell)
end

function JustPIC._2D.advect_markerchain!(
chain::MarkerChain{AMDGPUBackend},
method::AbstractAdvectionIntegrator,
Expand Down
14 changes: 14 additions & 0 deletions ext/JustPICCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,20 @@ module _2D

## MakerChain

function JustPIC._2D.init_markerchain(::Type{CUDABackend}, nxcell, min_xcell, max_xcell, xv, initial_elevation)
nx = length(xv) - 1
dx = xv[2] - xv[1]
dx_chain = dx / (nxcell + 1)
px, py = ntuple(_ -> @fill(NaN, (nx,), celldims = (max_xcell,)), Val(2))
index = @fill(false, (nx,), celldims = (max_xcell,), eltype = Bool)

@parallel (1:nx) fill_markerchain_coords_index!(
px, py, index, xv, initial_elevation, dx_chain, nxcell, max_xcell
)

return MarkerChain(CUDABackend, (px, py), index, xv, min_xcell, max_xcell)
end

function JustPIC._2D.advect_markerchain!(
chain::MarkerChain{CUDABackend},
method::AbstractAdvectionIntegrator,
Expand Down
17 changes: 9 additions & 8 deletions src/MarkerChain/Advection/advection.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

function advect_markerchain!(
chain::MarkerChain, method::AbstractAdvectionIntegrator, V, grid_vxi, dt
)
Expand All @@ -9,16 +10,16 @@ end

# Two-step Runge-Kutta advection scheme for marker chains
function advection!(
particles::MarkerChain,
chain::MarkerChain,
method::AbstractAdvectionIntegrator,
V,
grid_vi::NTuple{N,NTuple{N,T}},
dt,
) where {N,T}
(; coords, index) = particles
(; coords, index) = chain

# compute some basic stuff
ni = size(index)
ni = size(index, 1)
dxi = compute_dx(first(grid_vi))
# Need to transpose grid_vy and Vy to reuse interpolation kernels

Expand All @@ -34,7 +35,7 @@ end
# DIMENSION AGNOSTIC KERNELS

# ParallelStencil function Runge-Kuttaadvection function for 3D staggered grids
@parallel_indices (I...) function advection_markerchain_kernel!(
@parallel_indices (i) function advection_markerchain_kernel!(
p,
method::AbstractAdvectionIntegrator,
V::NTuple{N,T},
Expand All @@ -45,17 +46,17 @@ end
dt,
) where {N,T}
for ipart in cellaxes(index)
doskip(index, ipart, I...) && continue
doskip(index, ipart, i) && continue

# skip if particle does not exist in this memory location
doskip(index, ipart, I...) && continue
doskip(index, ipart, i) && continue
# extract particle coordinates
pᵢ = get_particle_coords(p, ipart, I...)
pᵢ = get_particle_coords(p, ipart, i)
# advect particle
pᵢ_new = advect_particle_markerchain(method, pᵢ, V, grid, local_limits, dxi, dt)
# update particle coordinates
for k in 1:N
@inbounds @index p[k][ipart, I...] = pᵢ_new[k]
@inbounds @index p[k][ipart, i] = pᵢ_new[k]
end
end

Expand Down
49 changes: 32 additions & 17 deletions src/MarkerChain/init.jl
Original file line number Diff line number Diff line change
@@ -1,27 +1,42 @@
function init_markerchain(backend, nxcell, min_xcell, max_xcell, xv, initial_elevation)
@parallel_indices (i) function fill_coords_index!(
px, py, index, x, initial_elevation, dx_chain, nxcell, max_xcell
)
# lower-left corner of the cell
x0 = x[i]
# fill index array
for ip in 1:nxcell
@index px[ip, i] = x0 + dx_chain * ip
@index py[ip, i] = initial_elevation
@index index[ip, i] = true
end
return nothing
end

function init_markerchain(::Type{JustPIC.CPUBackend}, nxcell, min_xcell, max_xcell, xv, initial_elevation)
nx = length(xv) - 1
dx = xv[2] - xv[1]
dx_chain = dx / (nxcell + 1)
px, py = ntuple(_ -> @fill(NaN, (nx,), celldims = (max_xcell,)), Val(2))
index = @fill(false, (nx,), celldims = (max_xcell,), eltype = Bool)

@parallel (1:nx) fill_coords_index!(
@parallel (1:nx) fill_markerchain_coords_index!(
px, py, index, xv, initial_elevation, dx_chain, nxcell, max_xcell
)

return MarkerChain(backend, (px, py), index, xv, min_xcell, max_xcell)
return MarkerChain(JustPIC.CPUBackend, (px, py), index, xv, min_xcell, max_xcell)
end

@parallel_indices (i) function fill_markerchain_coords_index!(
px, py, index, x, initial_elevation, dx_chain, nxcell, max_xcell
)
# lower-left corner of the cell
x0 = x[i]
# fill index array
for ip in 1:nxcell
@index px[ip, i] = x0 + dx_chain * ip
@index py[ip, i] = initial_elevation
@index index[ip, i] = true
end
return nothing
end

@parallel_indices (i) function fill_markerchain_coords_index!(
px, py, index, x, initial_elevation::AbstractArray{T, 1}, dx_chain, nxcell, max_xcell
) where {T}
# lower-left corner of the cell
x0 = x[i]
initial_elevation0 = initial_elevation[i]
# fill index array
for ip in 1:nxcell
@index px[ip, i] = x0 + dx_chain * ip
@index py[ip, i] = initial_elevation0
@index index[ip, i] = true
end
return nothing
end
65 changes: 43 additions & 22 deletions src/MarkerChain/interp1.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
@inline _interp1D(xq, x0, x1, y0, y1) = fma((xq - x0), (y1 - y0) * inv(x1 - x0), y0)
@inline _interp1D(xq, x0, x1, y0, y1) = muladd((xq - x0), (y1 - y0) * inv(x1 - x0), y0)

function interp1D_extremas(xq, x, y)
last_I = findlast(!isnan, x)
last_I = 1
for i in length(x):-1:2
if !isnan(x[i])
last_I = i
break
end
end
x_lo, x_hi = x[1], x[last_I]
@inbounds for j in eachindex(x)[1:(end - 1)]
x0, x1 = x[j], x[j + 1]
Expand All @@ -25,46 +31,61 @@ function interp1D_extremas(xq, x, y)
return _interp1D(xq, x0, x1, y0, y1)
end
end
return error("xq outside domain")
# return error("xq outside domain")
return NaN
end

function interp1D_inner(xq, x, y, cell_coords, I::Integer)
last_I = findlast(!isnan, x)
function interp1D_inner(xq, x, y, coords, I::Integer)
last_I = 1
for i in length(x):-1:2
if !isnan(x[i])
last_I = i
break
end
end
x_lo, x_hi = x[1], x[last_I]
@inbounds for j in 1:last_I
@inbounds for j in 1:last_I-1
x0, x1 = x[j], x[j + 1]

# interpolation
if x0 xq x1
y0, y1 = y[j], y[j + 1]
return _interp1D(xq, x0, x1, y0, y1)
end

# interpolate using the last particle of left-neighbouring cell
if xq x_lo
x0, y0 = left_cell_right_particle(cell_coords, I)
x0, y0 = left_cell_right_particle(coords, I)
x1, y1 = x[1], y[1]
return _interp1D(xq, x0, x1, y0, y1)
end

# interpolate using the first particle of right-neighbouring cell
if xq x_hi
x0, y0 = x[last_I], y[last_I]
x1, y1 = right_cell_left_particle(cell_coords, I)
x1, y1 = right_cell_left_particle(coords, I)
return _interp1D(xq, x0, x1, y0, y1)
end

# interpolation
if x0 xq x1
y0, y1 = y[j], y[j + 1]
return _interp1D(xq, x0, x1, y0, y1)
end
end
@show x_lo, x_hi, xq, I
return error("xq outside domain")
# return error("xq outside domain")
return NaN
end

@inline right_cell_left_particle(cell_coords, I::Int) =
@index(cell_coords[1][1, I + 1]), @index(cell_coords[2][1, I + 1])
@inline right_cell_left_particle(coords, I::Int) =
@index(coords[1][1, I + 1]), @index(coords[2][1, I + 1])

@inline function left_cell_right_particle(coords, I)
px = coords[1]
# px = @cell coords[1][I - 1]
ip = 1
for i in cellnum(px):-1:2
if !isnan(@index px[i, I-1])
ip = i
break
end
end

@inline function left_cell_right_particle(cell_coords, I)
px = cell_coords[1][I - 1]
ip = findlast(!isnan, px)
return px[ip], @index(cell_coords[2][ip, I - 1])
return @index(px[ip, I-1]), @index(coords[2][ip, I - 1])
end

@inline function is_above_surface(xq, yq, coords, cell_vertices)
Expand Down
20 changes: 10 additions & 10 deletions src/MarkerChain/move.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
function move_particles!(chain::MarkerChain)
(; coords, index, cell_vertices) = chain
dxi = compute_dx(cell_vertices)
nxi = size(index)
nxi = size(index, 1)
grid = cell_vertices

@parallel (@idx nxi) move_particles_launcher!(coords, grid, dxi, index)

return nothing
end

@parallel_indices (I...) function move_particles_launcher!(coords, grid, dxi, index)
_move_particles!(coords, grid, dxi, index, I)
@parallel_indices (i) function move_particles_launcher!(coords, grid, dxi, index)
_move_particles!(coords, grid, dxi, index, i)
return nothing
end

Expand All @@ -19,11 +19,11 @@ chop(I::NTuple{3,T}) where {T} = I[1], I[2]

function _move_particles!(coords, grid, dxi, index, idx)
# coordinate of the lower-most-left coordinate of the parent cell
corner_xi = corner_coordinate(grid, chop(idx))
corner_xi = corner_coordinate(grid, idx)

# iterate over particles in child cell
for ip in cellaxes(index)
doskip(index, ip, idx...) && continue
doskip(index, ip, idx) && continue
pᵢ = cache_particle(coords, ip, idx)

# check whether the particle is
Expand All @@ -36,20 +36,20 @@ function _move_particles!(coords, grid, dxi, index, idx)
if !(any(<(1), new_cell) || any(new_cell .> length(grid)))
## THE PARTICLE DID NOT ESCAPE THE DOMAIN
# remove particle from child cell
@inbounds @index index[ip, chop(idx)] = false
@inbounds @index coords[1][ip, chop(idx)] = NaN
@inbounds @index coords[2][ip, chop(idx)] = NaN
@inbounds @index index[ip, idx] = false
@inbounds @index coords[1][ip, idx] = NaN
@inbounds @index coords[2][ip, idx] = NaN
# check whether there's empty space in parent cell
free_idx = find_free_memory(index, new_cell...)
free_idx == 0 && continue
iszero(free_idx) && continue
# move particle and its fields to the first free memory location
@inbounds @index index[free_idx, new_cell] = true
fill_particle!(coords, pᵢ, free_idx, new_cell)

else
## SOMEHOW THE PARTICLE DID ESCAPE THE DOMAIN
## => REMOVE IT
@inbounds @index index[ip, idx...] = false
@inbounds @index index[ip, idx] = false
empty_particle!(coords, ip, idx)
end
end
Expand Down
32 changes: 24 additions & 8 deletions src/MarkerChain/resample.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ function resample!(chain::MarkerChain)
nx = length(index)
dx_cells = cell_length(chain)

# sort marker chain - can't be done at the cell level because
# SA can't be sorted inside a GPU kernel
sort_chain!(chain)

# call kernel
@parallel (1:nx) resample!(coords, cell_vertices, index, min_xcell, max_xcell, dx_cells)
return nothing
Expand All @@ -22,13 +26,15 @@ function resample_cell!(
) where {T}

# cell particles coordinates
x_cell, y_cell = coords[1][I], coords[2][I]
index_I = @cell index[I]
px, py = coords[1], coords[2]
x_cell = @cell px[I]
y_cell = @cell py[I]

# lower-left corner of the cell
cell_vertex = cell_vertices[I]
# number of particles in the cell
np = count(index[I])
# number of p`articles in the cell
np = count(index_I)
# dx of the new chain
dx_chain = dx_cells / (np + 1)
# resample the cell if the number of particles is
Expand All @@ -38,11 +44,10 @@ function resample_cell!(
np_new = max(min_xcell, np)
dx_chain = dx_cells / (np_new + 1)
if do_resampling
# @show I
# fill index array
for ip in 1:np_new
# x query point
@index px[ip, I] = xq = cell_vertex + dx_chain * ip
xq = cell_vertex + dx_chain * ip
# interpolated y coordinated
yq = if 1 < I < length(index)
# inner cells; this is true (ncells-2) consecutive times
Expand All @@ -51,9 +56,7 @@ function resample_cell!(
# first and last cells
interp1D_extremas(xq, x_cell, y_cell)
end
if isnan(yq)
@show I, y_cell
end
@index px[ip, I] = xq
@index py[ip, I] = yq
@index index[ip, I] = true
end
Expand Down Expand Up @@ -92,3 +95,16 @@ function isdistorded(x_cell, dx_ideal)
end
return false
end

# sort marker chain cells
function sort_chain!(chain::MarkerChain{T}) where T

(; coords, index) = chain
# sort permutations of each cell
perms = sortperm(coords[1].data; dims=2)
coords[1].data .= @views coords[1].data[perms]
coords[2].data .= @views coords[2].data[perms]
index.data .= @views index.data[perms]

return nothing
end
Loading

0 comments on commit 7fc97ae

Please sign in to comment.