Skip to content

Commit

Permalink
Merge branch 'develop' into feature/sycl
Browse files Browse the repository at this point in the history
  • Loading branch information
jcosborn committed Jan 23, 2025
2 parents fcd7bdb + 18bf43e commit 599b4ae
Show file tree
Hide file tree
Showing 77 changed files with 4,291 additions and 1,318 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/cuda_githubactions_build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,17 @@ jobs:
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.0-1_all.deb
sudo dpkg -i cuda-keyring_1.0-1_all.deb
sudo apt-get update -y
sudo apt-get install -y --no-install-recommends ninja-build cmake libopenmpi-dev gfortran
sudo apt-get install -y --no-install-recommends ninja-build cmake libopenmpi-dev gfortran clang-14
- uses: awalsh128/cache-apt-pkgs-action@latest
with:
packages: cuda-compiler-12-1 cuda-libraries-dev-12-1 cuda-nvml-dev-12-1
execute_install_scripts: true

- uses: actions/checkout@v3
- uses: actions/checkout@v4

- name: Ccache for gh actions
uses: hendrikmuhs/ccache-action@v1.2.9
uses: hendrikmuhs/ccache-action@v1.2.16
with:
key: ${{ github.job }}-${{ matrix.compiler }}
max-size: 2000M
Expand Down
5 changes: 5 additions & 0 deletions include/color_spinor_field.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ namespace quda
int nColor = 0; // Number of colors of the field
int nSpin = 0; // =1 for staggered, =2 for coarse Dslash, =4 for 4d spinor
int nVec = 1; // number of packed vectors (for multigrid transfer operator)
int nVec_actual = 1; // The actual number of packed vectors (that are not zero padded)

QudaTwistFlavorType twistFlavor = QUDA_TWIST_INVALID; // used by twisted mass
QudaSiteOrder siteOrder = QUDA_INVALID_SITE_ORDER; // defined for full fields
Expand Down Expand Up @@ -241,6 +242,7 @@ namespace quda
nColor(cpuParam.nColor),
nSpin(cpuParam.nSpin),
nVec(cpuParam.nVec),
nVec_actual(cpuParam.nVec_actual),
twistFlavor(cpuParam.twistFlavor),
siteOrder(QUDA_EVEN_ODD_SITE_ORDER),
fieldOrder(QUDA_INVALID_FIELD_ORDER),
Expand Down Expand Up @@ -324,6 +326,7 @@ namespace quda
int nColor = 0;
int nSpin = 0;
int nVec = 0;
mutable int nVec_actual = 0;

QudaTwistFlavorType twistFlavor = QUDA_TWIST_INVALID;

Expand Down Expand Up @@ -461,6 +464,8 @@ namespace quda
int Ncolor() const { return nColor; }
int Nspin() const { return nSpin; }
int Nvec() const { return nVec; }
int Nvec_actual() const { return nVec_actual; }
void Nvec_actual(int nVec_actual) const { this->nVec_actual = nVec_actual; }
QudaTwistFlavorType TwistFlavor() const { return twistFlavor; }
int Ndim() const { return nDim; }
const int *X() const { return x.data; }
Expand Down
5 changes: 3 additions & 2 deletions include/color_spinor_field_order.h
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ namespace quda
*/
template <typename Float, typename storeFloat, bool block_float_, typename norm_t> struct fieldorder_wrapper {
using value_type = Float; /**< Compute type */
using store_type = storeFloat; /**< Storage type */
using store_t = storeFloat; /**< Storage type */
complex<storeFloat> *v; /**< Field memory address this wrapper encompasses */
const int idx; /**< Index into field */
private:
Expand Down Expand Up @@ -586,7 +586,6 @@ namespace quda
*/
__device__ __host__ inline auto get_scale() const
{
static_assert(block_float == false, "Orders with block_float == true should not call the get_scale method.");
return block_float ? static_cast<Float>(1) / norm[norm_idx] : scale;
}

Expand Down Expand Up @@ -863,6 +862,8 @@ namespace quda
static constexpr int nSpin = nSpin_;
static constexpr int nColor = nColor_;

using store_t = storeFloat;

field<Float, storeFloat, fixed, block_float> v;
unsigned int volumeCB = 0;

Expand Down
204 changes: 204 additions & 0 deletions include/expand_list.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
#include <tune_quda.h>
#include <int_factor_array.hpp>

namespace quda
{

/**
@brief This helper class instantiates the following mapping:
tp.aux.x -> Bx in x_atom_size * [factors of (x + x_atom_size - 1) / x_atom_size];
tp.aux.y -> By in y_atom_size * [factors of (y + y_atom_size - 1) / y_atom_size];
tp.aux.z -> Bz in z_atom_size * [factors of (z + z_atom_size - 1) / z_atom_size];
tp.aux.w -> Bw in w_atom_size * [factors of (w + w_atom_size - 1) / w_atom_size].
See `void expand(TuneParam &tp, const qudaStream_t &stream)`
*/
template <class Callable, int x, int x_atom_size, int y, int y_atom_size, int z, int z_atom_size, int w, int w_atom_size>
class expand_aux_t
{

Callable &_callable;

static constexpr IntFactorArray<(x + x_atom_size - 1) / x_atom_size, x_atom_size> x_factors {};
static constexpr IntFactorArray<(y + y_atom_size - 1) / y_atom_size, y_atom_size> y_factors {};
static constexpr IntFactorArray<(z + z_atom_size - 1) / z_atom_size, z_atom_size> z_factors {};
static constexpr IntFactorArray<(w + w_atom_size - 1) / w_atom_size, w_atom_size> w_factors {};

template <int Bx, int By, int Bz, size_t W, size_t... Ws>
void span_w(TuneParam &tp, const qudaStream_t &stream, std::index_sequence<W, Ws...>)
{
constexpr int Bw = w_factors[W];
if (tp.aux.w == Bw) {
_callable.template launch_mma<Bx, By, Bz, Bw>(tp, stream);
} else {
if constexpr (sizeof...(Ws) > 0) {
span_w<Bx, By, Bz>(tp, stream, std::index_sequence<Ws...>());
} else {
errorQuda("Invalid tp.aux.w(=%d)", tp.aux.w);
}
}
}

template <int Bx, int By, size_t Z, size_t... Zs>
void span_z(TuneParam &tp, const qudaStream_t &stream, std::index_sequence<Z, Zs...>)
{
constexpr int Bz = z_factors[Z];
if (tp.aux.z == Bz) {
std::make_index_sequence<w_factors.size()> w_indices;
span_w<Bx, By, Bz>(tp, stream, w_indices);
} else {
if constexpr (sizeof...(Zs) > 0) {
span_z<Bx, By>(tp, stream, std::index_sequence<Zs...>());
} else {
errorQuda("Invalid tp.aux.z(=%d)", tp.aux.z);
}
}
}

template <int Bx, size_t Y, size_t... Ys>
void span_y(TuneParam &tp, const qudaStream_t &stream, std::index_sequence<Y, Ys...>)
{
constexpr int By = y_factors[Y];
if (tp.aux.y == By) {
std::make_index_sequence<z_factors.size()> z_indices;
span_z<Bx, By>(tp, stream, z_indices);
} else {
if constexpr (sizeof...(Ys) > 0) {
span_y<Bx>(tp, stream, std::index_sequence<Ys...>());
} else {
errorQuda("Invalid tp.aux.y(=%d)", tp.aux.y);
}
}
}

template <size_t X, size_t... Xs>
void span_x(TuneParam &tp, const qudaStream_t &stream, std::index_sequence<X, Xs...>)
{
constexpr int Bx = x_factors[X];
if (tp.aux.x == Bx) {
std::make_index_sequence<y_factors.size()> y_indices;
span_y<Bx>(tp, stream, y_indices);
} else {
if constexpr (sizeof...(Xs) > 0) {
span_x(tp, stream, std::index_sequence<Xs...>());
} else {
errorQuda("Invalid tp.aux.x(=%d)", tp.aux.x);
}
}
}

public:
/**
@brief invoke `_callable.template launch_mma<Bx, By, Bz, Bw>(tp, stream);` based on the tp.aux values
tp.aux.x -> Bx in x_atom_size * [factors of (x + x_atom_size - 1) / x_atom_size];
tp.aux.y -> By in y_atom_size * [factors of (y + y_atom_size - 1) / y_atom_size];
tp.aux.z -> Bz in z_atom_size * [factors of (z + z_atom_size - 1) / z_atom_size];
tp.aux.w -> Bw in w_atom_size * [factors of (w + w_atom_size - 1) / w_atom_size].
For example, if x_atom_size = 8, x = 48, then Bx can take values in [8, 16, 24, 48]; when tp.aux.x == 0,
Bx = 8; when tp.aux.x == 1, Bx = 16; when tp.aux.x == 2, Bx = 24; when tp.aux.x == 3, Bx = 48.
@param tp The TuneParam parameter
@param stream The stream parameter
*/
void expand(TuneParam &tp, const qudaStream_t &stream)
{
std::make_index_sequence<x_factors.size()> x_indices;
span_x(tp, stream, x_indices);
}

expand_aux_t(Callable &callable) : _callable(callable) { }

/**
@brief Get the Bx value
@param tp The TuneParam parameter
*/
int get_x(const TuneParam &tp) const
{
if (x_factors.get_index(tp.aux.x) >= x_factors.size()) { errorQuda("Invalid tp.aux.x = %d\n", tp.aux.x); }
return tp.aux.x;
}

/**
@brief Get the By value
@param tp The TuneParam parameter
*/
int get_y(const TuneParam &tp) const
{
if (y_factors.get_index(tp.aux.y) >= y_factors.size()) { errorQuda("Invalid tp.aux.y = %d\n", tp.aux.y); }
return tp.aux.y;
}

/**
@brief Get the Bz value
@param tp The TuneParam parameter
*/
int get_z(const TuneParam &tp) const
{
if (z_factors.get_index(tp.aux.z) >= z_factors.size()) { errorQuda("Invalid tp.aux.z = %d\n", tp.aux.z); }
return tp.aux.z;
}

/**
@brief Get the Bw value
@param tp The TuneParam parameter
*/
int get_w(const TuneParam &tp) const
{
if (w_factors.get_index(tp.aux.w) >= w_factors.size()) { errorQuda("Invalid tp.aux.w = %d\n", tp.aux.w); }
return tp.aux.w;
}

template <unsigned int Int, unsigned int Multiple>
bool advancer(int &aux, TuneParam &param, const IntFactorArray<Int, Multiple> &factors) const
{
if (factors.get_index(aux) < factors.size() - 1) {
aux = factors[factors.get_index(aux) + 1];
return _callable.set_mma_param(param);
} else {
return false;
}
}

/**
@brief Advance to the next possible aux value and return true; return false we have gone to the last
possible value
@return whether or not an advance is performed
@param tp The TuneParam parameter
*/
bool advance_aux(TuneParam &param) const
{
if (advancer(param.aux.x, param, x_factors)) {
return true;
} else {
param.aux.x = x_atom_size;
if (advancer(param.aux.y, param, y_factors)) {
return true;
} else {
param.aux.y = y_atom_size;
if (advancer(param.aux.z, param, z_factors)) {
return true;
} else {
param.aux.z = z_atom_size;
if (advancer(param.aux.w, param, w_factors)) {
return true;
} else {
param.aux.w = w_atom_size;
return false;
}
}
}
}
}

/**
@brief Initialize aux
@param tp The TuneParam parameter
*/
void init_aux(TuneParam &param) const
{
param.aux.x = x_atom_size;
param.aux.y = y_atom_size;
param.aux.z = z_atom_size;
param.aux.w = w_atom_size;
}
};

} // namespace quda
2 changes: 1 addition & 1 deletion include/gauge_field_order.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ namespace quda {
template <typename Float, typename storeFloat>
struct fieldorder_wrapper {
using value_type = Float;
using store_type = storeFloat;
using store_t = storeFloat;
complex<storeFloat> *v;
const unsigned int idx;

Expand Down
40 changes: 12 additions & 28 deletions include/int_factor_array.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,32 +5,6 @@
namespace quda
{

inline unsigned int numFactors(unsigned int Int)
{
unsigned int i = 0;
for (unsigned int j = 1u; j <= Int; j++) {
if (Int % j == 0) { i++; }
}
return i;
}

/**
* @brief A struct containing a compile time generated array
* containing factors of an integer.
*/
inline auto get_int_factor_array(unsigned int Int)
{
std::vector<unsigned int> _out(numFactors(Int));
unsigned int i = 0;
for (unsigned int j = 1u; j <= Int; j++) {
if (Int % j == 0) {
_out[i] = j;
i++;
}
}
return _out;
}

/**
* @brief compute number of factors of an integer
*
Expand All @@ -48,7 +22,7 @@ namespace quda
* @brief A struct containing a compile time generated array
* containing factors of an integer.
*/
template <unsigned int Int> struct IntFactorArray {
template <unsigned int Int, unsigned int Multiple> struct IntFactorArray {

array<unsigned int, numFactors<Int>()> data_;

Expand All @@ -72,7 +46,17 @@ namespace quda
* @brief read only constant index operator[]
* @param i the index to look up
*/
constexpr unsigned int operator[](int i) const noexcept { return data_[i]; }
constexpr unsigned int operator[](int i) const noexcept { return Multiple * data_[i]; }

constexpr unsigned int get_index(unsigned int value) const noexcept
{
unsigned int i = 0;
for (; i < numFactors<Int>(); i++) {
if (Multiple * data_[i] == static_cast<unsigned int>(value)) { return i; }
}
return i;
}

}; // end struct

} // namespace quda
12 changes: 12 additions & 0 deletions include/int_list.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

namespace quda
{

/**
@brief This is a dummy struct that wraps around a list of integers
*/
template <int... Ints> struct IntList {
};

} // namespace quda
3 changes: 2 additions & 1 deletion include/kernel_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,9 @@ namespace quda

enum class use_kernel_arg_p { FALSE, TRUE, ALWAYS };

template <use_kernel_arg_p use_kernel_arg_ = use_kernel_arg_p::TRUE> struct kernel_param {
template <use_kernel_arg_p use_kernel_arg_ = use_kernel_arg_p::TRUE, bool check_bounds_ = true> struct kernel_param {
static constexpr use_kernel_arg_p use_kernel_arg = use_kernel_arg_;
static constexpr bool check_bounds = check_bounds_;
dim3 threads; /** number of active threads required */
int comms_rank; /** per process value of comm_rank() */
int comms_rank_global; /** per process value comm_rank_global() */
Expand Down
Loading

0 comments on commit 599b4ae

Please sign in to comment.