diff --git a/crystals-kyber/poly.go b/crystals-kyber/poly.go index 1d19a70..442614f 100644 --- a/crystals-kyber/poly.go +++ b/crystals-kyber/poly.go @@ -195,7 +195,7 @@ func (p *Poly) compress(d int) []byte { var t [8]uint16 id := 0 for i := 0; i < n/8; i++ { - for j := 0; j < 8; j++ { + for j := 0; j < 8; j++ { //TODO: fix KyberSlash2 here t[j] = uint16(((uint32(p[8*i+j])<<3)+uint32(q)/2)/ uint32(q)) & ((1 << 3) - 1) } @@ -207,11 +207,18 @@ func (p *Poly) compress(d int) []byte { case 4: var t [8]uint16 + var d0 uint32 /* accumulation value for fixing KyberSlash2 */ id := 0 for i := 0; i < n/8; i++ { for j := 0; j < 8; j++ { - t[j] = uint16(((uint32(p[8*i+j])<<4)+uint32(q)/2)/ - uint32(q)) & ((1 << 4) - 1) + /* t[j] = uint16(((uint32(p[8*i+j])<<4)+uint32(q)/2)/ + uint32(q)) & ((1 << 4) - 1)*/ + t[j] = uint16(p[8*i+j]) + d0 = uint32(t << 4) + d0 += 1665 + d0 *= 80635 + d0 >>= 28 + t[j] = d0 & 0xf; } c[id] = byte(t[0]) | byte(t[1]<<4) c[id+1] = byte(t[2]) | byte(t[3]<<4) @@ -240,7 +247,7 @@ func (p *Poly) compress(d int) []byte { var t [4]uint16 id := 0 for i := 0; i < n/4; i++ { - for j := 0; j < 4; j++ { + for j := 0; j < 4; j++ {//TODO: fix KyberSlash2 here t[j] = uint16(((uint32(p[4*i+j])<<6)+uint32(q)/2)/ uint32(q)) & ((1 << 6) - 1) }