Skip to content

Commit

Permalink
Namespace some more things
Browse files Browse the repository at this point in the history
  • Loading branch information
noelchalmers committed Mar 5, 2022
1 parent 77aa3c9 commit adadfb2
Show file tree
Hide file tree
Showing 6 changed files with 26 additions and 18 deletions.
2 changes: 1 addition & 1 deletion include/mesh.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class mesh_t {
public:
platform_t platform;
settings_t settings;
occa::properties props;
properties_t props;

comm_t comm;
int rank, size;
Expand Down
2 changes: 1 addition & 1 deletion include/ogs/ogsBase.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class ogsBase_t {
bool unique=false;
bool gather_defined=false;

static occa::stream dataStream;
static stream_t dataStream;

ogsBase_t()=default;
virtual ~ogsBase_t()=default;
Expand Down
12 changes: 6 additions & 6 deletions include/ogs/ogsExchange.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ class ogsExchange_t {
pinnedMemory<char> h_workspace, h_sendspace;
deviceMemory<char> o_workspace, o_sendspace;

occa::stream dataStream;
static occa::kernel extractKernel[4];
stream_t dataStream;
static kernel_t extractKernel[4];

#ifdef GPU_AWARE_MPI
bool gpu_aware=true;
Expand All @@ -56,7 +56,7 @@ class ogsExchange_t {
#endif

ogsExchange_t(platform_t &_platform, comm_t _comm,
occa::stream _datastream):
stream_t _datastream):
platform(_platform),
comm(_comm),
dataStream(_datastream) {
Expand Down Expand Up @@ -118,7 +118,7 @@ class ogsAllToAll_t: public ogsExchange_t {
ogsAllToAll_t(dlong Nshared,
memory<parallelNode_t> &sharedNodes,
ogsOperator_t &gatherHalo,
occa::stream _dataStream,
stream_t _dataStream,
comm_t _comm,
platform_t &_platform);

Expand Down Expand Up @@ -199,7 +199,7 @@ class ogsPairwise_t: public ogsExchange_t {
ogsPairwise_t(dlong Nshared,
memory<parallelNode_t> &sharedNodes,
ogsOperator_t &gatherHalo,
occa::stream _dataStream,
stream_t _dataStream,
comm_t _comm,
platform_t &_platform);

Expand Down Expand Up @@ -281,7 +281,7 @@ class ogsCrystalRouter_t: public ogsExchange_t {
ogsCrystalRouter_t(dlong Nshared,
memory<parallelNode_t> &sharedNodes,
ogsOperator_t &gatherHalo,
occa::stream _dataStream,
stream_t _dataStream,
comm_t _comm,
platform_t &_platform);

Expand Down
6 changes: 3 additions & 3 deletions include/ogs/ogsOperator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,9 @@ class ogsOperator_t {

//4 types - Float, Double, Int32, Int64
//4 ops - Add, Mul, Max, Min
static occa::kernel gatherScatterKernel[4][4];
static occa::kernel gatherKernel[4][4];
static occa::kernel scatterKernel[4];
static kernel_t gatherScatterKernel[4][4];
static kernel_t gatherKernel[4][4];
static kernel_t scatterKernel[4];

friend void InitializeKernels(platform_t& platform, const Type type, const Op op);
};
Expand Down
8 changes: 8 additions & 0 deletions include/platform.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,14 @@ class platform_t {
return comm.size();
}

int getDeviceCount(const std::string mode) {
return occa::getDeviceCount(mode);
}

void setCacheDir(const std::string cacheDir) {
occa::env::setOccaCacheDir(cacheDir);
}

private:
void DeviceConfig();
void DeviceProperties();
Expand Down
14 changes: 7 additions & 7 deletions libs/core/platformDeviceConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ void platform_t::DeviceConfig(){
device_id = localRank;

//check for over-subscribing devices
int deviceCount = occa::getDeviceCount(mode);
int deviceCount = getDeviceCount(mode);
if (deviceCount>0 && localRank>=deviceCount) {
LIBP_FORCE_WARNING("Rank " << rank() << " oversubscribing device " << device_id%deviceCount << " on node \"" << hostname.ptr() << "\"");
device_id = device_id%deviceCount;
Expand Down Expand Up @@ -153,22 +153,22 @@ void platform_t::DeviceConfig(){

device.setup(mode);

std::string occaCacheDir;
std::string cacheDir;
char * cacheEnvVar = std::getenv("STREAMPARANUMAL_CACHE_DIR");
if (cacheEnvVar == nullptr) {
// Environment variable is not set
occaCacheDir = LIBP_DIR "/.occa";
cacheDir = LIBP_DIR "/.occa";
}
else {
// Environmet variable is set, but could be empty string
occaCacheDir = cacheEnvVar;
cacheDir = cacheEnvVar;

if (occaCacheDir.size() == 0) {
if (cacheDir.size() == 0) {
// Environment variable is set but equal to empty string
occaCacheDir = LIBP_DIR "/.occa";
cacheDir = LIBP_DIR "/.occa";
}
}
occa::env::setOccaCacheDir(occaCacheDir);
setCacheDir(cacheDir);

comm.Barrier();
}
Expand Down

0 comments on commit adadfb2

Please sign in to comment.