Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pass KernelOps into SharedMemoryCache constructor #1531

Open
wants to merge 10 commits into
base: develop
Choose a base branch
from
11 changes: 8 additions & 3 deletions include/dslash_helper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <shmem_pack_helper.cuh>
#include <kernel_helper.h>
#include <tune_quda.h>
#include <kernel_ops.h>

constexpr quda::use_kernel_arg_p use_kernel_arg = quda::use_kernel_arg_p::TRUE;

Expand Down Expand Up @@ -660,17 +661,21 @@ namespace quda
are reserved for data packing, which may include communication to
neighboring processes.
*/
template <typename Arg> struct dslash_functor {
template <typename Arg> struct dslash_functor : getKernelOps<typename Arg::D> {
const typename Arg::Arg &arg;
static constexpr int nParity = Arg::nParity;
static constexpr bool dagger = Arg::dagger;
static constexpr KernelType kernel_type = Arg::kernel_type;
static constexpr const char *filename() { return Arg::D::filename(); }
constexpr dslash_functor(const Arg &arg) : arg(arg.arg) { }
using typename getKernelOps<typename Arg::D>::KernelOpsT;
template <typename... OpsArgs>
constexpr dslash_functor(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg.arg)
{
}

__forceinline__ __device__ void operator()(int, int s, int parity)
{
typename Arg::D dslash(arg);
typename Arg::D dslash(*this);
// for full fields set parity from z thread index else use arg setting
if (nParity == 1) parity = arg.parity;

Expand Down
84 changes: 54 additions & 30 deletions include/gauge_fix_ovr_hit_devf.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,29 @@ namespace quda {
}
}

template <int N> struct GaugeFixHitDims {
static constexpr dim3 dims(dim3 block)
{
block.y = N;
return block;
}
};

/**
* Device function to perform gauge fixing with overrelxation.
* Uses 8 threads per lattice site, the reduction is performed by shared memory without using atomicadd.
* This implementation needs 8x more shared memory than the implementation using atomicadd
* Uses 4 threads per lattice site, the reduction is performed by shared memory using atomicadd.
*/
template <typename Float, int gauge_dir, int nColor>
inline __device__ void GaugeFixHit_AtomicAdd(Matrix<complex<Float>,nColor> &link, const Float relax_boost, int mu)
template <typename Float> using GaugeFixHit_AtomicAddOps = KernelOps<SharedMemoryCache<Float, GaugeFixHitDims<4>>>;
template <typename Float, int gauge_dir, int nColor, typename Ftor>
inline __device__ void GaugeFixHit_AtomicAdd(Matrix<complex<Float>, nColor> &link, const Float relax_boost, int mu,
const Ftor &ftor)
{
auto blockSize = target::block_dim().x;
auto tid = target::thread_idx().x;

//Container for the four real parameters of SU(2) subgroup in shared memory
SharedMemoryCache<Float> cache;
auto elems = cache.data();
SharedMemoryCache<Float, GaugeFixHitDims<4>> cache(ftor);
Float *elems = cache.data();
maddyscientist marked this conversation as resolved.
Show resolved Hide resolved

//initialize shared memory
if (mu < 4) elems[mu * blockSize + tid] = 0.0;
Expand Down Expand Up @@ -138,17 +147,20 @@ namespace quda {

/**
* Device function to perform gauge fixing with overrelxation.
* Uses 4 threads per lattice site, the reduction is performed by shared memory using atomicadd.
* Uses 4*8 threads per lattice site, the reduction is performed by shared memory without using atomicadd.
* This implementation needs 8x more shared memory than the implementation using atomicadd
*/
template <typename Float, int gauge_dir, int nColor>
inline __device__ void GaugeFixHit_NoAtomicAdd(Matrix<complex<Float>,nColor> &link, const Float relax_boost, int mu)
template <typename Float> using GaugeFixHit_NoAtomicAddOps = KernelOps<SharedMemoryCache<array<Float, 4>>>;
template <typename Float, int gauge_dir, int nColor, typename Ftor>
inline __device__ void GaugeFixHit_NoAtomicAdd(Matrix<complex<Float>, nColor> &link, const Float relax_boost, int mu,
const Ftor &ftor)
{
auto blockSize = target::block_dim().x;
auto tid = target::thread_idx().x;

//Container for the four real parameters of SU(2) subgroup in shared memory
SharedMemoryCache<Float> cache;
auto elems = cache.data();
SharedMemoryCache<array<Float, 4>> cache(ftor);
Float *elems = &(*cache.data())[0];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not liking this change. Why is this pathology necessary?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, there was no need to use an inner array. I've made it similar to the others now.
99632a5


//Loop over all SU(2) subroups of SU(N)
//#pragma unroll
Expand Down Expand Up @@ -228,15 +240,18 @@ namespace quda {
* Uses 8 treads per lattice site, the reduction is performed by shared memory without using atomicadd.
* This implementation uses the same amount of shared memory as the atomicadd implementation with more thread block synchronization
*/
template <typename Float, int gauge_dir, int nColor>
inline __device__ void GaugeFixHit_NoAtomicAdd_LessSM(Matrix<complex<Float>,nColor> &link, const Float relax_boost, int mu)
template <typename Float>
using GaugeFixHit_NoAtomicAdd_LessSMOps = KernelOps<SharedMemoryCache<Float, GaugeFixHitDims<4>>>;
template <typename Float, int gauge_dir, int nColor, typename Ftor>
inline __device__ void GaugeFixHit_NoAtomicAdd_LessSM(Matrix<complex<Float>, nColor> &link, const Float relax_boost,
int mu, const Ftor &ftor)
{
auto blockSize = target::block_dim().x;
auto tid = target::thread_idx().x;

//Container for the four real parameters of SU(2) subgroup in shared memory
SharedMemoryCache<Float> cache;
auto elems = cache.data();
SharedMemoryCache<Float, GaugeFixHitDims<4>> cache(ftor);
Float *elems = cache.data();

//Loop over all SU(2) subroups of SU(N)
//#pragma unroll
Expand Down Expand Up @@ -323,18 +338,20 @@ namespace quda {
/**
* Device function to perform gauge fixing with overrelxation.
* Uses 8 threads per lattice site, the reduction is performed by shared memory without using atomicadd.
* This implementation needs 8x more shared memory than the implementation using atomicadd
* This implementation needs 8x more shared memory than the implementation using atomicadd
*/
template <typename Float, int gauge_dir, int nColor>
inline __device__ void GaugeFixHit_AtomicAdd(Matrix<complex<Float>,nColor> &link, Matrix<complex<Float>,nColor> &link1,
const Float relax_boost, int mu)
template <typename Float> using GaugeFixHit_AtomicAdd2Ops = KernelOps<SharedMemoryCache<Float, GaugeFixHitDims<4>>>;
template <typename Float, int gauge_dir, int nColor, typename Ftor>
inline __device__ void GaugeFixHit_AtomicAdd(Matrix<complex<Float>, nColor> &link,
Matrix<complex<Float>, nColor> &link1, const Float relax_boost, int mu,
const Ftor &ftor)
{
auto blockSize = target::block_dim().x;
auto tid = target::thread_idx().x;

//Container for the four real parameters of SU(2) subgroup in shared memory
SharedMemoryCache<Float> cache;
auto elems = cache.data();
SharedMemoryCache<Float, GaugeFixHitDims<4>> cache(ftor);
Float *elems = cache.data();

//initialize shared memory
if (mu < 4) elems[mu * blockSize + tid] = 0.0;
Expand Down Expand Up @@ -408,16 +425,19 @@ namespace quda {
* Device function to perform gauge fixing with overrelxation.
* Uses 4 threads per lattice site, the reduction is performed by shared memory using atomicadd.
*/
template <typename Float, int gauge_dir, int nColor>
inline __device__ void GaugeFixHit_NoAtomicAdd(Matrix<complex<Float>,nColor> &link, Matrix<complex<Float>,nColor> &link1,
const Float relax_boost, int mu)
template <typename Float>
using GaugeFixHit_NoAtomicAdd2Ops = KernelOps<SharedMemoryCache<Float, GaugeFixHitDims<16>>>;
template <typename Float, int gauge_dir, int nColor, typename Ftor>
inline __device__ void GaugeFixHit_NoAtomicAdd(Matrix<complex<Float>, nColor> &link,
Matrix<complex<Float>, nColor> &link1, const Float relax_boost, int mu,
const Ftor &ftor)
{
auto blockSize = target::block_dim().x;
auto tid = target::thread_idx().x;

//Container for the four real parameters of SU(2) subgroup in shared memory
SharedMemoryCache<Float> cache;
auto elems = cache.data();
SharedMemoryCache<Float, GaugeFixHitDims<16>> cache(ftor);
Float *elems = cache.data();

//Loop over all SU(2) subroups of SU(N)
//#pragma unroll
Expand Down Expand Up @@ -485,15 +505,19 @@ namespace quda {
* Uses 4 threads per lattice site, the reduction is performed by shared memory without using atomicadd.
* This implementation uses the same amount of shared memory as the atomicadd implementation with more thread block synchronization
*/
template <typename Float, int gauge_dir, int nColor>
inline __device__ void GaugeFixHit_NoAtomicAdd_LessSM(Matrix<complex<Float>,nColor> &link, Matrix<complex<Float>,nColor> &link1, const Float relax_boost, int mu)
template <typename Float>
using GaugeFixHit_NoAtomicAdd_LessSM2Ops = KernelOps<SharedMemoryCache<Float, GaugeFixHitDims<4>>>;
template <typename Float, int gauge_dir, int nColor, typename Ftor>
inline __device__ void GaugeFixHit_NoAtomicAdd_LessSM(Matrix<complex<Float>, nColor> &link,
Matrix<complex<Float>, nColor> &link1, const Float relax_boost,
int mu, const Ftor &ftor)
{
auto blockSize = target::block_dim().x;
auto tid = target::thread_idx().x;

//Container for the four real parameters of SU(2) subgroup in shared memory
SharedMemoryCache<Float> cache;
auto elems = cache.data();
SharedMemoryCache<Float, GaugeFixHitDims<4>> cache(ftor);
Float *elems = cache.data();

//Loop over all SU(2) subroups of SU(N)
//#pragma unroll
Expand Down
21 changes: 15 additions & 6 deletions include/kernels/block_transpose.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,7 @@ namespace quda
}
};

template <typename Arg> struct BlockTransposeKernel {
const Arg &arg;
constexpr BlockTransposeKernel(const Arg &arg) : arg(arg) { }
static constexpr const char *filename() { return KERNEL_FILE; }

template <typename Arg> struct BlockTransposeKernelOps {
struct CacheDims {
static constexpr dim3 dims(dim3 block)
{
Expand All @@ -55,6 +51,19 @@ namespace quda
return block;
}
};
using color_spinor_t = ColorSpinor<typename Arg::real, 1, Arg::nSpin>;
maddyscientist marked this conversation as resolved.
Show resolved Hide resolved
using CacheT = SharedMemoryCache<color_spinor_t, CacheDims>;
using Ops = KernelOps<CacheT>;
};

template <typename Arg> struct BlockTransposeKernel : BlockTransposeKernelOps<Arg>::Ops {
const Arg &arg;
using typename BlockTransposeKernelOps<Arg>::Ops::KernelOpsT;
template <typename... OpsArgs>
constexpr BlockTransposeKernel(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg)
{
}
static constexpr const char *filename() { return KERNEL_FILE; }

/**
@brief Transpose between the two different orders of batched colorspinor fields:
Expand All @@ -69,7 +78,7 @@ namespace quda
int parity = parity_color / Arg::nColor;
using color_spinor_t = ColorSpinor<typename Arg::real, 1, Arg::nSpin>;

SharedMemoryCache<color_spinor_t, CacheDims> cache;
typename BlockTransposeKernelOps<Arg>::CacheT cache {*this};

int x_offset = target::block_dim().x * target::block_idx().x;
int v_offset = target::block_dim().y * target::block_idx().y;
Expand Down
Loading
Loading