Skip to content

Commit

Permalink
Merge pull request #1457 from lattice/hotfix/staggered-eig-parity-inf…
Browse files Browse the repository at this point in the history
…late

Bugfix for single-parity vector inflation in `VectorIO`
  • Loading branch information
maddyscientist authored Apr 26, 2024
2 parents 1f88479 + f09576e commit c70220c
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 42 deletions.
16 changes: 12 additions & 4 deletions lib/vector_io.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,16 @@ namespace quda
for (int j = 0; j < Ls; j++) { V[i * Ls + j] = v.data<char *>() + j * stride; }
}

// if we're loading inflated vectors, we need to grab the spinor size from `tmp` instead of `v0`.
auto spinor_X = (create_tmp && parity_inflate) ? tmp[0].X() : v0.X();
auto spinor_site_subset = (create_tmp && parity_inflate) ? tmp[0].SiteSubset() : v0.SiteSubset();

// time loading
quda::host_timer_t host_timer;
host_timer.start(); // start the timer

read_spinor_field(filename.c_str(), V.data(), v0.Precision(), v0.X(), v0.SiteSubset(),
spinor_parity, v0.Ncolor(), v0.Nspin(), Nvec * Ls, 0, nullptr);
read_spinor_field(filename.c_str(), V.data(), v0.Precision(), spinor_X, spinor_site_subset, spinor_parity,
v0.Ncolor(), v0.Nspin(), Nvec * Ls, 0, nullptr);

host_timer.stop(); // stop the timer
logQuda(QUDA_SUMMARIZE, "Time spent loading vectors from %s = %g secs\n", filename.c_str(), host_timer.last());
Expand Down Expand Up @@ -140,12 +144,16 @@ namespace quda
for (int j = 0; j < Ls; j++) { V[i * Ls + j] = v.data<const char *>() + j * stride; }
}

// if we performed parity inflation, we need to grab the spinor size from `tmp` instead of `v0`.
auto spinor_X = (create_tmp && parity_inflate) ? tmp[0].X() : v0.X();
auto spinor_site_subset = (create_tmp && parity_inflate) ? tmp[0].SiteSubset() : v0.SiteSubset();

// time saving
quda::host_timer_t host_timer;
host_timer.start(); // start the timer

write_spinor_field(filename.c_str(), V.data(), save_prec, v0.X(), v0.SiteSubset(), spinor_parity, v0.Ncolor(),
v0.Nspin(), Nvec * Ls, 0, nullptr, partfile);
write_spinor_field(filename.c_str(), V.data(), save_prec, spinor_X, spinor_site_subset, spinor_parity,
v0.Ncolor(), v0.Nspin(), Nvec * Ls, 0, nullptr, partfile);

host_timer.stop(); // stop the timer
logQuda(QUDA_SUMMARIZE, "Time spent saving vectors to %s = %g secs\n", filename.c_str(), host_timer.last());
Expand Down
128 changes: 90 additions & 38 deletions tests/io_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,15 @@ TEST_P(GaugeIOTest, verify)
for (int dir = 0; dir < 4; dir++) { host_free(gauge[dir]); }
}

using cs_test_t = ::testing::tuple<QudaSiteSubset, bool, QudaPrecision, QudaPrecision, int, bool, QudaFieldLocation>;
using cs_test_t = ::testing::tuple<QudaSiteSubset, QudaSiteSubset, QudaParity, bool, QudaPrecision, QudaPrecision, int,
bool, QudaFieldLocation>;

class ColorSpinorIOTest : public ::testing::TestWithParam<cs_test_t>
{
protected:
QudaSiteSubset site_subset;
QudaSiteSubset site_subset_saved;
QudaSiteSubset site_subset_loaded;
QudaParity parity;
bool inflate;
QudaPrecision prec;
QudaPrecision prec_io;
Expand All @@ -86,13 +89,15 @@ class ColorSpinorIOTest : public ::testing::TestWithParam<cs_test_t>

public:
ColorSpinorIOTest() :
site_subset(::testing::get<0>(GetParam())),
inflate(::testing::get<1>(GetParam())),
prec(::testing::get<2>(GetParam())),
prec_io(::testing::get<3>(GetParam())),
nSpin(::testing::get<4>(GetParam())),
partfile(::testing::get<5>(GetParam())),
location(::testing::get<6>(GetParam()))
site_subset_saved(::testing::get<0>(GetParam())),
site_subset_loaded(::testing::get<1>(GetParam())),
parity(::testing::get<2>(GetParam())),
inflate(::testing::get<3>(GetParam())),
prec(::testing::get<4>(GetParam())),
prec_io(::testing::get<5>(GetParam())),
nSpin(::testing::get<6>(GetParam())),
partfile(::testing::get<7>(GetParam())),
location(::testing::get<8>(GetParam()))
{
}
};
Expand All @@ -117,37 +122,79 @@ TEST_P(ColorSpinorIOTest, verify)
|| (prec < QUDA_SINGLE_PRECISION && nSpin == 2))
GTEST_SKIP();

// There's no meaningful way to save a full parity spinor field and load a
// single-parity field
if (site_subset_saved == QUDA_FULL_SITE_SUBSET && site_subset_loaded == QUDA_PARITY_SITE_SUBSET) GTEST_SKIP();

// Site subsets must be the same except for inflation tests
if (site_subset_saved != site_subset_loaded && !inflate) GTEST_SKIP();

// Unique test: saving an inflated single-parity field and loading the full-parity field
bool mixed_inflated_full_test
= (site_subset_saved == QUDA_PARITY_SITE_SUBSET && site_subset_loaded == QUDA_FULL_SITE_SUBSET && inflate);

QudaGaugeParam gauge_param = newQudaGaugeParam();
QudaInvertParam inv_param = newQudaInvertParam();
ColorSpinorParam param;
setWilsonGaugeParam(gauge_param);
setInvertParam(inv_param);
constructWilsonTestSpinorParam(&param, &inv_param, &gauge_param);
param.siteSubset = site_subset;
param.suggested_parity = QUDA_EVEN_PARITY;
param.nSpin = nSpin;
param.setPrecision(prec, prec, true); // change order to native order
param.location = location;
if (location == QUDA_CPU_FIELD_LOCATION) param.fieldOrder = QUDA_SPACE_SPIN_COLOR_FIELD_ORDER;
param.create = QUDA_NULL_FIELD_CREATE;

// We create a separate parameter struct for saved and loaded ColorSpinorFields
ColorSpinorParam param_save, param_load;

// Saved
constructWilsonTestSpinorParam(&param_save, &inv_param, &gauge_param);
param_save.siteSubset = site_subset_saved;
param_save.suggested_parity = parity; // ignored for full parity
param_save.nSpin = nSpin;
param_save.setPrecision(prec, prec, true); // change order to native order
param_save.location = location;
if (location == QUDA_CPU_FIELD_LOCATION) param_save.fieldOrder = QUDA_SPACE_SPIN_COLOR_FIELD_ORDER;
param_save.create = QUDA_NULL_FIELD_CREATE;

// Loaded
param_load = param_save;
param_load.siteSubset = site_subset_loaded;

// Update ColorSpinorField size appropriately depending on the site subset
if (site_subset_saved == QUDA_PARITY_SITE_SUBSET) param_save.x[0] /= 2;
if (site_subset_loaded == QUDA_PARITY_SITE_SUBSET) param_load.x[0] /= 2;

// create some random vectors
auto n_vector = 1;
std::vector<ColorSpinorField> v(n_vector, param);
std::vector<ColorSpinorField> u(n_vector, param);
std::vector<ColorSpinorField> v(n_vector, param_save);
std::vector<ColorSpinorField> u(n_vector, param_load);

RNG rng(v[0], 1234);
for (auto &vi : v) spinorNoise(vi, rng, QUDA_NOISE_GAUSS);

auto file = "dummy.cs";

VectorIO io(file, inflate, partfile);

io.save({v.begin(), v.end()}, prec_io, n_vector);
io.load(u);
// create a separate VectorIO for saving and loading
// this lets us test saving a single-parity field with inflation
// then loading the full parity field.
VectorIO io_save(file, inflate, partfile);
VectorIO io_load(file, inflate, partfile);

io_save.save({v.begin(), v.end()}, prec_io, n_vector);
io_load.load(u);

// This lambda simplifies returning the correct subset of `u` by reference if we're doing
// tests where we save an inflated single-parity vector and load a full-parity vector.
auto get_u_subset = [&](int i) -> const ColorSpinorField & {
if (mixed_inflated_full_test)
if (this->parity == QUDA_EVEN_PARITY) // lambdas don't automatically capture `this`
return u[i].Even();
else
return u[i].Odd();
else
return u[i];
};

for (auto i = 0u; i < v.size(); i++) {
auto dev = blas::max_deviation(u[i], v[i]);
// grab the correct subset of u if we're doing mixed parity/full tests
const ColorSpinorField &u_subset = get_u_subset(i);

auto dev = blas::max_deviation(u_subset, v[i]);
if (prec == prec_io)
EXPECT_EQ(dev[0], 0.0);
else
Expand Down Expand Up @@ -185,33 +232,38 @@ INSTANTIATE_TEST_SUITE_P(Gauge, GaugeIOTest, Combine(Values(QUDA_DOUBLE_PRECISIO

// colorspinor full field IO test
INSTANTIATE_TEST_SUITE_P(Full, ColorSpinorIOTest,
Combine(Values(QUDA_FULL_SITE_SUBSET), Values(false),
Combine(Values(QUDA_FULL_SITE_SUBSET), Values(QUDA_FULL_SITE_SUBSET),
Values(QUDA_INVALID_PARITY), Values(false),
Values(QUDA_DOUBLE_PRECISION, QUDA_SINGLE_PRECISION, QUDA_HALF_PRECISION),
Values(QUDA_DOUBLE_PRECISION, QUDA_SINGLE_PRECISION), Values(1, 2, 4),
Values(false, true), Values(QUDA_CUDA_FIELD_LOCATION, QUDA_CPU_FIELD_LOCATION)),
[](testing::TestParamInfo<cs_test_t> param) {
std::string name;
name += get_prec_str(::testing::get<2>(param.param)) + std::string("_");
name += get_prec_str(::testing::get<3>(param.param)) + std::string("_");
name += std::string("spin") + std::to_string(::testing::get<4>(param.param));
name += ::testing::get<5>(param.param) ? "_singlefile" : "_partfile";
name += ::testing::get<6>(param.param) == QUDA_CUDA_FIELD_LOCATION ? "_device" : "_host";
name += get_prec_str(::testing::get<4>(param.param)) + std::string("_");
name += get_prec_str(::testing::get<5>(param.param)) + std::string("_");
name += std::string("spin") + std::to_string(::testing::get<6>(param.param));
name += ::testing::get<7>(param.param) ? "_singlefile" : "_partfile";
name += ::testing::get<8>(param.param) == QUDA_CUDA_FIELD_LOCATION ? "_device" : "_host";
return name;
});

// colorspinor parity field IO test
INSTANTIATE_TEST_SUITE_P(Parity, ColorSpinorIOTest,
Combine(Values(QUDA_PARITY_SITE_SUBSET), Values(false, true),
Combine(Values(QUDA_PARITY_SITE_SUBSET), Values(QUDA_PARITY_SITE_SUBSET, QUDA_FULL_SITE_SUBSET),
Values(QUDA_EVEN_PARITY, QUDA_ODD_PARITY), Values(false, true),
Values(QUDA_DOUBLE_PRECISION, QUDA_SINGLE_PRECISION, QUDA_HALF_PRECISION),
Values(QUDA_DOUBLE_PRECISION, QUDA_SINGLE_PRECISION), Values(1, 2, 4),
Values(false, true), Values(QUDA_CUDA_FIELD_LOCATION, QUDA_CPU_FIELD_LOCATION)),
[](testing::TestParamInfo<cs_test_t> param) {
std::string name;
if (::testing::get<1>(param.param)) name += std::string("inflate_");
name += get_prec_str(::testing::get<2>(param.param)) + std::string("_");
name += get_prec_str(::testing::get<3>(param.param)) + std::string("_");
name += std::string("spin") + std::to_string(::testing::get<4>(param.param));
name += ::testing::get<5>(param.param) ? "_singlefile" : "_partfile";
name += ::testing::get<6>(param.param) == QUDA_CUDA_FIELD_LOCATION ? "_device" : "_host";
name += ::testing::get<1>(param.param) == QUDA_PARITY_SITE_SUBSET ? "load_parity_" :
"load_full_";
name += ::testing::get<2>(param.param) == QUDA_EVEN_PARITY ? "even_" : "odd_";
if (::testing::get<3>(param.param)) name += std::string("inflate_");
name += get_prec_str(::testing::get<4>(param.param)) + std::string("_");
name += get_prec_str(::testing::get<5>(param.param)) + std::string("_");
name += std::string("spin") + std::to_string(::testing::get<6>(param.param));
name += ::testing::get<7>(param.param) ? "_singlefile" : "_partfile";
name += ::testing::get<8>(param.param) == QUDA_CUDA_FIELD_LOCATION ? "_device" : "_host";
return name;
});

0 comments on commit c70220c

Please sign in to comment.