Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make extendedGaugeResident up to date with gaugePrecise. #1528

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 57 additions & 43 deletions lib/interface_quda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,31 @@ void freeUniqueSloppyGaugeUtility(GaugeField *&precise, GaugeField *&sloppy, Gau
void freeUniqueGaugeUtility(GaugeField *&precise, GaugeField *&sloppy, GaugeField *&precondition, GaugeField *&refinement,
GaugeField *&eigensolver, GaugeField *&extended, bool preserve_precise);

void updateExtendedGaugeResident(bool new_gauge, const lat_dim_t &R, TimeProfile &profile, bool redundant_comms = false,
QudaReconstructType recon = QUDA_RECONSTRUCT_INVALID)
{
if (!gaugePrecise) errorQuda("No resident gauge field allocated");
if (extendedGaugeResident) {
if (new_gauge) {
delete extendedGaugeResident;
extendedGaugeResident = createExtendedGauge(*gaugePrecise, R, profile, redundant_comms, recon);
} else if ((recon != QUDA_RECONSTRUCT_INVALID && recon != extendedGaugeResident->Reconstruct())
|| (R[0] != extendedGaugeResident->R()[0]) || R[1] != extendedGaugeResident->R()[1]
|| R[2] != extendedGaugeResident->R()[2] || R[3] != extendedGaugeResident->R()[3]) {
delete extendedGaugeResident;
extendedGaugeResident = createExtendedGauge(*gaugePrecise, R, profile, redundant_comms, recon);
}
} else if (!new_gauge) {
extendedGaugeResident = createExtendedGauge(*gaugePrecise, R, profile, redundant_comms, recon);
}
}

void updateExtendedGaugeResident(GaugeField *extendedGauge)
{
if (extendedGaugeResident) delete extendedGaugeResident;
extendedGaugeResident = extendedGauge;
}

void loadGaugeQuda(void *h_gauge, QudaGaugeParam *param)
{
auto profile = pushProfile(profileGauge);
Expand Down Expand Up @@ -723,6 +748,7 @@ void loadGaugeQuda(void *h_gauge, QudaGaugeParam *param)
gaugeEigensolver = eigensolver;

if(param->overlap) gaugeExtended = extended;
updateExtendedGaugeResident(true, R, profileGauge);
break;
case QUDA_ASQTAD_FAT_LINKS:
gaugeFatPrecise = precise;
Expand Down Expand Up @@ -753,14 +779,6 @@ void loadGaugeQuda(void *h_gauge, QudaGaugeParam *param)
}

delete in;

if (extendedGaugeResident) {
// updated the resident gauge field if needed
QudaReconstructType recon = extendedGaugeResident->Reconstruct();
delete extendedGaugeResident;
// Use the static R (which is defined at the very beginning of lib/interface_quda.cpp) here
extendedGaugeResident = createExtendedGauge(*gaugePrecise, R, profileGauge, false, recon);
}
}

void saveGaugeQuda(void *h_gauge, QudaGaugeParam *param)
Expand Down Expand Up @@ -3941,18 +3959,17 @@ int computeGaugeForceQuda(void* mom, void* siteLink, int*** input_path_buf, int
std::exchange(*gaugePrecise, cudaSiteLink);
}

if (qudaGaugeParam->make_resident_mom && !qudaGaugeParam->use_resident_mom)
std::exchange(momResident, cudaMom);
else if (!qudaGaugeParam->make_resident_mom)
momResident = GaugeField();

if (qudaGaugeParam->make_resident_gauge) {
if (extendedGaugeResident) delete extendedGaugeResident;
extendedGaugeResident = cudaGauge;
updateExtendedGaugeResident(cudaGauge);
} else {
delete cudaGauge;
}

if (qudaGaugeParam->make_resident_mom && !qudaGaugeParam->use_resident_mom)
std::exchange(momResident, cudaMom);
else if (!qudaGaugeParam->make_resident_mom)
momResident = GaugeField();

return 0;
}

Expand Down Expand Up @@ -4012,8 +4029,7 @@ int computeGaugePathQuda(void *out, void *siteLink, int ***input_path_buf, int *
}

if (qudaGaugeParam->make_resident_gauge) {
if (extendedGaugeResident) delete extendedGaugeResident;
extendedGaugeResident = cudaGauge;
updateExtendedGaugeResident(cudaGauge);
} else {
delete cudaGauge;
}
Expand Down Expand Up @@ -4065,8 +4081,9 @@ void createCloverQuda(QudaInvertParam* invertParam)
// for clover we optimize to only send depth 1 halos in y/z/t (FIXME - make work for x, make robust in general)
lat_dim_t R;
for (int d=0; d<4; d++) R[d] = (d==0 ? 2 : 1) * (redundant_comms || commDimPartitioned(d));
GaugeField *gauge
= extendedGaugeResident ? extendedGaugeResident : createExtendedGauge(*gaugePrecise, R, getProfile(), false, recon);
// FIXME always preserve the extended gauge
updateExtendedGaugeResident(false, R, profileClover, false, recon);
GaugeField *gauge = extendedGaugeResident;

GaugeField *ex = gauge;
if (gauge->Precision() < cloverPrecise->Precision()) {
Expand All @@ -4091,9 +4108,6 @@ void createCloverQuda(QudaInvertParam* invertParam)
cloverInvert(*cloverPrecise, cloverPrecise->Reconstruct());

if (ex != gauge) delete ex;

// FIXME always preserve the extended gauge
extendedGaugeResident = gauge;
}

void* createGaugeFieldQuda(void* gauge, int geometry, QudaGaugeParam* param)
Expand Down Expand Up @@ -4608,10 +4622,9 @@ void computeCloverForceQuda(void *h_mom, double dt, void **h_x, void **, double
errorQuda("solutionResident.size() %lu does not match number of shifts %d", solutionResident.size(), nvector);

// Make sure extendedGaugeResident has the correct R
if (extendedGaugeResident) delete extendedGaugeResident;
lat_dim_t R;
for (int d = 0; d < 4; d++) R[d] = (d == 0 ? 2 : 1) * (redundant_comms || commDimPartitioned(d));
extendedGaugeResident = createExtendedGauge(*gaugePrecise, R, getProfile());
updateExtendedGaugeResident(false, R, profileCloverForce);
GaugeField &gaugeEx = *extendedGaugeResident;

computeCloverForce(cudaMom, gaugeEx, *gaugePrecise, *cloverPrecise, x, x0, force_coeff, ferm_epsilon,
Expand Down Expand Up @@ -4683,10 +4696,9 @@ void computeTMCloverForceQuda(void *h_mom, void **h_x, void **h_x0, double *coef
}

// Make sure extendedGaugeResident has the correct R
if (extendedGaugeResident) delete extendedGaugeResident;
lat_dim_t R;
for (int d = 0; d < 4; d++) R[d] = (d == 0 ? 2 : 1) * (redundant_comms || commDimPartitioned(d));
extendedGaugeResident = createExtendedGauge(*gaugePrecise, R, profileTMCloverForce);
updateExtendedGaugeResident(false, R, profileTMCloverForce);
GaugeField &gaugeEx = *extendedGaugeResident;

computeCloverForce(gpuMom, gaugeEx, *gaugePrecise, *cloverPrecise, x, x0, force_coeff, ferm_epsilon,
Expand Down Expand Up @@ -4744,6 +4756,7 @@ void updateGaugeFieldQuda(void *gauge, void *momentum, double dt, int conj_mom,
if (gaugePrecise) freeUniqueGaugeQuda(QUDA_WILSON_LINKS);
gaugePrecise = new GaugeField();
std::exchange(*gaugePrecise, u_out);
updateExtendedGaugeResident(true, R, profileGaugeUpdate);
}

if (param->make_resident_mom && !param->use_resident_mom)
Expand Down Expand Up @@ -4788,6 +4801,8 @@ void projectSU3Quda(void *gauge_h, double tol, QudaGaugeParam *param)
gaugePrecise = new GaugeField();
std::exchange(*gaugePrecise, cudaGauge);
}

if (param->make_resident_gauge) { updateExtendedGaugeResident(true, R, profileProject); }
}

void staggeredPhaseQuda(void *gauge_h, QudaGaugeParam *param)
Expand Down Expand Up @@ -4825,6 +4840,8 @@ void staggeredPhaseQuda(void *gauge_h, QudaGaugeParam *param)
gaugePrecise = new GaugeField();
std::exchange(*gaugePrecise, cudaGauge);
}

if (param->make_resident_gauge) { updateExtendedGaugeResident(true, R, profilePhase); }
}

// evaluate the momentum action
Expand Down Expand Up @@ -4866,10 +4883,7 @@ void gaussGaugeQuda(unsigned long long seed, double sigma)
if (!gaugePrecise) errorQuda("Cannot generate Gauss GaugeField as there is no resident gauge field");
quda::gaugeGauss(*gaugePrecise, seed, sigma);

if (extendedGaugeResident) {
extendedGaugeResident->copy(*gaugePrecise);
extendedGaugeResident->exchangeExtendedGhost(R, profileGauss, redundant_comms);
}
updateExtendedGaugeResident(true, R, profileGauss);
}

void gaussMomQuda(unsigned long long seed, double sigma)
Expand All @@ -4888,8 +4902,8 @@ void plaqQuda(double plaq[3])

if (!gaugePrecise) errorQuda("Cannot compute plaquette as there is no resident gauge field");

GaugeField *data = extendedGaugeResident ? extendedGaugeResident : createExtendedGauge(*gaugePrecise, R, profilePlaq);
extendedGaugeResident = data;
updateExtendedGaugeResident(false, R, profilePlaq);
GaugeField *data = extendedGaugeResident;

double3 plaq3 = quda::plaquette(*data);
plaq[0] = plaq3.x;
Expand All @@ -4907,6 +4921,8 @@ void polyakovLoopQuda(double ploop[2], int dir)

QudaGaugeObservableParam obsParam = newQudaGaugeObservableParam();
obsParam.compute_polyakov_loop = QUDA_BOOLEAN_TRUE;
obsParam.remove_staggered_phase
= extendedGaugeResident->StaggeredPhaseApplied() ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE;
gaugeObservablesQuda(&obsParam);
ploop[0] = obsParam.ploop[0];
ploop[1] = obsParam.ploop[1];
Expand All @@ -4917,12 +4933,6 @@ void computeGaugeLoopTraceQuda(double _Complex *traces, int **input_path_buf, in
{
if (!gaugePrecise) errorQuda("Cannot compute gauge loop traces as there is no resident gauge field");

if (extendedGaugeResident) delete extendedGaugeResident;
extendedGaugeResident = createExtendedGauge(*gaugePrecise, R, profileGaugeObs);

// informed by gauge path code; apply / remove gauge as appropriate
if (extendedGaugeResident->StaggeredPhaseApplied()) extendedGaugeResident->removeStaggeredPhase();

QudaGaugeObservableParam obsParam = newQudaGaugeObservableParam();
obsParam.compute_gauge_loop_trace = QUDA_BOOLEAN_TRUE;
obsParam.traces = traces;
Expand All @@ -4932,6 +4942,8 @@ void computeGaugeLoopTraceQuda(double _Complex *traces, int **input_path_buf, in
obsParam.num_paths = num_paths;
obsParam.max_length = max_length;
obsParam.factor = factor;
obsParam.remove_staggered_phase
= extendedGaugeResident->StaggeredPhaseApplied() ? QUDA_BOOLEAN_TRUE : QUDA_BOOLEAN_FALSE;
gaugeObservablesQuda(&obsParam);
}

Expand All @@ -4941,8 +4953,7 @@ void computeGaugeLoopTraceQuda(double _Complex *traces, int **input_path_buf, in
void copyExtendedResidentGaugeQuda(void *resident_gauge)
{
if (!gaugePrecise) errorQuda("Cannot perform deep copy of resident gauge field as there is no resident gauge field");
extendedGaugeResident
= extendedGaugeResident ? extendedGaugeResident : createExtendedGauge(*gaugePrecise, R, profilePlaq);
updateExtendedGaugeResident(false, R, profilePlaq);
static_cast<GaugeField *>(resident_gauge)->copy(*extendedGaugeResident);
}

Expand Down Expand Up @@ -5424,8 +5435,7 @@ int computeGaugeFixingOVRQuda(void *gauge, const unsigned int gauge_dir, const u
freeUniqueGaugeQuda(QUDA_WILSON_LINKS);
gaugePrecise = new GaugeField();
std::exchange(*gaugePrecise, cudaInGauge);
if (extendedGaugeResident) delete extendedGaugeResident;
extendedGaugeResident = cudaInGaugeEx;
updateExtendedGaugeResident(cudaInGaugeEx);
} else {
delete cudaInGaugeEx;
}
Expand Down Expand Up @@ -5463,6 +5473,7 @@ int computeGaugeFixingFFTQuda(void *gauge, const unsigned int gauge_dir, const u
freeUniqueGaugeQuda(QUDA_WILSON_LINKS);
gaugePrecise = new GaugeField();
std::exchange(*gaugePrecise, cudaInGauge);
updateExtendedGaugeResident(true, R, GaugeFixFFTQuda);
}

return 0;
Expand Down Expand Up @@ -5612,7 +5623,7 @@ void gaugeObservablesQuda(QudaGaugeObservableParam *param)

GaugeField *gauge = nullptr;
if (!gaugeSmeared) {
if (!extendedGaugeResident) extendedGaugeResident = createExtendedGauge(*gaugePrecise, R, profileGaugeObs);
updateExtendedGaugeResident(false, R, profileGaugeObs);
gauge = extendedGaugeResident;
} else {
gauge = gaugeSmeared;
Expand All @@ -5627,6 +5638,9 @@ void gaugeObservablesQuda(QudaGaugeObservableParam *param)
}

gaugeObservables(*gauge, *param);

// Restore the staggered phase
if (param->remove_staggered_phase == QUDA_BOOLEAN_TRUE) { gauge->applyStaggeredPhase(); }
}

static void check_param(double _Complex *host_sinks, void **host_quark, int n_quark, int tile_quark, void **host_evec,
Expand Down