diff --git a/crystals-kyber/poly.go b/crystals-kyber/poly.go index 1d19a70..5c74780 100644 --- a/crystals-kyber/poly.go +++ b/crystals-kyber/poly.go @@ -186,18 +186,24 @@ func polyToMsg(p Poly) []byte { return msg } -//compress packs a polynomial into a byte array using d bits per coefficient +//compress packs a polynomial into a byte array using d bits per coefficient - fixed against https://kyberslash.cr.yp.to/faq.html (cases d=4,5 only for now) func (p *Poly) compress(d int) []byte { c := make([]byte, n*d/8) switch d { case 3: var t [8]uint16 + var d0 uint32 /* accumulation value for fixing KyberSlash2 */ id := 0 for i := 0; i < n/8; i++ { for j := 0; j < 8; j++ { - t[j] = uint16(((uint32(p[8*i+j])<<3)+uint32(q)/2)/ - uint32(q)) & ((1 << 3) - 1) + /* t[j] = uint16(((uint32(p[8*i+j])<<3)+uint32(q)/2)/ + uint32(q)) & ((1 << 3) - 1) */ + d0 = uint32(p[8*i+j]) << 3 + d0 += 1664 + d0 *= 161271 + d0 >>= 29 + t[j] = uint16(d0 & 0x7) } c[id] = byte(t[0]) | byte(t[1]<<3) | byte(t[2]<<6) c[id+1] = byte(t[2]>>2) | byte(t[3]<<1) | byte(t[4]<<4) | byte(t[5]<<7) @@ -207,11 +213,17 @@ func (p *Poly) compress(d int) []byte { case 4: var t [8]uint16 + var d0 uint32 /* accumulation value for fixing KyberSlash2 */ id := 0 for i := 0; i < n/8; i++ { for j := 0; j < 8; j++ { - t[j] = uint16(((uint32(p[8*i+j])<<4)+uint32(q)/2)/ - uint32(q)) & ((1 << 4) - 1) + /* t[j] = uint16(((uint32(p[8*i+j])<<4)+uint32(q)/2)/ + uint32(q)) & ((1 << 4) - 1)*/ + d0 = uint32(p[8*i+j]) << 4 + d0 += 1665 + d0 *= 80635 + d0 >>= 28 + t[j] = uint16(d0 & 0xf) } c[id] = byte(t[0]) | byte(t[1]<<4) c[id+1] = byte(t[2]) | byte(t[3]<<4) @@ -222,11 +234,17 @@ func (p *Poly) compress(d int) []byte { case 5: var t [8]uint16 + var d0 uint32 /* accumulation value for fixing KyberSlash2 */ id := 0 for i := 0; i < n/8; i++ { for j := 0; j < 8; j++ { - t[j] = uint16(((uint32(p[8*i+j])<<5)+uint32(q)/2)/ - uint32(q)) & ((1 << 5) - 1) + /* t[j] = uint16(((uint32(p[8*i+j])<<5)+uint32(q)/2)/ + uint32(q)) & ((1 << 5) - 1) */ + d0 = uint32(p[8*i+j]) << 5 + d0 += 1664 + d0 *= 40318 + d0 >>= 27 + t[j] = uint16(d0 & 0x1f) } c[id] = byte(t[0]) | byte(t[1]<<5) c[id+1] = byte(t[1]>>3) | byte(t[2]<<2) | byte(t[3]<<7) @@ -238,11 +256,17 @@ func (p *Poly) compress(d int) []byte { case 6: var t [4]uint16 + var d0 uint32 /* accumulation value for fixing KyberSlash2 */ id := 0 for i := 0; i < n/4; i++ { - for j := 0; j < 4; j++ { - t[j] = uint16(((uint32(p[4*i+j])<<6)+uint32(q)/2)/ - uint32(q)) & ((1 << 6) - 1) + for j := 0; j < 4; j++ { + /* t[j] = uint16(((uint32(p[4*i+j])<<6)+uint32(q)/2)/ + uint32(q)) & ((1 << 6) - 1) */ + d0 = uint32(p[4*i+j]) << 6 + d0 += 1664 + d0 *= 20159 + d0 >>= 26 + t[j] = uint16(d0 & 0x3f) } c[id] = byte(t[0]) | byte(t[1]<<6) c[id+1] = byte(t[1]>>2) | byte(t[2]<<4) @@ -252,11 +276,17 @@ func (p *Poly) compress(d int) []byte { case 10: var t [4]uint16 + var d0 uint64 /* accumulation value for fixing KyberSlash2 */ id := 0 for i := 0; i < n/4; i++ { for j := 0; j < 4; j++ { - t[j] = uint16(((uint32(p[4*i+j])<<10)+uint32(q)/2)/ - uint32(q)) & ((1 << 10) - 1) + /* t[j] = uint16(((uint32(p[4*i+j])<<10)+uint32(q)/2)/ + uint32(q)) & ((1 << 10) - 1) */ + d0 = uint64(p[4*i+j]) << 10 + d0 += 1665 + d0 *= 1290167 + d0 >>= 32 + t[j] = uint16(d0 & 0x3ff) } c[id] = byte(t[0]) c[id+1] = byte(t[0]>>8) | byte(t[1]<<2) @@ -267,11 +297,17 @@ func (p *Poly) compress(d int) []byte { } case 11: var t [8]uint16 + var d0 uint64 /* accumulation value for fixing KyberSlash2 */ id := 0 for i := 0; i < n/8; i++ { for j := 0; j < 8; j++ { - t[j] = uint16(((uint32(p[8*i+j])<<11)+uint32(q)/2)/ - uint32(q)) & ((1 << 11) - 1) + /* t[j] = uint16(((uint32(p[8*i+j])<<11)+uint32(q)/2)/ + uint32(q)) & ((1 << 11) - 1) */ + d0 = uint64(p[8*i+j]) << 11 + d0 += 1664 + d0 *= 645084 + d0 >>= 31 + t[j] = uint16(d0 & 0x7ff) } c[id] = byte(t[0]) c[id+1] = byte(t[0]>>8) | byte(t[1]<<3)