diff --git a/ff/mersenne31.hpp b/ff/mersenne31.hpp index 83338a2..3c4ae7d 100644 --- a/ff/mersenne31.hpp +++ b/ff/mersenne31.hpp @@ -140,7 +140,7 @@ class mrs31_t { friend inline mrs31_t operator-(mrs31_t a, const mrs31_t b) { return a -= b; } - inline mrs31_t cneg(bool flag) + inline mrs31_t& cneg(bool flag) { if (flag && val != 0) val = MOD - val; @@ -264,6 +264,20 @@ class mrs31_t { return mrs31_t{ret}; } + static inline mrs31_t dot_product(mrs31_t a, mrs31_t b, + mrs31_t c, mrs31_t d) + { + uint64_t tmp = a.val*(uint64_t)b.val + c.val*(uint64_t)d.val; + uint32_t ret = (uint32_t)(tmp >> 31); + if (ret >= MOD) + ret -= MOD; + ret += (uint32_t)tmp & MOD; + if (ret >= MOD) + ret -= MOD; + + return mrs31_t{ret}; + } + private: static inline mrs31_t sqr_n(mrs31_t s, uint32_t n) { diff --git a/ff/mont32_t.cuh b/ff/mont32_t.cuh index 2d7b96b..cd868b3 100644 --- a/ff/mont32_t.cuh +++ b/ff/mont32_t.cuh @@ -125,7 +125,7 @@ public: friend inline mont32_t operator-(mont32_t a, const mont32_t b) { return a -= b; } - inline mont32_t cneg(bool flag) + inline mont32_t& cneg(bool flag) { asm("{"); asm(".reg.pred %flag;"); @@ -334,6 +334,26 @@ public: return final_sub(acc[1]); } + static inline mont32_t dot_product(mont32_t a, mont32_t b, + mont32_t c, mont32_t d) + { + uint32_t acc[2]; + + asm("mul.lo.u32 %0, %2, %3; mul.hi.u32 %1, %2, %3;" + : "=r"(acc[0]), "=r"(acc[1]) : "r"(*a), "r"(*b)); + asm("mad.lo.cc.u32 %0, %2, %3, %0; madc.hi.u32 %1, %2, %3, %1;" + : "+r"(acc[0]), "+r"(acc[1]) : "r"(*c), "r"(*d)); + if (N == 32) + acc[1] = final_sub(acc[1]); + + uint32_t red; + asm("mul.lo.u32 %0, %1, %2;" : "=r"(red) : "r"(acc[0]), "r"(M0)); + asm("mad.lo.cc.u32 %0, %2, %3, %0; madc.hi.cc.u32 %1, %2, %3, %1;" + : "+r"(acc[0]), "+r"(acc[1]) : "r"(red), "r"(MOD)); + + return final_sub(acc[1]); + } + inline mont32_t reciprocal() const { return *this ^ (MOD-2); } friend inline mont32_t operator/(int one, mont32_t a)