diff --git a/include/dirac_quda.h b/include/dirac_quda.h index f68b235fde..a5ba6e0966 100644 --- a/include/dirac_quda.h +++ b/include/dirac_quda.h @@ -205,6 +205,66 @@ namespace quda { */ virtual bool isCoarse() const { return false; } + /** + @brief static function that returns if a Dirac type is staggered-type depending on a QudaDiracType + */ + static bool is_wilson_type(QudaDiracType); + + /** + @brief static function that returns if a Dslash type is staggered-type depending on a QudaDslashType + */ + static bool is_wilson_type(QudaDslashType); + + /** + @brief return if the operator is a Wilson-type 4-d operator + */ + bool isWilsonType() const { return Dirac::is_wilson_type(getDiracType()); } + + /** + @brief static function that returns if a Dirac type is staggered-type depending on a QudaDiracType + */ + static bool is_staggered_type(QudaDiracType); + + /** + @brief static function that returns if a Dslash type is staggered-type depending on a QudaDslashType + */ + static bool is_staggered_type(QudaDslashType); + + /** + @brief return if the operator is a staggered operator + */ + bool isStaggered() const { return Dirac::is_staggered_type(getDiracType()); } + + /** + @brief static function that returns if a Dirac type is asqtad depending on a QudaDiracType + */ + static bool is_asqtad(QudaDiracType); + + /** + @brief static function that returns if a Dslash type is asqtaddepending on a QudaDslashType + */ + static bool is_asqtad(QudaDslashType); + + /** + @brief return if the operator is a staggered operator + */ + bool isAsqtad() const { return Dirac::is_asqtad(getDiracType()); } + + /** + @brief static function that returns if a Dirac type is a domain wall operator (5-dimensional) depending on a QudaDiracType + */ + static bool is_dwf(QudaDiracType); + + /** + @brief static function that returns if a Dslash type is a domain wall operator (5-dimensional) depending on a QudaDslashType + */ + static bool is_dwf(QudaDslashType); + + /** + @brief return if the operator is a domain wall operator, that is, 5-dimensional + */ + bool isDwf() const { return Dirac::is_dwf(getDiracType()); } + /** @brief Check parity spinors are usable (check geometry ?) */ @@ -368,9 +428,11 @@ namespace quda { QudaMatPCType getMatPCType() const { return matpcType; } /** - @brief I have no idea what this does + @brief returns the number of stencil applications per dslash application; 1 for operators with + a single hopping term (generally full operators), 2 for composite operators + that consist of two hopping terms (generally PC operators) */ - int getStencilSteps() const; + virtual int getStencilSteps() const = 0; /** @brief sets whether operator is daggered or not @@ -393,6 +455,17 @@ namespace quda { */ virtual QudaDiracType getDiracType() const = 0; + /** @brief returns the Dslash type + + @return Dslash type + */ + QudaDslashType getDslashType() const { return dirac_to_dslash_type(getDiracType()); } + + /** + @brief static function that returns the QudaDslashType corresponding to a QudaDiracType + */ + static QudaDslashType dirac_to_dslash_type(QudaDiracType); + /** @brief Return the one-hop field for staggered operators for MG setup @@ -473,21 +546,22 @@ namespace quda { DiracWilson& operator=(const DiracWilson &dirac); virtual void Dslash(cvector_ref &out, cvector_ref &in, - QudaParity parity) const; + QudaParity parity) const override; virtual void DslashXpay(cvector_ref &out, cvector_ref &in, - QudaParity parity, cvector_ref &x, double k) const; + QudaParity parity, cvector_ref &x, double k) const override; - virtual void M(cvector_ref &out, cvector_ref &in) const; - virtual void MdagM(cvector_ref &out, cvector_ref &in) const; + virtual void M(cvector_ref &out, cvector_ref &in) const override; + virtual void MdagM(cvector_ref &out, cvector_ref &in) const override; virtual void prepare(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; - virtual QudaDiracType getDiracType() const { return QUDA_WILSON_DIRAC; } + virtual int getStencilSteps() const override { return 1; } + virtual QudaDiracType getDiracType() const override { return QUDA_WILSON_DIRAC; } /** * @brief Create the coarse Wilson operator. @@ -507,7 +581,7 @@ namespace quda { * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for Wilson operator */ virtual void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass = 0., - double mu = 0., double mu_factor = 0., bool allow_truncation = false) const; + double mu = 0., double mu_factor = 0., bool allow_truncation = false) const override; }; // Even-odd preconditioned Wilson @@ -521,16 +595,17 @@ namespace quda { virtual ~DiracWilsonPC(); DiracWilsonPC& operator=(const DiracWilsonPC &dirac); - void M(cvector_ref &out, cvector_ref &in) const; - void MdagM(cvector_ref &out, cvector_ref &in) const; + void M(cvector_ref &out, cvector_ref &in) const override; + void MdagM(cvector_ref &out, cvector_ref &in) const override; virtual void prepare(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; - virtual QudaDiracType getDiracType() const { return QUDA_WILSONPC_DIRAC; } + virtual int getStencilSteps() const override { return 2; } + virtual QudaDiracType getDiracType() const override { return QUDA_WILSONPC_DIRAC; } }; // Full clover @@ -538,7 +613,7 @@ namespace quda { protected: CloverField *clover; - void checkParitySpinor(cvector_ref &, cvector_ref &) const; + void checkParitySpinor(cvector_ref &, cvector_ref &) const override; void initConstants(); public: @@ -551,18 +626,19 @@ namespace quda { void Clover(cvector_ref &out, cvector_ref &in, QudaParity parity) const; virtual void DslashXpay(cvector_ref &out, cvector_ref &in, - QudaParity parity, cvector_ref &x, double k) const; + QudaParity parity, cvector_ref &x, double k) const override; - virtual void M(cvector_ref &out, cvector_ref &in) const; - virtual void MdagM(cvector_ref &out, cvector_ref &in) const; + virtual void M(cvector_ref &out, cvector_ref &in) const override; + virtual void MdagM(cvector_ref &out, cvector_ref &in) const override; virtual void prepare(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; - virtual QudaDiracType getDiracType() const { return QUDA_CLOVER_DIRAC; } + virtual int getStencilSteps() const override { return 1; } + virtual QudaDiracType getDiracType() const override { return QUDA_CLOVER_DIRAC; } /** * @brief Update the internal gauge, fat gauge, long gauge, clover field pointer as appropriate. @@ -573,7 +649,7 @@ namespace quda { * @param long_gauge_in Updated long links * @param clover_in Updated clover field */ - virtual void updateFields(GaugeField *gauge_in, GaugeField *, GaugeField *, CloverField *clover_in) + virtual void updateFields(GaugeField *gauge_in, GaugeField *, GaugeField *, CloverField *clover_in) override { DiracWilson::updateFields(gauge_in, nullptr, nullptr, nullptr); clover = clover_in; @@ -597,7 +673,7 @@ namespace quda { * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for clover operator */ void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass = 0., double mu = 0., - double mu_factor = 0., bool allow_truncation = false) const; + double mu_factor = 0., bool allow_truncation = false) const override; /** @brief If managed memory and prefetch is enabled, prefetch @@ -606,7 +682,7 @@ namespace quda { @param[in] mem_space Memory space we are prefetching to @param[in] stream Which stream to run the prefetch in (default 0) */ - virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream = device::get_default_stream()) const; + virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream = device::get_default_stream()) const override; }; // Even-odd preconditioned clover @@ -625,26 +701,27 @@ namespace quda { // Dslash is redefined as A_pp^{-1} D_p\bar{p} virtual void Dslash(cvector_ref &out, cvector_ref &in, - QudaParity parity) const; + QudaParity parity) const override; // out = x + k A_pp^{-1} D_p\bar{p} virtual void DslashXpay(cvector_ref &out, cvector_ref &in, - QudaParity parity, cvector_ref &x, double k) const; + QudaParity parity, cvector_ref &x, double k) const override; // Can implement: M as e.g. : i) tmp_e = A^{-1}_ee D_eo in_o (Dslash) // ii) out_o = in_o + A_oo^{-1} D_oe tmp_e (AXPY) - void M(cvector_ref &out, cvector_ref &in) const; + void M(cvector_ref &out, cvector_ref &in) const override; // squared op - void MdagM(cvector_ref &out, cvector_ref &in) const; + void MdagM(cvector_ref &out, cvector_ref &in) const override; virtual void prepare(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; - virtual QudaDiracType getDiracType() const { return QUDA_CLOVERPC_DIRAC; } + virtual int getStencilSteps() const override { return 2; } + virtual QudaDiracType getDiracType() const override { return QUDA_CLOVERPC_DIRAC; } /** * @brief Create the coarse even-odd preconditioned clover @@ -660,7 +737,7 @@ namespace quda { * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for clover operator */ void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass = 0., double mu = 0., - double mu_factor = 0., bool allow_truncation = false) const; + double mu_factor = 0., bool allow_truncation = false) const override; /** @brief If managed memory and prefetch is enabled, prefetch @@ -671,7 +748,7 @@ namespace quda { @param[in] mem_space Memory space we are prefetching to @param[in] stream Which stream to run the prefetch in (default 0) */ - virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream = device::get_default_stream()) const; + virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream = device::get_default_stream()) const override; }; // Full clover with Hasenbusch Twist @@ -693,10 +770,15 @@ namespace quda { virtual ~DiracCloverHasenbuschTwist(); DiracCloverHasenbuschTwist &operator=(const DiracCloverHasenbuschTwist &dirac); - virtual void M(cvector_ref &out, cvector_ref &in) const; - virtual void MdagM(cvector_ref &out, cvector_ref &in) const; + virtual void M(cvector_ref &out, cvector_ref &in) const override; + virtual void MdagM(cvector_ref &out, cvector_ref &in) const override; - virtual QudaDiracType getDiracType() const { return QUDA_CLOVER_HASENBUSCH_TWIST_DIRAC; } + virtual int getStencilSteps() const override + { + // implemented as separate even, odd D_{eo} D_{oe} + return 2; + } + virtual QudaDiracType getDiracType() const override { return QUDA_CLOVER_HASENBUSCH_TWIST_DIRAC; } /** * @brief Create the coarse clover operator @@ -709,7 +791,7 @@ namespace quda { * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for clover */ void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass = 0., double mu = 0., - double mu_factor = 0., bool allow_truncation = false) const; + double mu_factor = 0., bool allow_truncation = false) const override; }; // Even-odd preconditioned clover @@ -742,12 +824,13 @@ namespace quda { // Can implement: M as e.g. : i) tmp_e = A^{-1}_ee D_eo in_o (Dslash) // ii) out_o = in_o + A_oo^{-1} D_oe tmp_e (AXPY) - void M(cvector_ref &out, cvector_ref &in) const; + void M(cvector_ref &out, cvector_ref &in) const override; // squared op - void MdagM(cvector_ref &out, cvector_ref &in) const; + void MdagM(cvector_ref &out, cvector_ref &in) const override; - virtual QudaDiracType getDiracType() const { return QUDA_CLOVER_HASENBUSCH_TWISTPC_DIRAC; } + virtual int getStencilSteps() const override { return 2; } + virtual QudaDiracType getDiracType() const override { return QUDA_CLOVER_HASENBUSCH_TWISTPC_DIRAC; } /** * @brief Create the coarse even-odd preconditioned clover @@ -763,7 +846,7 @@ namespace quda { * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for clover hasenbusch */ void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass = 0., double mu = 0., - double mu_factor = 0., bool allow_truncation = false) const; + double mu_factor = 0., bool allow_truncation = false) const override; }; // Full domain wall @@ -787,21 +870,23 @@ namespace quda { virtual ~DiracDomainWall(); DiracDomainWall& operator=(const DiracDomainWall &dirac); - void Dslash(cvector_ref &out, cvector_ref &in, QudaParity parity) const; + void Dslash(cvector_ref &out, cvector_ref &in, + QudaParity parity) const override; void DslashXpay(cvector_ref &out, cvector_ref &in, QudaParity parity, - cvector_ref &x, double k) const; + cvector_ref &x, double k) const override; - virtual void M(cvector_ref &out, cvector_ref &in) const; - virtual void MdagM(cvector_ref &out, cvector_ref &in) const; + virtual void M(cvector_ref &out, cvector_ref &in) const override; + virtual void MdagM(cvector_ref &out, cvector_ref &in) const override; virtual void prepare(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; - virtual QudaDiracType getDiracType() const { return QUDA_DOMAIN_WALL_DIRAC; } + virtual int getStencilSteps() const override { return 1; } + virtual QudaDiracType getDiracType() const override { return QUDA_DOMAIN_WALL_DIRAC; } }; // 5d Even-odd preconditioned domain wall @@ -815,16 +900,17 @@ namespace quda { virtual ~DiracDomainWallPC(); DiracDomainWallPC& operator=(const DiracDomainWallPC &dirac); - void M(cvector_ref &out, cvector_ref &in) const; - void MdagM(cvector_ref &out, cvector_ref &in) const; + void M(cvector_ref &out, cvector_ref &in) const override; + void MdagM(cvector_ref &out, cvector_ref &in) const override; virtual void prepare(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; - virtual QudaDiracType getDiracType() const { return QUDA_DOMAIN_WALLPC_DIRAC; } + virtual int getStencilSteps() const override { return 2; } + virtual QudaDiracType getDiracType() const override { return QUDA_DOMAIN_WALLPC_DIRAC; } }; // Full domain wall, but with 4-d parity ordered fields @@ -837,23 +923,25 @@ namespace quda { virtual ~DiracDomainWall4D(); DiracDomainWall4D &operator=(const DiracDomainWall4D &dirac); - void Dslash4(cvector_ref &out, cvector_ref &in, QudaParity parity) const; + void Dslash4(cvector_ref &out, cvector_ref &in, + QudaParity parity) const override; void Dslash5(cvector_ref &out, cvector_ref &in) const; void Dslash4Xpay(cvector_ref &out, cvector_ref &in, QudaParity parity, cvector_ref &x, double k) const; void Dslash5Xpay(cvector_ref &out, cvector_ref &in, cvector_ref &x, double k) const; - void M(cvector_ref &out, cvector_ref &in) const; - void MdagM(cvector_ref &out, cvector_ref &in) const; + void M(cvector_ref &out, cvector_ref &in) const override; + void MdagM(cvector_ref &out, cvector_ref &in) const override; virtual void prepare(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; - virtual QudaDiracType getDiracType() const { return QUDA_DOMAIN_WALL_4D_DIRAC; } + virtual int getStencilSteps() const override { return 1; } + virtual QudaDiracType getDiracType() const override { return QUDA_DOMAIN_WALL_4D_DIRAC; } }; // 4d Even-odd preconditioned domain wall @@ -870,16 +958,17 @@ namespace quda { void M5invXpay(cvector_ref &out, cvector_ref &in, cvector_ref &x, double k) const; - void M(cvector_ref &out, cvector_ref &in) const; - void MdagM(cvector_ref &out, cvector_ref &in) const; + void M(cvector_ref &out, cvector_ref &in) const override; + void MdagM(cvector_ref &out, cvector_ref &in) const override; virtual void prepare(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; - virtual QudaDiracType getDiracType() const { return QUDA_DOMAIN_WALL_4DPC_DIRAC; } + virtual int getStencilSteps() const override { return 2; } + virtual QudaDiracType getDiracType() const override { return QUDA_DOMAIN_WALL_4DPC_DIRAC; } }; // Full Mobius @@ -887,47 +976,49 @@ namespace quda { protected: //Mobius coefficients - Complex b_5[QUDA_MAX_DWF_LS]; - Complex c_5[QUDA_MAX_DWF_LS]; - - /** - Whether we are using classical Mobius with constant real-valued - b and c coefficients, or zMobius with complex-valued variable - coefficients - */ - bool zMobius; - - double mobius_kappa_b; - double mobius_kappa_c; - double mobius_kappa; - - public: - DiracMobius(const DiracParam ¶m); - // DiracMobius(const DiracMobius &dirac); - // virtual ~DiracMobius(); - // DiracMobius& operator=(const DiracMobius &dirac); - - void Dslash4(cvector_ref &out, cvector_ref &in, QudaParity parity) const; - void Dslash4pre(cvector_ref &out, cvector_ref &in) const; - void Dslash5(cvector_ref &out, cvector_ref &in) const; - - void Dslash4Xpay(cvector_ref &out, cvector_ref &in, QudaParity parity, - cvector_ref &x, double k) const; - void Dslash4preXpay(cvector_ref &out, cvector_ref &in, - cvector_ref &x, double k) const; - void Dslash5Xpay(cvector_ref &out, cvector_ref &in, - cvector_ref &x, double k) const; - - virtual void M(cvector_ref &out, cvector_ref &in) const; - virtual void MdagM(cvector_ref &out, cvector_ref &in) const; - - virtual void prepare(cvector_ref &out, cvector_ref &in, - cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; - virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; - - virtual QudaDiracType getDiracType() const { return QUDA_MOBIUS_DOMAIN_WALL_DIRAC; } + Complex b_5[QUDA_MAX_DWF_LS]; + Complex c_5[QUDA_MAX_DWF_LS]; + + /** + Whether we are using classical Mobius with constant real-valued + b and c coefficients, or zMobius with complex-valued variable + coefficients + */ + bool zMobius; + + double mobius_kappa_b; + double mobius_kappa_c; + double mobius_kappa; + + public: + DiracMobius(const DiracParam ¶m); + // DiracMobius(const DiracMobius &dirac); + // virtual ~DiracMobius(); + // DiracMobius& operator=(const DiracMobius &dirac); + + void Dslash4(cvector_ref &out, cvector_ref &in, + QudaParity parity) const override; + void Dslash4pre(cvector_ref &out, cvector_ref &in) const; + void Dslash5(cvector_ref &out, cvector_ref &in) const; + + void Dslash4Xpay(cvector_ref &out, cvector_ref &in, QudaParity parity, + cvector_ref &x, double k) const; + void Dslash4preXpay(cvector_ref &out, cvector_ref &in, + cvector_ref &x, double k) const; + void Dslash5Xpay(cvector_ref &out, cvector_ref &in, + cvector_ref &x, double k) const; + + virtual void M(cvector_ref &out, cvector_ref &in) const override; + virtual void MdagM(cvector_ref &out, cvector_ref &in) const override; + + virtual void prepare(cvector_ref &out, cvector_ref &in, + cvector_ref &x, cvector_ref &b, + const QudaSolutionType solType) const override; + virtual void reconstruct(cvector_ref &x, cvector_ref &b, + const QudaSolutionType solType) const override; + + virtual int getStencilSteps() const override { return 1; } + virtual QudaDiracType getDiracType() const override { return QUDA_MOBIUS_DOMAIN_WALL_DIRAC; } }; // 4d even-odd preconditioned Mobius domain wall @@ -963,20 +1054,21 @@ namespace quda { QudaParity parity, cvector_ref &x, double a, cvector_ref &y) const; - void MdagMLocal(cvector_ref &out, cvector_ref &in) const; + void MdagMLocal(cvector_ref &out, cvector_ref &in) const override; - void M(cvector_ref &out, cvector_ref &in) const; - void MdagM(cvector_ref &out, cvector_ref &in) const; + void M(cvector_ref &out, cvector_ref &in) const override; + void MdagM(cvector_ref &out, cvector_ref &in) const override; // this needs to be specialized for Mobius since we have a fused MdagM kernel - void MMdag(cvector_ref &out, cvector_ref &in) const; + void MMdag(cvector_ref &out, cvector_ref &in) const override; virtual void prepare(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; - virtual QudaDiracType getDiracType() const { return QUDA_MOBIUS_DOMAIN_WALLPC_DIRAC; } + virtual int getStencilSteps() const override { return 2; } + virtual QudaDiracType getDiracType() const override { return QUDA_MOBIUS_DOMAIN_WALLPC_DIRAC; } }; // Full Mobius EOFA @@ -1003,16 +1095,17 @@ namespace quda { void m5_eofa_xpay(cvector_ref &out, cvector_ref &in, cvector_ref &x, double a = -1.) const; - virtual void M(cvector_ref &out, cvector_ref &in) const; - virtual void MdagM(cvector_ref &out, cvector_ref &in) const; + virtual void M(cvector_ref &out, cvector_ref &in) const override; + virtual void MdagM(cvector_ref &out, cvector_ref &in) const override; virtual void prepare(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; - virtual QudaDiracType getDiracType() const { return QUDA_MOBIUS_DOMAIN_WALL_EOFA_DIRAC; } + virtual int getStencilSteps() const override { return 1; } + virtual QudaDiracType getDiracType() const override { return QUDA_MOBIUS_DOMAIN_WALL_EOFA_DIRAC; } }; // 4d Even-odd preconditioned Mobius domain wall with EOFA @@ -1026,19 +1119,20 @@ namespace quda { void m5inv_eofa_xpay(cvector_ref &out, cvector_ref &in, cvector_ref &x, double a = -1.) const; - void M(cvector_ref &out, cvector_ref &in) const; - void MdagM(cvector_ref &out, cvector_ref &in) const; + void M(cvector_ref &out, cvector_ref &in) const override; + void MdagM(cvector_ref &out, cvector_ref &in) const override; // ye = Mee * xe + Meo * xo, yo = Moo * xo + Moe * xe void full_dslash(cvector_ref &out, cvector_ref &in) const; virtual void prepare(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; - virtual QudaDiracType getDiracType() const { return QUDA_MOBIUS_DOMAIN_WALLPC_EOFA_DIRAC; } + virtual int getStencilSteps() const override { return 2; } + virtual QudaDiracType getDiracType() const override { return QUDA_MOBIUS_DOMAIN_WALLPC_EOFA_DIRAC; } }; void gamma5(cvector_ref &out, cvector_ref &in); @@ -1060,9 +1154,9 @@ namespace quda { void twistedApply(cvector_ref &out, cvector_ref &in, const QudaTwistGamma5Type twistType) const; virtual void Dslash(cvector_ref &out, cvector_ref &in, - QudaParity parity) const; + QudaParity parity) const override; virtual void DslashXpay(cvector_ref &out, cvector_ref &in, - QudaParity parity, cvector_ref &x, double k) const; + QudaParity parity, cvector_ref &x, double k) const override; public: DiracTwistedMass(const DiracTwistedMass &dirac); @@ -1072,18 +1166,19 @@ namespace quda { void Twist(cvector_ref &out, cvector_ref &in) const; - virtual void M(cvector_ref &out, cvector_ref &in) const; - virtual void MdagM(cvector_ref &out, cvector_ref &in) const; + virtual void M(cvector_ref &out, cvector_ref &in) const override; + virtual void MdagM(cvector_ref &out, cvector_ref &in) const override; virtual void prepare(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; - virtual QudaDiracType getDiracType() const { return QUDA_TWISTED_MASS_DIRAC; } + virtual int getStencilSteps() const override { return 1; } + virtual QudaDiracType getDiracType() const override { return QUDA_TWISTED_MASS_DIRAC; } - double Mu() const { return mu; } + double Mu() const override { return mu; } /** * @brief Create the coarse twisted-mass operator @@ -1106,7 +1201,7 @@ namespace quda { * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for twisted mass */ void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu, - double mu_factor = 0., bool allow_trunation = false) const; + double mu_factor = 0., bool allow_trunation = false) const override; }; // Even-odd preconditioned twisted mass @@ -1117,24 +1212,25 @@ namespace quda { DiracTwistedMassPC(const DiracParam ¶m, const int nDim); virtual ~DiracTwistedMassPC(); - DiracTwistedMassPC& operator=(const DiracTwistedMassPC &dirac); + DiracTwistedMassPC &operator=(const DiracTwistedMassPC &dirac); void TwistInv(cvector_ref &out, cvector_ref &in) const; virtual void Dslash(cvector_ref &out, cvector_ref &in, - QudaParity parity) const; + QudaParity parity) const override; virtual void DslashXpay(cvector_ref &out, cvector_ref &in, - QudaParity parity, cvector_ref &x, double k) const; - void M(cvector_ref &out, cvector_ref &in) const; - void MdagM(cvector_ref &out, cvector_ref &in) const; + QudaParity parity, cvector_ref &x, double k) const override; + void M(cvector_ref &out, cvector_ref &in) const override; + void MdagM(cvector_ref &out, cvector_ref &in) const override; virtual void prepare(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; - virtual QudaDiracType getDiracType() const { return QUDA_TWISTED_MASSPC_DIRAC; } + virtual int getStencilSteps() const override { return 2; } + virtual QudaDiracType getDiracType() const override { return QUDA_TWISTED_MASSPC_DIRAC; } /** * @brief Create the coarse even-odd preconditioned twisted-mass @@ -1150,7 +1246,7 @@ namespace quda { * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for twisted mass */ void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu, - double mu_factor = 0., bool allow_truncation = false) const; + double mu_factor = 0., bool allow_truncation = false) const override; }; // Full twisted mass with a clover term @@ -1161,7 +1257,7 @@ namespace quda { double epsilon; double tm_rho; CloverField *clover; - void checkParitySpinor(cvector_ref &, cvector_ref &) const; + void checkParitySpinor(cvector_ref &, cvector_ref &) const override; void twistedCloverApply(cvector_ref &out, cvector_ref &in, QudaTwistGamma5Type twistType, QudaParity parity) const; @@ -1169,26 +1265,27 @@ namespace quda { DiracTwistedClover(const DiracTwistedClover &dirac); DiracTwistedClover(const DiracParam ¶m, const int nDim); virtual ~DiracTwistedClover(); - DiracTwistedClover& operator=(const DiracTwistedClover &dirac); + DiracTwistedClover &operator=(const DiracTwistedClover &dirac); void TwistClover(cvector_ref &out, cvector_ref &in, QudaParity parity) const; virtual void Dslash(cvector_ref &out, cvector_ref &in, - QudaParity parity) const; + QudaParity parity) const override; virtual void DslashXpay(cvector_ref &out, cvector_ref &in, - QudaParity parity, cvector_ref &x, double k) const; - virtual void M(cvector_ref &out, cvector_ref &in) const; - virtual void MdagM(cvector_ref &out, cvector_ref &in) const; + QudaParity parity, cvector_ref &x, double k) const override; + virtual void M(cvector_ref &out, cvector_ref &in) const override; + virtual void MdagM(cvector_ref &out, cvector_ref &in) const override; virtual void prepare(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; - virtual QudaDiracType getDiracType() const { return QUDA_TWISTED_CLOVER_DIRAC; } + virtual int getStencilSteps() const override { return 1; } + virtual QudaDiracType getDiracType() const override { return QUDA_TWISTED_CLOVER_DIRAC; } - double Mu() const { return mu; } + double Mu() const override { return mu; } /** * @brief Update the internal gauge, fat gauge, long gauge, clover field pointer as appropriate. @@ -1199,7 +1296,7 @@ namespace quda { * @param long_gauge_in Updated long links * @param clover_in Updated clover field */ - virtual void updateFields(GaugeField *gauge_in, GaugeField *, GaugeField *, CloverField *clover_in) + virtual void updateFields(GaugeField *gauge_in, GaugeField *, GaugeField *, CloverField *clover_in) override { DiracWilson::updateFields(gauge_in, nullptr, nullptr, nullptr); clover = clover_in; @@ -1226,7 +1323,7 @@ namespace quda { * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for twisted clover */ void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu, - double mu_factor = 0., bool allow_truncation = false) const; + double mu_factor = 0., bool allow_truncation = false) const override; /** @brief If managed memory and prefetch is enabled, prefetch @@ -1235,7 +1332,7 @@ namespace quda { @param[in] mem_space Memory space we are prefetching to @param[in] stream Which stream to run the prefetch in (default 0) */ - virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream = device::get_default_stream()) const; + virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream = device::get_default_stream()) const override; }; // Even-odd preconditioned twisted mass with a clover term @@ -1266,19 +1363,20 @@ namespace quda { QudaParity parity, cvector_ref &x, double k) const; virtual void Dslash(cvector_ref &out, cvector_ref &in, - QudaParity parity) const; + QudaParity parity) const override; virtual void DslashXpay(cvector_ref &out, cvector_ref &in, - QudaParity parity, cvector_ref &x, double k) const; - void M(cvector_ref &out, cvector_ref &in) const; - void MdagM(cvector_ref &out, cvector_ref &in) const; + QudaParity parity, cvector_ref &x, double k) const override; + void M(cvector_ref &out, cvector_ref &in) const override; + void MdagM(cvector_ref &out, cvector_ref &in) const override; virtual void prepare(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; - virtual QudaDiracType getDiracType() const { return QUDA_TWISTED_CLOVERPC_DIRAC; } + virtual int getStencilSteps() const override { return 2; } + virtual QudaDiracType getDiracType() const override { return QUDA_TWISTED_CLOVERPC_DIRAC; } /** * @brief Create the coarse even-odd preconditioned twisted-clover @@ -1296,7 +1394,7 @@ namespace quda { * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for twisted clover */ void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu, - double mu_factor = 0., bool allow_truncation = false) const; + double mu_factor = 0., bool allow_truncation = false) const override; /** @brief If managed memory and prefetch is enabled, prefetch @@ -1307,14 +1405,14 @@ namespace quda { @param[in] mem_space Memory space we are prefetching to @param[in] stream Which stream to run the prefetch in (default 0) */ - virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream = device::get_default_stream()) const; + virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream = device::get_default_stream()) const override; }; // Full staggered - class DiracStaggered : public Dirac { + class DiracStaggered : public Dirac + { protected: - public: DiracStaggered(const DiracParam ¶m); DiracStaggered(const DiracStaggered &dirac); @@ -1322,26 +1420,27 @@ namespace quda { DiracStaggered& operator=(const DiracStaggered &dirac); virtual void Dslash(cvector_ref &out, cvector_ref &in, - QudaParity parity) const; + QudaParity parity) const override; virtual void DslashXpay(cvector_ref &out, cvector_ref &in, - QudaParity parity, cvector_ref &x, double k) const; - virtual void M(cvector_ref &out, cvector_ref &in) const; - virtual void MdagM(cvector_ref &out, cvector_ref &in) const; + QudaParity parity, cvector_ref &x, double k) const override; + virtual void M(cvector_ref &out, cvector_ref &in) const override; + virtual void MdagM(cvector_ref &out, cvector_ref &in) const override; virtual void prepare(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; - virtual QudaDiracType getDiracType() const { return QUDA_STAGGERED_DIRAC; } + virtual int getStencilSteps() const override { return 1; } + virtual QudaDiracType getDiracType() const override { return QUDA_STAGGERED_DIRAC; } /** @brief Return the one-hop field for staggered operators for MG setup @return Gauge field */ - virtual GaugeField *getStaggeredShortLinkField() const { return gauge; } + virtual GaugeField *getStaggeredShortLinkField() const override { return gauge; } /** * @brief Create the coarse staggered operator. @@ -1365,7 +1464,7 @@ namespace quda { * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for staggered */ void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu = 0., - double mu_factor = 0., bool allow_truncation = false) const; + double mu_factor = 0., bool allow_truncation = false) const override; /** * @brief Create two-link staggered quark smearing operator @@ -1378,11 +1477,12 @@ namespace quda { * @param[in] parity Parity flag */ void SmearOp(cvector_ref &out, cvector_ref &in, double a, double b, - int t0, QudaParity parity) const; + int t0, QudaParity parity) const override; }; // Even-odd preconditioned staggered - class DiracStaggeredPC : public DiracStaggered { + class DiracStaggeredPC : public DiracStaggered + { protected: @@ -1392,18 +1492,19 @@ namespace quda { virtual ~DiracStaggeredPC(); DiracStaggeredPC& operator=(const DiracStaggeredPC &dirac); - virtual void M(cvector_ref &out, cvector_ref &in) const; - virtual void MdagM(cvector_ref &out, cvector_ref &in) const; + virtual void M(cvector_ref &out, cvector_ref &in) const override; + virtual void MdagM(cvector_ref &out, cvector_ref &in) const override; virtual void prepare(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; - virtual QudaDiracType getDiracType() const { return QUDA_STAGGEREDPC_DIRAC; } + virtual int getStencilSteps() const override { return 2; } + virtual QudaDiracType getDiracType() const override { return QUDA_STAGGEREDPC_DIRAC; } - virtual bool hermitian() const { return true; } + virtual bool hermitian() const override { return true; } /** * @brief Create the coarse staggered operator. @@ -1427,7 +1528,7 @@ namespace quda { * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for staggered */ void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu = 0., - double mu_factor = 0., bool allow_truncation = false) const; + double mu_factor = 0., bool allow_truncation = false) const override; }; // Kahler-Dirac preconditioned staggered @@ -1445,31 +1546,32 @@ namespace quda { virtual ~DiracStaggeredKD(); DiracStaggeredKD &operator=(const DiracStaggeredKD &dirac); - virtual bool hasDslash() const { return false; } + virtual bool hasDslash() const override { return false; } virtual void Dslash(cvector_ref &out, cvector_ref &in, - QudaParity parity) const; + QudaParity parity) const override; virtual void DslashXpay(cvector_ref &out, cvector_ref &in, - QudaParity parity, cvector_ref &x, double k) const; - virtual void M(cvector_ref &out, cvector_ref &in) const; - virtual void MdagM(cvector_ref &out, cvector_ref &in) const; + QudaParity parity, cvector_ref &x, double k) const override; + virtual void M(cvector_ref &out, cvector_ref &in) const override; + virtual void MdagM(cvector_ref &out, cvector_ref &in) const override; void KahlerDiracInv(cvector_ref &out, cvector_ref &in) const; virtual void prepare(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void prepareSpecialMG(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstructSpecialMG(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; - virtual bool hasSpecialMG() const { return true; } + virtual bool hasSpecialMG() const override { return true; } - virtual QudaDiracType getDiracType() const { return QUDA_STAGGEREDKD_DIRAC; } + virtual int getStencilSteps() const override { return 1; } + virtual QudaDiracType getDiracType() const override { return QUDA_STAGGEREDKD_DIRAC; } /** * @brief Update the internal gauge, fat gauge, long gauge, clover field pointer as appropriate. @@ -1481,7 +1583,7 @@ namespace quda { * @param clover_in Updated clover field */ virtual void updateFields(GaugeField *gauge_in, GaugeField *fat_gauge_in, GaugeField *long_gauge_in, - CloverField *clover_in); + CloverField *clover_in) override; /** * @brief Create the coarse staggered KD operator. @@ -1503,7 +1605,7 @@ namespace quda { * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for staggered */ void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu = 0., - double mu_factor = 0., bool allow_truncation = false) const; + double mu_factor = 0., bool allow_truncation = false) const override; /** @brief If managed memory and prefetch is enabled, prefetch @@ -1514,7 +1616,7 @@ namespace quda { @param[in] mem_space Memory space we are prefetching to @param[in] stream Which stream to run the prefetch in (default 0) */ - virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream = device::get_default_stream()) const; + virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream = device::get_default_stream()) const override; }; // Full staggered @@ -1531,33 +1633,34 @@ namespace quda { DiracImprovedStaggered& operator=(const DiracImprovedStaggered &dirac); virtual void Dslash(cvector_ref &out, cvector_ref &in, - QudaParity parity) const; + QudaParity parity) const override; virtual void DslashXpay(cvector_ref &out, cvector_ref &in, - QudaParity parity, cvector_ref &x, double k) const; - virtual void M(cvector_ref &out, cvector_ref &in) const; - virtual void MdagM(cvector_ref &out, cvector_ref &in) const; + QudaParity parity, cvector_ref &x, double k) const override; + virtual void M(cvector_ref &out, cvector_ref &in) const override; + virtual void MdagM(cvector_ref &out, cvector_ref &in) const override; virtual void prepare(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; - virtual QudaDiracType getDiracType() const { return QUDA_ASQTAD_DIRAC; } + virtual int getStencilSteps() const override { return 1; } + virtual QudaDiracType getDiracType() const override { return QUDA_ASQTAD_DIRAC; } /** @brief Return the one-hop field for staggered operators for MG setup @return fat link field */ - virtual GaugeField *getStaggeredShortLinkField() const { return fatGauge; } + virtual GaugeField *getStaggeredShortLinkField() const override { return fatGauge; } /** @brief return the long link field for staggered operators for MG setup @return long link field */ - virtual GaugeField *getStaggeredLongLinkField() const { return longGauge; } + virtual GaugeField *getStaggeredLongLinkField() const override { return longGauge; } /** * @brief Update the internal gauge, fat gauge, long gauge, clover field pointer as appropriate. @@ -1568,7 +1671,7 @@ namespace quda { * @param long_gauge_in Updated long links * @param clover_in Updated clover field */ - virtual void updateFields(GaugeField *, GaugeField *fat_gauge_in, GaugeField *long_gauge_in, CloverField *) + virtual void updateFields(GaugeField *, GaugeField *fat_gauge_in, GaugeField *long_gauge_in, CloverField *) override { Dirac::updateFields(fat_gauge_in, nullptr, nullptr, nullptr); fatGauge = fat_gauge_in; @@ -1598,7 +1701,7 @@ namespace quda { * @param allow_truncation [in] whether or not we let coarsening drop improvements, dropping long links here */ void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu, - double mu_factor, bool allow_truncation) const; + double mu_factor, bool allow_truncation) const override; /** @brief If managed memory and prefetch is enabled, prefetch @@ -1607,7 +1710,7 @@ namespace quda { @param[in] mem_space Memory space we are prefetching to @param[in] stream Which stream to run the prefetch in (default 0) */ - virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream = device::get_default_stream()) const; + virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream = device::get_default_stream()) const override; /** * @brief Create two-link staggered quark smearing operator @@ -1620,7 +1723,7 @@ namespace quda { * @param[in] parity Parity flag */ void SmearOp(cvector_ref &out, cvector_ref &in, double a, double b, - int t0, QudaParity parity) const; + int t0, QudaParity parity) const override; }; // Even-odd preconditioned staggered @@ -1634,18 +1737,19 @@ namespace quda { virtual ~DiracImprovedStaggeredPC(); DiracImprovedStaggeredPC& operator=(const DiracImprovedStaggeredPC &dirac); - virtual void M(cvector_ref &out, cvector_ref &in) const; - virtual void MdagM(cvector_ref &out, cvector_ref &in) const; + virtual void M(cvector_ref &out, cvector_ref &in) const override; + virtual void MdagM(cvector_ref &out, cvector_ref &in) const override; virtual void prepare(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; - virtual QudaDiracType getDiracType() const { return QUDA_ASQTADPC_DIRAC; } + virtual int getStencilSteps() const override { return 2; } + virtual QudaDiracType getDiracType() const override { return QUDA_ASQTADPC_DIRAC; } - virtual bool hermitian() const { return true; } + virtual bool hermitian() const override { return true; } /** * @brief Create the coarse staggered operator. @@ -1669,7 +1773,7 @@ namespace quda { * @param allow_truncation [in] whether or not we let coarsening drop improvements, dropping long links here */ void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu, - double mu_factor, bool allow_truncation) const; + double mu_factor, bool allow_truncation) const override; }; // Kahler-Dirac preconditioned staggered @@ -1686,31 +1790,32 @@ namespace quda { virtual ~DiracImprovedStaggeredKD(); DiracImprovedStaggeredKD &operator=(const DiracImprovedStaggeredKD &dirac); - virtual bool hasDslash() const { return false; } + virtual bool hasDslash() const override { return false; } virtual void Dslash(cvector_ref &out, cvector_ref &in, - QudaParity parity) const; + QudaParity parity) const override; virtual void DslashXpay(cvector_ref &out, cvector_ref &in, - QudaParity parity, cvector_ref &x, double k) const; - virtual void M(cvector_ref &out, cvector_ref &in) const; - virtual void MdagM(cvector_ref &out, cvector_ref &in) const; + QudaParity parity, cvector_ref &x, double k) const override; + virtual void M(cvector_ref &out, cvector_ref &in) const override; + virtual void MdagM(cvector_ref &out, cvector_ref &in) const override; void KahlerDiracInv(cvector_ref &out, cvector_ref &in) const; virtual void prepare(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void prepareSpecialMG(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstructSpecialMG(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; - virtual bool hasSpecialMG() const { return true; } + virtual bool hasSpecialMG() const override { return true; } - virtual QudaDiracType getDiracType() const { return QUDA_ASQTADKD_DIRAC; } + virtual int getStencilSteps() const override { return 1; } + virtual QudaDiracType getDiracType() const override { return QUDA_ASQTADKD_DIRAC; } /** * @brief Update the internal gauge, fat gauge, long gauge, clover field pointer as appropriate. @@ -1722,7 +1827,7 @@ namespace quda { * @param clover_in Updated clover field */ virtual void updateFields(GaugeField *gauge_in, GaugeField *fat_gauge_in, GaugeField *long_gauge_in, - CloverField *clover_in); + CloverField *clover_in) override; /** * @brief Create the coarse improved staggered KD operator. @@ -1744,7 +1849,7 @@ namespace quda { * @param allow_truncation [in] whether or not we let coarsening drop improvements, dropping long for asqtad */ void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu, - double mu_factor, bool allow_truncation) const; + double mu_factor, bool allow_truncation) const override; /** @brief If managed memory and prefetch is enabled, prefetch @@ -1755,7 +1860,7 @@ namespace quda { @param[in] mem_space Memory space we are prefetching to @param[in] stream Which stream to run the prefetch in (default 0) */ - virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream = device::get_default_stream()) const; + virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream = device::get_default_stream()) const override; }; /** @@ -1824,10 +1929,10 @@ namespace quda { void createYhat(bool gpu = true) const; public: - double Mass() const { return mass; } - double Mu() const { return mu; } - double MuFactor() const { return mu_factor; } - bool AllowTruncation() const { return allow_truncation; } + double Mass() const override { return mass; } + double Mu() const override { return mu; } + double MuFactor() const override { return mu_factor; } + bool AllowTruncation() const override { return allow_truncation; } /** @param[in] param Parameters defining this operator @@ -1859,7 +1964,7 @@ namespace quda { DiracCoarse(const DiracCoarse &dirac, const DiracParam ¶m); virtual ~DiracCoarse(); - virtual bool isCoarse() const { return true; } + virtual bool isCoarse() const override { return true; } /** @brief Apply the coarse clover operator @@ -1884,7 +1989,7 @@ namespace quda { @param[parity] parity Parity which we are applying the operator to */ virtual void Dslash(cvector_ref &out, cvector_ref &in, - QudaParity parity) const; + QudaParity parity) const override; /** @brief Apply DslashXpay out = (D * in + A * x) @@ -1895,31 +2000,32 @@ namespace quda { @param[in] k scalar multiplier */ virtual void DslashXpay(cvector_ref &out, cvector_ref &in, - QudaParity parity, cvector_ref &x, double k) const; + QudaParity parity, cvector_ref &x, double k) const override; /** @brief Apply the full operator @param[out] out output vector, out = M * in @param[in] in input vector */ - virtual void M(cvector_ref &out, cvector_ref &in) const; + virtual void M(cvector_ref &out, cvector_ref &in) const override; /** @brief Apply the normal full operator @param[out] out output vector, out = M * in @param[in] in input vector */ - virtual void MdagM(cvector_ref &out, cvector_ref &in) const; + virtual void MdagM(cvector_ref &out, cvector_ref &in) const override; virtual void prepare(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; - virtual QudaDiracType getDiracType() const { return QUDA_COARSE_DIRAC; } + virtual int getStencilSteps() const override { return 1; } + virtual QudaDiracType getDiracType() const override { return QUDA_COARSE_DIRAC; } - virtual void updateFields(GaugeField *gauge_in, GaugeField *, GaugeField *, CloverField *) + virtual void updateFields(GaugeField *gauge_in, GaugeField *, GaugeField *, CloverField *) override { Dirac::updateFields(gauge_in, nullptr, nullptr, nullptr); warningQuda("Coarse gauge links cannot be trivially updated for DiracCoarse(PC). Perform an MG update instead."); @@ -1938,7 +2044,7 @@ namespace quda { * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for coarse op */ void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu, - double mu_factor = 0., bool allow_truncation = false) const; + double mu_factor = 0., bool allow_truncation = false) const override; /** * @brief Create the precondtioned coarse operator @@ -1957,7 +2063,7 @@ namespace quda { @param[in] mem_space Memory space we are prefetching to @param[in] stream Which stream to run the prefetch in (default 0) */ - virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream = device::get_default_stream()) const; + virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream = device::get_default_stream()) const override; /** @brief If use_mma and the batch size is larger than 1, actually apply coarse dslash with MMA @@ -2010,7 +2116,8 @@ namespace quda { @param[in] in Input field @param[parity] parity Parity which we are applying the operator to */ - void Dslash(cvector_ref &out, cvector_ref &in, QudaParity parity) const; + void Dslash(cvector_ref &out, cvector_ref &in, + QudaParity parity) const override; /** @brief Apply preconditioned DslashXpay out = (x + k * D * in) @@ -2021,29 +2128,30 @@ namespace quda { @param[in] k scalar multiplier */ void DslashXpay(cvector_ref &out, cvector_ref &in, QudaParity parity, - cvector_ref &x, double k) const; + cvector_ref &x, double k) const override; /** @brief Apply the preconditioned operator @param[out] out output vector, out = M * in @param[in] in input vector */ - void M(cvector_ref &out, cvector_ref &in) const; + void M(cvector_ref &out, cvector_ref &in) const override; /** @brief Apply the preconditioned full operator @param[out] out output vector, out = M * in @param[in] in input vector */ - void MdagM(cvector_ref &out, cvector_ref &in) const; + void MdagM(cvector_ref &out, cvector_ref &in) const override; virtual void prepare(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; - virtual QudaDiracType getDiracType() const { return QUDA_COARSEPC_DIRAC; } + virtual int getStencilSteps() const override { return 2; } + virtual QudaDiracType getDiracType() const override { return QUDA_COARSEPC_DIRAC; } /** * @brief Create the coarse even-odd preconditioned coarse @@ -2061,7 +2169,7 @@ namespace quda { * @param allow_truncation [in] whether or not we let coarsening drop improvements, none available for coarse op */ void createCoarseOp(GaugeField &Y, GaugeField &X, const Transfer &T, double kappa, double mass, double mu, - double mu_factor = 0., bool allow_truncation = false) const; + double mu_factor = 0., bool allow_truncation = false) const override; /** @brief If managed memory and prefetch is enabled, prefetch @@ -2070,10 +2178,9 @@ namespace quda { @param[in] mem_space Memory space we are prefetching to @param[in] stream Which stream to run the prefetch in (default 0) */ - virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream = device::get_default_stream()) const; + virtual void prefetch(QudaFieldLocation mem_space, qudaStream_t stream = device::get_default_stream()) const override; }; - /** @brief Full Gauge Laplace operator. Although not a Dirac operator per se, it's a linear operator so it's conventient to @@ -2089,20 +2196,21 @@ namespace quda { GaugeLaplace& operator=(const GaugeLaplace &laplace); virtual void Dslash(cvector_ref &out, cvector_ref &in, - QudaParity parity) const; + QudaParity parity) const override; virtual void DslashXpay(cvector_ref &out, cvector_ref &in, - QudaParity parity, cvector_ref &x, double k) const; - virtual void M(cvector_ref &out, cvector_ref &in) const; - virtual void MdagM(cvector_ref &out, cvector_ref &in) const; + QudaParity parity, cvector_ref &x, double k) const override; + virtual void M(cvector_ref &out, cvector_ref &in) const override; + virtual void MdagM(cvector_ref &out, cvector_ref &in) const override; virtual void prepare(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; - virtual bool hermitian() const { return true; } + const QudaSolutionType solType) const override; + virtual bool hermitian() const override { return true; } - virtual QudaDiracType getDiracType() const { return QUDA_GAUGE_LAPLACE_DIRAC; } + virtual int getStencilSteps() const override { return 1; } + virtual QudaDiracType getDiracType() const override { return QUDA_GAUGE_LAPLACE_DIRAC; } }; /** @@ -2116,17 +2224,18 @@ namespace quda { virtual ~GaugeLaplacePC(); GaugeLaplacePC& operator=(const GaugeLaplacePC &laplace); - void M(cvector_ref &out, cvector_ref &in) const; - void MdagM(cvector_ref &out, cvector_ref &in) const; + void M(cvector_ref &out, cvector_ref &in) const override; + void MdagM(cvector_ref &out, cvector_ref &in) const override; virtual void prepare(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; - virtual bool hermitian() const { return true; } + const QudaSolutionType solType) const override; + virtual bool hermitian() const override { return true; } - virtual QudaDiracType getDiracType() const { return QUDA_GAUGE_LAPLACEPC_DIRAC; } + virtual int getStencilSteps() const override { return 2; } + virtual QudaDiracType getDiracType() const override { return QUDA_GAUGE_LAPLACEPC_DIRAC; } }; /** @@ -2134,7 +2243,8 @@ namespace quda { operator per se, it's a linear operator so it's conventient to put in the Dirac operator abstraction. */ - class GaugeCovDev : public Dirac { + class GaugeCovDev : public Dirac + { protected: int covdev_mu; @@ -2152,18 +2262,20 @@ namespace quda { virtual void MdagMCD(cvector_ref &out, cvector_ref &in, const int mu) const; virtual void Dslash(cvector_ref &out, cvector_ref &in, - QudaParity parity) const; + QudaParity parity) const override; virtual void DslashXpay(cvector_ref &out, cvector_ref &in, - QudaParity parity, cvector_ref &x, double k) const; - virtual void M(cvector_ref &out, cvector_ref &in) const; - virtual void MdagM(cvector_ref &out, cvector_ref &in) const; + QudaParity parity, cvector_ref &x, double k) const override; + virtual void M(cvector_ref &out, cvector_ref &in) const override; + virtual void MdagM(cvector_ref &out, cvector_ref &in) const override; virtual void prepare(cvector_ref &out, cvector_ref &in, cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; + const QudaSolutionType solType) const override; virtual void reconstruct(cvector_ref &x, cvector_ref &b, - const QudaSolutionType solType) const; - virtual QudaDiracType getDiracType() const { return QUDA_GAUGE_COVDEV_DIRAC; } + const QudaSolutionType solType) const override; + + virtual int getStencilSteps() const override { return 1; } + virtual QudaDiracType getDiracType() const override { return QUDA_GAUGE_COVDEV_DIRAC; } }; // Functor base class for applying a given Dirac matrix (M, MdagM, etc.) @@ -2171,7 +2283,8 @@ namespace quda { // and provides for several operator() operations to apply it, perhaps to apply // AXPYs etc. Once we have this, further classes diracM diracMdag etc // can implement the operator()-s as needed to apply the operator, MdagM etc etc. - class DiracMatrix { + class DiracMatrix + { protected: const Dirac *dirac; @@ -2199,42 +2312,17 @@ namespace quda { /** @brief return if the operator is a Wilson-type 4-d operator */ - bool isWilsonType() const - { - return (Type() == typeid(DiracWilson).name() || Type() == typeid(DiracWilsonPC).name() - || Type() == typeid(DiracClover).name() || Type() == typeid(DiracCloverPC).name() - || Type() == typeid(DiracCloverHasenbuschTwist).name() - || Type() == typeid(DiracCloverHasenbuschTwistPC).name() || Type() == typeid(DiracTwistedMass).name() - || Type() == typeid(DiracTwistedMassPC).name() || Type() == typeid(DiracTwistedClover).name() - || Type() == typeid(DiracTwistedCloverPC).name()) ? - true : - false; - } + bool isWilsonType() const { return dirac->isWilsonType(); } /** @brief return if the operator is a staggered operator */ - bool isStaggered() const - { - return (Type() == typeid(DiracStaggeredPC).name() || Type() == typeid(DiracStaggered).name() - || Type() == typeid(DiracImprovedStaggeredPC).name() || Type() == typeid(DiracImprovedStaggered).name() - || Type() == typeid(DiracStaggeredKD).name() || Type() == typeid(DiracImprovedStaggeredKD).name()) ? - true : - false; - } + bool isStaggered() const { return dirac->isStaggered(); } /** @brief return if the operator is a domain wall operator, that is, 5-dimensional */ - bool isDwf() const - { - return (Type() == typeid(DiracDomainWall).name() || Type() == typeid(DiracDomainWallPC).name() - || Type() == typeid(DiracDomainWall4D).name() || Type() == typeid(DiracDomainWall4DPC).name() - || Type() == typeid(DiracMobius).name() || Type() == typeid(DiracMobiusPC).name() - || Type() == typeid(DiracMobiusEofa).name() || Type() == typeid(DiracMobiusEofaPC).name()) ? - true : - false; - } + bool isDwf() const { return dirac->isDwf(); } /** @brief return if the operator is a coarse operator @@ -2261,17 +2349,14 @@ namespace quda { @param[out] out The vector of output fields @param[in] in The vector of input fields */ - void operator()(cvector_ref &out, cvector_ref &in) const + void operator()(cvector_ref &out, cvector_ref &in) const override { dirac->M(out, in); for (auto i = 0u; i < in.size(); i++) if (shift != 0.0) blas::axpy(shift, in[i], out[i]); } - int getStencilSteps() const - { - return dirac->getStencilSteps(); - } + int getStencilSteps() const override { return dirac->getStencilSteps(); } }; /* Gloms onto a DiracOp and provides an operator() which applies its MdagM */ @@ -2286,19 +2371,19 @@ namespace quda { @param[out] out The vector of output fields @param[out] out The vector of input fields */ - void operator()(cvector_ref &out, cvector_ref &in) const + void operator()(cvector_ref &out, cvector_ref &in) const override { dirac->MdagM(out, in); for (auto i = 0u; i < in.size(); i++) if (shift != 0.0) blas::axpy(shift, in[i], out[i]); } - int getStencilSteps() const + int getStencilSteps() const override { return 2*dirac->getStencilSteps(); // 2 for M and M dagger } - virtual bool hermitian() const { return true; } // normal op is always Hermitian + virtual bool hermitian() const override { return true; } // normal op is always Hermitian }; /* Gloms onto a DiracOp and provides an operator() which applies its MdagMLocal */ @@ -2314,12 +2399,12 @@ namespace quda { @param[out] out The vector of output fields @param[in] in The vector of input fields */ - void operator()(cvector_ref &out, cvector_ref &in) const + void operator()(cvector_ref &out, cvector_ref &in) const override { dirac->MdagMLocal(out, in); } - int getStencilSteps() const + int getStencilSteps() const override { return 2 * dirac->getStencilSteps(); // 2 for M and M dagger } @@ -2338,19 +2423,19 @@ namespace quda { @param[out] out The vector of output fields @param[in] in The vector of input fields */ - void operator()(cvector_ref &out, cvector_ref &in) const + void operator()(cvector_ref &out, cvector_ref &in) const override { dirac->MMdag(out, in); for (auto i = 0u; i < in.size(); i++) if (shift != 0.0) blas::axpy(shift, in[i], out[i]); } - int getStencilSteps() const + int getStencilSteps() const override { return 2*dirac->getStencilSteps(); // 2 for M and M dagger } - virtual bool hermitian() const { return true; } // normal op is always Hermitian + virtual bool hermitian() const override { return true; } // normal op is always Hermitian }; /* Gloms onto a DiracMatrix and provides an operator() for its Mdag method */ @@ -2365,17 +2450,14 @@ namespace quda { @param[out] out The vector of output fields @param[in] in The vector of input fields */ - void operator()(cvector_ref &out, cvector_ref &in) const + void operator()(cvector_ref &out, cvector_ref &in) const override { dirac->Mdag(out, in); for (auto i = 0u; i < in.size(); i++) if (shift != 0.0) blas::axpy(shift, in[i], out[i]); } - int getStencilSteps() const - { - return dirac->getStencilSteps(); - } + int getStencilSteps() const override { return dirac->getStencilSteps(); } }; /* Gloms onto a dirac matrix and gives back the dagger of whatever that was originally. @@ -2394,17 +2476,14 @@ namespace quda { @param[out] out The vector of output fields @param[in] in The vector of input fields */ - void operator()(cvector_ref &out, cvector_ref &in) const + void operator()(cvector_ref &out, cvector_ref &in) const override { dirac->flipDagger(); mat(std::move(out), std::move(in)); dirac->flipDagger(); } - int getStencilSteps() const - { - return mat.getStencilSteps(); - } + int getStencilSteps() const override { return mat.getStencilSteps(); } }; /** @@ -2496,7 +2575,7 @@ namespace quda { @param[out] out The vector of output fields @param[in] in The vector of input fields */ - void operator()(cvector_ref &out, cvector_ref &in) const + void operator()(cvector_ref &out, cvector_ref &in) const override { dirac->M(out, in); for (auto i = 0u; i < in.size(); i++) { @@ -2505,12 +2584,12 @@ namespace quda { } } - int getStencilSteps() const { return dirac->getStencilSteps(); } + int getStencilSteps() const override { return dirac->getStencilSteps(); } /** @brief return if the operator is HPD */ - virtual bool hermitian() const + virtual bool hermitian() const override { auto dirac_type = dirac->getDiracType(); auto pc_type = dirac->getMatPCType(); diff --git a/lib/dirac.cpp b/lib/dirac.cpp index 8922de27b6..a5ec216943 100644 --- a/lib/dirac.cpp +++ b/lib/dirac.cpp @@ -254,55 +254,274 @@ namespace quda { return nullptr; } - - // Count the number of stencil applications per dslash application. - int Dirac::getStencilSteps() const + + bool Dirac::is_wilson_type(QudaDiracType type) + { + switch (type) { + case QUDA_WILSON_DIRAC: + case QUDA_WILSONPC_DIRAC: + case QUDA_CLOVER_DIRAC: + case QUDA_CLOVERPC_DIRAC: + case QUDA_CLOVER_HASENBUSCH_TWIST_DIRAC: + case QUDA_CLOVER_HASENBUSCH_TWISTPC_DIRAC: + case QUDA_TWISTED_CLOVER_DIRAC: + case QUDA_TWISTED_CLOVERPC_DIRAC: + case QUDA_TWISTED_MASS_DIRAC: + case QUDA_TWISTED_MASSPC_DIRAC: return true; + case QUDA_DOMAIN_WALL_DIRAC: + case QUDA_DOMAIN_WALLPC_DIRAC: + case QUDA_DOMAIN_WALL_4D_DIRAC: + case QUDA_DOMAIN_WALL_4DPC_DIRAC: + case QUDA_MOBIUS_DOMAIN_WALL_DIRAC: + case QUDA_MOBIUS_DOMAIN_WALLPC_DIRAC: + case QUDA_MOBIUS_DOMAIN_WALL_EOFA_DIRAC: + case QUDA_MOBIUS_DOMAIN_WALLPC_EOFA_DIRAC: + case QUDA_STAGGERED_DIRAC: + case QUDA_STAGGEREDPC_DIRAC: + case QUDA_STAGGEREDKD_DIRAC: + case QUDA_ASQTAD_DIRAC: + case QUDA_ASQTADPC_DIRAC: + case QUDA_ASQTADKD_DIRAC: + case QUDA_COARSE_DIRAC: + case QUDA_COARSEPC_DIRAC: + case QUDA_GAUGE_COVDEV_DIRAC: + case QUDA_GAUGE_LAPLACE_DIRAC: + case QUDA_GAUGE_LAPLACEPC_DIRAC: return false; + default: errorQuda("Invalid QudaDiracType %d", type); break; + } + return false; + } + + bool Dirac::is_wilson_type(QudaDslashType type) + { + switch (type) { + case QUDA_WILSON_DSLASH: + case QUDA_CLOVER_WILSON_DSLASH: + case QUDA_CLOVER_HASENBUSCH_TWIST_DSLASH: + case QUDA_TWISTED_MASS_DSLASH: + case QUDA_TWISTED_CLOVER_DSLASH: return true; + case QUDA_DOMAIN_WALL_DSLASH: + case QUDA_DOMAIN_WALL_4D_DSLASH: + case QUDA_MOBIUS_DWF_DSLASH: + case QUDA_MOBIUS_DWF_EOFA_DSLASH: + case QUDA_STAGGERED_DSLASH: + case QUDA_ASQTAD_DSLASH: + case QUDA_LAPLACE_DSLASH: + case QUDA_COVDEV_DSLASH: return false; + default: errorQuda("Invalid QudaDslashType %d", type); break; + } + return false; + } + + bool Dirac::is_staggered_type(QudaDiracType type) + { + switch (type) { + case QUDA_STAGGERED_DIRAC: + case QUDA_STAGGEREDPC_DIRAC: + case QUDA_STAGGEREDKD_DIRAC: + case QUDA_ASQTAD_DIRAC: + case QUDA_ASQTADPC_DIRAC: + case QUDA_ASQTADKD_DIRAC: return true; + case QUDA_WILSON_DIRAC: + case QUDA_WILSONPC_DIRAC: + case QUDA_CLOVER_DIRAC: + case QUDA_CLOVERPC_DIRAC: + case QUDA_CLOVER_HASENBUSCH_TWIST_DIRAC: + case QUDA_CLOVER_HASENBUSCH_TWISTPC_DIRAC: + case QUDA_TWISTED_CLOVER_DIRAC: + case QUDA_TWISTED_CLOVERPC_DIRAC: + case QUDA_TWISTED_MASS_DIRAC: + case QUDA_TWISTED_MASSPC_DIRAC: + case QUDA_DOMAIN_WALL_DIRAC: + case QUDA_DOMAIN_WALLPC_DIRAC: + case QUDA_DOMAIN_WALL_4D_DIRAC: + case QUDA_DOMAIN_WALL_4DPC_DIRAC: + case QUDA_MOBIUS_DOMAIN_WALL_DIRAC: + case QUDA_MOBIUS_DOMAIN_WALLPC_DIRAC: + case QUDA_MOBIUS_DOMAIN_WALL_EOFA_DIRAC: + case QUDA_MOBIUS_DOMAIN_WALLPC_EOFA_DIRAC: + case QUDA_COARSE_DIRAC: + case QUDA_COARSEPC_DIRAC: + case QUDA_GAUGE_COVDEV_DIRAC: + case QUDA_GAUGE_LAPLACE_DIRAC: + case QUDA_GAUGE_LAPLACEPC_DIRAC: return false; + default: errorQuda("Invalid QudaDiracType %d", type); break; + } + return false; + } + + bool Dirac::is_staggered_type(QudaDslashType type) + { + switch (type) { + case QUDA_STAGGERED_DSLASH: + case QUDA_ASQTAD_DSLASH: return true; + case QUDA_WILSON_DSLASH: + case QUDA_CLOVER_WILSON_DSLASH: + case QUDA_CLOVER_HASENBUSCH_TWIST_DSLASH: + case QUDA_TWISTED_MASS_DSLASH: + case QUDA_TWISTED_CLOVER_DSLASH: + case QUDA_DOMAIN_WALL_DSLASH: + case QUDA_DOMAIN_WALL_4D_DSLASH: + case QUDA_MOBIUS_DWF_DSLASH: + case QUDA_MOBIUS_DWF_EOFA_DSLASH: + case QUDA_LAPLACE_DSLASH: + case QUDA_COVDEV_DSLASH: return false; + default: errorQuda("Invalid QudaDslashType %d", type); break; + } + return false; + } + + bool Dirac::is_asqtad(QudaDiracType type) + { + switch (type) { + case QUDA_ASQTAD_DIRAC: + case QUDA_ASQTADPC_DIRAC: + case QUDA_ASQTADKD_DIRAC: return true; + case QUDA_WILSON_DIRAC: + case QUDA_WILSONPC_DIRAC: + case QUDA_CLOVER_DIRAC: + case QUDA_CLOVERPC_DIRAC: + case QUDA_CLOVER_HASENBUSCH_TWIST_DIRAC: + case QUDA_CLOVER_HASENBUSCH_TWISTPC_DIRAC: + case QUDA_TWISTED_CLOVER_DIRAC: + case QUDA_TWISTED_CLOVERPC_DIRAC: + case QUDA_TWISTED_MASS_DIRAC: + case QUDA_TWISTED_MASSPC_DIRAC: + case QUDA_DOMAIN_WALL_DIRAC: + case QUDA_DOMAIN_WALLPC_DIRAC: + case QUDA_DOMAIN_WALL_4D_DIRAC: + case QUDA_DOMAIN_WALL_4DPC_DIRAC: + case QUDA_MOBIUS_DOMAIN_WALL_DIRAC: + case QUDA_MOBIUS_DOMAIN_WALLPC_DIRAC: + case QUDA_MOBIUS_DOMAIN_WALL_EOFA_DIRAC: + case QUDA_MOBIUS_DOMAIN_WALLPC_EOFA_DIRAC: + case QUDA_STAGGERED_DIRAC: + case QUDA_STAGGEREDPC_DIRAC: + case QUDA_STAGGEREDKD_DIRAC: + case QUDA_COARSE_DIRAC: + case QUDA_COARSEPC_DIRAC: + case QUDA_GAUGE_COVDEV_DIRAC: + case QUDA_GAUGE_LAPLACE_DIRAC: + case QUDA_GAUGE_LAPLACEPC_DIRAC: return false; + default: errorQuda("Invalid QudaDiracType %d", type); break; + } + return false; + } + + bool Dirac::is_asqtad(QudaDslashType type) + { + switch (type) { + case QUDA_ASQTAD_DSLASH: return true; + case QUDA_WILSON_DSLASH: + case QUDA_CLOVER_WILSON_DSLASH: + case QUDA_CLOVER_HASENBUSCH_TWIST_DSLASH: + case QUDA_TWISTED_MASS_DSLASH: + case QUDA_TWISTED_CLOVER_DSLASH: + case QUDA_DOMAIN_WALL_DSLASH: + case QUDA_DOMAIN_WALL_4D_DSLASH: + case QUDA_MOBIUS_DWF_DSLASH: + case QUDA_MOBIUS_DWF_EOFA_DSLASH: + case QUDA_STAGGERED_DSLASH: + case QUDA_LAPLACE_DSLASH: + case QUDA_COVDEV_DSLASH: return false; + default: errorQuda("Invalid QudaDslashType %d", type); break; + } + return false; + } + + bool Dirac::is_dwf(QudaDiracType type) + { + switch (type) { + case QUDA_DOMAIN_WALL_DIRAC: + case QUDA_DOMAIN_WALLPC_DIRAC: + case QUDA_DOMAIN_WALL_4D_DIRAC: + case QUDA_DOMAIN_WALL_4DPC_DIRAC: + case QUDA_MOBIUS_DOMAIN_WALL_DIRAC: + case QUDA_MOBIUS_DOMAIN_WALLPC_DIRAC: + case QUDA_MOBIUS_DOMAIN_WALL_EOFA_DIRAC: + case QUDA_MOBIUS_DOMAIN_WALLPC_EOFA_DIRAC: return true; + case QUDA_WILSON_DIRAC: + case QUDA_WILSONPC_DIRAC: + case QUDA_CLOVER_DIRAC: + case QUDA_CLOVERPC_DIRAC: + case QUDA_CLOVER_HASENBUSCH_TWIST_DIRAC: + case QUDA_CLOVER_HASENBUSCH_TWISTPC_DIRAC: + case QUDA_TWISTED_CLOVER_DIRAC: + case QUDA_TWISTED_CLOVERPC_DIRAC: + case QUDA_TWISTED_MASS_DIRAC: + case QUDA_TWISTED_MASSPC_DIRAC: + case QUDA_STAGGERED_DIRAC: + case QUDA_STAGGEREDPC_DIRAC: + case QUDA_STAGGEREDKD_DIRAC: + case QUDA_ASQTAD_DIRAC: + case QUDA_ASQTADPC_DIRAC: + case QUDA_ASQTADKD_DIRAC: + case QUDA_COARSE_DIRAC: + case QUDA_COARSEPC_DIRAC: + case QUDA_GAUGE_COVDEV_DIRAC: + case QUDA_GAUGE_LAPLACE_DIRAC: + case QUDA_GAUGE_LAPLACEPC_DIRAC: return false; + default: errorQuda("Invalid QudaDiracType %d", type); break; + } + return false; + } + + bool Dirac::is_dwf(QudaDslashType type) { - int steps = 0; - switch (type) - { - case QUDA_WILSON_DIRAC: - case QUDA_CLOVER_DIRAC: - case QUDA_DOMAIN_WALL_DIRAC: - case QUDA_DOMAIN_WALL_4D_DIRAC: - case QUDA_MOBIUS_DOMAIN_WALL_DIRAC: - case QUDA_MOBIUS_DOMAIN_WALL_EOFA_DIRAC: - case QUDA_TWISTED_CLOVER_DIRAC: - case QUDA_TWISTED_MASS_DIRAC: - case QUDA_STAGGERED_DIRAC: - case QUDA_ASQTAD_DIRAC: - case QUDA_STAGGEREDKD_DIRAC: - case QUDA_ASQTADKD_DIRAC: - case QUDA_COARSE_DIRAC: - case QUDA_GAUGE_LAPLACE_DIRAC: - case QUDA_GAUGE_COVDEV_DIRAC: - steps = 1; // single fused operator - break; - case QUDA_CLOVER_HASENBUSCH_TWIST_DIRAC: // implemented as separate even, odd - steps = 2; // For D_{eo} and D_{oe} piece. - break; - case QUDA_WILSONPC_DIRAC: - case QUDA_CLOVERPC_DIRAC: - case QUDA_CLOVER_HASENBUSCH_TWISTPC_DIRAC: - case QUDA_DOMAIN_WALLPC_DIRAC: - case QUDA_DOMAIN_WALL_4DPC_DIRAC: - case QUDA_MOBIUS_DOMAIN_WALLPC_DIRAC: - case QUDA_MOBIUS_DOMAIN_WALLPC_EOFA_DIRAC: - case QUDA_STAGGEREDPC_DIRAC: - case QUDA_ASQTADPC_DIRAC: - case QUDA_TWISTED_CLOVERPC_DIRAC: - case QUDA_TWISTED_MASSPC_DIRAC: - case QUDA_COARSEPC_DIRAC: - case QUDA_GAUGE_LAPLACEPC_DIRAC: - steps = 2; - break; - default: - errorQuda("Unsupported Dslash type %d.\n", type); - steps = 0; - break; + switch (type) { + case QUDA_DOMAIN_WALL_DSLASH: + case QUDA_DOMAIN_WALL_4D_DSLASH: + case QUDA_MOBIUS_DWF_DSLASH: + case QUDA_MOBIUS_DWF_EOFA_DSLASH: return true; + case QUDA_WILSON_DSLASH: + case QUDA_CLOVER_WILSON_DSLASH: + case QUDA_CLOVER_HASENBUSCH_TWIST_DSLASH: + case QUDA_TWISTED_MASS_DSLASH: + case QUDA_TWISTED_CLOVER_DSLASH: + case QUDA_STAGGERED_DSLASH: + case QUDA_ASQTAD_DSLASH: + case QUDA_LAPLACE_DSLASH: + case QUDA_COVDEV_DSLASH: return false; + default: errorQuda("Invalid QudaDslashType %d", type); break; } + return false; + } - return steps; + QudaDslashType Dirac::dirac_to_dslash_type(QudaDiracType type) + { + switch (type) { + case QUDA_WILSON_DIRAC: + case QUDA_WILSONPC_DIRAC: return QUDA_WILSON_DSLASH; + case QUDA_CLOVER_DIRAC: + case QUDA_CLOVERPC_DIRAC: return QUDA_CLOVER_WILSON_DSLASH; + case QUDA_CLOVER_HASENBUSCH_TWIST_DIRAC: + case QUDA_CLOVER_HASENBUSCH_TWISTPC_DIRAC: return QUDA_CLOVER_HASENBUSCH_TWIST_DSLASH; + case QUDA_TWISTED_CLOVER_DIRAC: + case QUDA_TWISTED_CLOVERPC_DIRAC: return QUDA_TWISTED_CLOVER_DSLASH; + case QUDA_TWISTED_MASS_DIRAC: + case QUDA_TWISTED_MASSPC_DIRAC: return QUDA_TWISTED_MASS_DSLASH; + case QUDA_DOMAIN_WALL_DIRAC: + case QUDA_DOMAIN_WALLPC_DIRAC: return QUDA_DOMAIN_WALL_DSLASH; + case QUDA_DOMAIN_WALL_4D_DIRAC: + case QUDA_DOMAIN_WALL_4DPC_DIRAC: return QUDA_DOMAIN_WALL_4D_DSLASH; + case QUDA_MOBIUS_DOMAIN_WALL_DIRAC: + case QUDA_MOBIUS_DOMAIN_WALLPC_DIRAC: return QUDA_MOBIUS_DWF_DSLASH; + case QUDA_MOBIUS_DOMAIN_WALL_EOFA_DIRAC: + case QUDA_MOBIUS_DOMAIN_WALLPC_EOFA_DIRAC: return QUDA_MOBIUS_DWF_EOFA_DSLASH; + case QUDA_STAGGERED_DIRAC: + case QUDA_STAGGEREDPC_DIRAC: + case QUDA_STAGGEREDKD_DIRAC: return QUDA_STAGGERED_DSLASH; + case QUDA_ASQTAD_DIRAC: + case QUDA_ASQTADPC_DIRAC: + case QUDA_ASQTADKD_DIRAC: return QUDA_ASQTAD_DSLASH; + case QUDA_GAUGE_COVDEV_DIRAC: return QUDA_COVDEV_DSLASH; + case QUDA_GAUGE_LAPLACE_DIRAC: + case QUDA_GAUGE_LAPLACEPC_DIRAC: return QUDA_LAPLACE_DSLASH; + case QUDA_COARSE_DIRAC: + case QUDA_COARSEPC_DIRAC: return QUDA_INVALID_DSLASH; + default: errorQuda("Invalid QudaDiracType %d", type); break; + } + return QUDA_INVALID_DSLASH; } void Dirac::prefetch(QudaFieldLocation mem_space, qudaStream_t stream) const diff --git a/lib/interface_quda.cpp b/lib/interface_quda.cpp index 6928797062..f9e1181848 100644 --- a/lib/interface_quda.cpp +++ b/lib/interface_quda.cpp @@ -3468,37 +3468,32 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col checkInvertParam(param, _hp_x[0], _hp_b[0]); - bool is_staggered = false; - bool is_asqtad = false; + // Asqtad loads fat and long links; all others (including naive staggered) load thin links + bool is_asqtad = Dirac::is_asqtad(param->dslash_type); GaugeBundleBackup thin_links_bkup; GaugeBundleBackup fat_links_bkup; GaugeBundleBackup long_links_bkup; - if (gaugePrecise) { - is_staggered = false; - thin_links_bkup.backup(gaugePrecise, gaugeSloppy, gaugePrecondition, gaugeRefinement, gaugeEigensolver, - gaugeExtended); + if (is_asqtad) { + if (!gaugeFatPrecise || !gaugeLongPrecise) + errorQuda("Both milc_fatlinks and milc_longlinks need to be non-null for asqtad-type dslash"); - } else if (gaugeFatPrecise) { - is_staggered = true; fat_links_bkup.backup(gaugeFatPrecise, gaugeFatSloppy, gaugeFatPrecondition, gaugeFatRefinement, gaugeFatEigensolver, gaugeFatExtended); - if (param->dslash_type == QUDA_ASQTAD_DSLASH) { - if (!gaugeLongPrecise) errorQuda("milc_longlinks is null for an asqtad dslash"); - is_asqtad = true; - long_links_bkup.backup(gaugeLongPrecise, gaugeLongSloppy, gaugeLongPrecondition, gaugeLongRefinement, - gaugeLongEigensolver, gaugeLongExtended); - } + long_links_bkup.backup(gaugeLongPrecise, gaugeLongSloppy, gaugeLongPrecondition, gaugeLongRefinement, + gaugeLongEigensolver, gaugeLongExtended); } else { - errorQuda("Both h_gauge and milc_fatlinks are null."); + if (!gaugePrecise) errorQuda("h_gauge is null for a Wilson-type or naive staggered dslash"); + thin_links_bkup.backup(gaugePrecise, gaugeSloppy, gaugePrecondition, gaugeRefinement, gaugeEigensolver, + gaugeExtended); } // Deal with Spinors bool pc_solution = (param->solution_type == QUDA_MATPC_SOLUTION) || (param->solution_type == QUDA_MATPCDAG_MATPC_SOLUTION); - lat_dim_t X = is_staggered ? gaugeFatPrecise->X() : gaugePrecise->X(); + lat_dim_t X = is_asqtad ? gaugeFatPrecise->X() : gaugePrecise->X(); ColorSpinorParam spinorParam(_hp_b[0], *param, X, pc_solution, param->input_location); std::vector _h_b(param->num_src); @@ -3523,23 +3518,23 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col logQuda(QUDA_DEBUG_VERBOSE, "Spliting the grid into sub-partitions: (%2d,%2d,%2d,%2d) / (%2d,%2d,%2d,%2d)\n", comm_dim(0), comm_dim(1), comm_dim(2), comm_dim(3), split_key[0], split_key[1], split_key[2], split_key[3]); - if (!is_staggered) + if (!is_asqtad) gf_param = GaugeFieldParam(*(thin_links_bkup.precise)); else { milc_fatlink_param = GaugeFieldParam(*(fat_links_bkup.precise)); - if (is_asqtad) milc_longlink_param = GaugeFieldParam(*(long_links_bkup.precise)); + milc_longlink_param = GaugeFieldParam(*(long_links_bkup.precise)); } for (int d = 0; d < CommKey::n_dim; d++) { if (comm_dim(d) % split_key[d] != 0) { errorQuda("Split not possible: %2d %% %2d != 0", comm_dim(d), split_key[d]); } - if (!is_staggered) { + if (!is_asqtad) { gf_param.x[d] *= split_key[d]; gf_param.pad *= split_key[d]; } else { milc_fatlink_param.x[d] *= split_key[d]; - if (is_asqtad) milc_longlink_param.x[d] *= split_key[d]; + milc_longlink_param.x[d] *= split_key[d]; } } @@ -3547,7 +3542,7 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col quda::GaugeField *collected_milc_fatlink_field = nullptr; quda::GaugeField *collected_milc_longlink_field = nullptr; - if (!is_staggered) { + if (!is_asqtad) { gf_param.create = QUDA_NULL_FIELD_CREATE; collected_gauge = new quda::GaugeField(gf_param); quda::split_field(*collected_gauge, {*(thin_links_bkup.precise)}, split_key); @@ -3558,12 +3553,9 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col collected_milc_fatlink_field = new GaugeField(milc_fatlink_param); quda::split_field(*collected_milc_fatlink_field, {*(fat_links_bkup.precise)}, split_key); - if (is_asqtad) { - milc_longlink_param.create = QUDA_NULL_FIELD_CREATE; - collected_milc_longlink_field = new GaugeField(milc_longlink_param); - v_g[0] = long_links_bkup.precise; - quda::split_field(*collected_milc_longlink_field, {*(long_links_bkup.precise)}, split_key); - } + milc_longlink_param.create = QUDA_NULL_FIELD_CREATE; + collected_milc_longlink_field = new GaugeField(milc_longlink_param); + quda::split_field(*collected_milc_longlink_field, {*(long_links_bkup.precise)}, split_key); } // ------ Clover field @@ -3630,7 +3622,7 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col // Load 'collected gauge field' logQuda(QUDA_DEBUG_VERBOSE, "Split grid loading gauge field...\n"); - if (!is_staggered) { + if (!is_asqtad) { setupGaugeFields(collected_gauge, gaugePrecise, gaugeSloppy, gaugePrecondition, gaugeRefinement, gaugeEigensolver, gaugeExtended, thin_links_bkup, profile.profile); @@ -3638,10 +3630,8 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col setupGaugeFields(collected_milc_fatlink_field, gaugeFatPrecise, gaugeFatSloppy, gaugeFatPrecondition, gaugeFatRefinement, gaugeFatEigensolver, gaugeFatExtended, fat_links_bkup, profile.profile); - if (is_asqtad) { - setupGaugeFields(collected_milc_longlink_field, gaugeLongPrecise, gaugeLongSloppy, gaugeLongPrecondition, - gaugeLongRefinement, gaugeLongEigensolver, gaugeLongExtended, long_links_bkup, profile.profile); - } + setupGaugeFields(collected_milc_longlink_field, gaugeLongPrecise, gaugeLongSloppy, gaugeLongPrecondition, + gaugeLongRefinement, gaugeLongEigensolver, gaugeLongExtended, long_links_bkup, profile.profile); } logQuda(QUDA_DEBUG_VERBOSE, "Split grid loaded gauge field...\n"); @@ -3693,7 +3683,7 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col profileInvertMultiSrc.TPSTOP(QUDA_PROFILE_EPILOGUE); // Restore the gauge field - if (!is_staggered) { + if (!is_asqtad) { freeUniqueGaugeQuda(QUDA_WILSON_LINKS); gaugePrecise = thin_links_bkup.precise; @@ -3713,16 +3703,14 @@ void callMultiSrcQuda(void **_hp_x, void **_hp_b, QudaInvertParam *param, // col gaugeFatRefinement = fat_links_bkup.refinement; gaugeFatEigensolver = fat_links_bkup.eigensolver; gaugeFatExtended = fat_links_bkup.extended; - if (is_asqtad) { - - freeUniqueGaugeQuda(QUDA_ASQTAD_LONG_LINKS); - gaugeLongPrecise = long_links_bkup.precise; - gaugeLongSloppy = long_links_bkup.sloppy; - gaugeLongPrecondition = long_links_bkup.precondition; - gaugeLongRefinement = long_links_bkup.refinement; - gaugeLongEigensolver = long_links_bkup.eigensolver; - gaugeLongExtended = long_links_bkup.extended; - } + + freeUniqueGaugeQuda(QUDA_ASQTAD_LONG_LINKS); + gaugeLongPrecise = long_links_bkup.precise; + gaugeLongSloppy = long_links_bkup.sloppy; + gaugeLongPrecondition = long_links_bkup.precondition; + gaugeLongRefinement = long_links_bkup.refinement; + gaugeLongEigensolver = long_links_bkup.eigensolver; + gaugeLongExtended = long_links_bkup.extended; } if (is_clover) { diff --git a/tests/utils/host_utils.cpp b/tests/utils/host_utils.cpp index 140eccd8a8..9df5050cc1 100644 --- a/tests/utils/host_utils.cpp +++ b/tests/utils/host_utils.cpp @@ -13,6 +13,7 @@ // QUDA headers #include #include +#include // External headers #include @@ -369,25 +370,9 @@ bool is_normal_residual(QudaInverterType type) } } -bool is_staggered(QudaDslashType type) -{ - switch (type) { - case QUDA_STAGGERED_DSLASH: - case QUDA_ASQTAD_DSLASH: return true; - default: return false; - } -} +bool is_staggered(QudaDslashType type) { return Dirac::is_staggered_type(type); } -bool is_chiral(QudaDslashType type) -{ - switch (type) { - case QUDA_DOMAIN_WALL_DSLASH: - case QUDA_DOMAIN_WALL_4D_DSLASH: - case QUDA_MOBIUS_DWF_DSLASH: - case QUDA_MOBIUS_DWF_EOFA_DSLASH: return true; - default: return false; - } -} +bool is_chiral(QudaDslashType type) { return Dirac::is_dwf(type); } bool is_laplace(QudaDslashType type) {