From 0ae16589601cedb3323d59970b65eab50dbfc2fd Mon Sep 17 00:00:00 2001 From: tsewei-lin Date: Tue, 21 Jan 2025 02:46:21 -0800 Subject: [PATCH] vector: crypto: fix overlap check when EGW > VLEN --- riscv/insns/vsm4r_vs.h | 2 +- riscv/zvk_ext_macros.h | 12 ++++++++++++ riscv/zvkned_ext_macros.h | 3 ++- 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/riscv/insns/vsm4r_vs.h b/riscv/insns/vsm4r_vs.h index 649eada96f..6a6ed04793 100644 --- a/riscv/insns/vsm4r_vs.h +++ b/riscv/insns/vsm4r_vs.h @@ -7,7 +7,7 @@ const uint32_t EGS = 4; require_vsm4_constraints; require_align(insn.rd(), P.VU.vflmul); // No overlap of vd and vs2. -require_noover(insn.rs2(), 1, insn.rd(), P.VU.vflmul); +require_noover_eglmul(insn.rd(), insn.rs2()); VI_ZVK_VD_VS2_NOOPERANDS_PRELOOP_EGU32x4_NOVM_LOOP( {}, diff --git a/riscv/zvk_ext_macros.h b/riscv/zvk_ext_macros.h index f094629835..490de7020f 100644 --- a/riscv/zvk_ext_macros.h +++ b/riscv/zvk_ext_macros.h @@ -86,6 +86,18 @@ // (LMUL * VLEN) <= EGW #define require_egw_fits(EGW) require((EGW) <= (P.VU.VLEN * P.VU.vflmul)) +// ensure that rs2 and rd do not overlap, assuming rd encodes an LMUL wide +// vector register group and rs2 encodes an vs2_EMUL=ceil(EGW / VLEN) vector register +// group. +// Assumption: LMUL >= vs2_EMUL which is enforced independently through require_egw_fits. +#define require_noover_eglmul(vd, vs2) \ + do { \ + int vd_emul = P.VU.vflmul < 1.f ? 1 : (int) P.VU.vflmul; \ + int aligned_vd = vd / vd_emul; \ + int aligned_vs2 = vs2 / vd_emul; \ + require(aligned_vd != aligned_vs2); \ + } while (0) + // Checks that the vector unit state (vtype and vl) can be interpreted // as element groups with EEW=32, EGS=4 (four 32-bits elements per group), // for an effective element group width of EGW=128 bits. diff --git a/riscv/zvkned_ext_macros.h b/riscv/zvkned_ext_macros.h index 8ece5687d9..db12d1593c 100644 --- a/riscv/zvkned_ext_macros.h +++ b/riscv/zvkned_ext_macros.h @@ -2,6 +2,7 @@ // the RISC-V Zvkned extension (vector AES single round). #include "insns/aes_common.h" +#include "zvk_ext_macros.h" #ifndef RISCV_ZVKNED_EXT_MACROS_H_ #define RISCV_ZVKNED_EXT_MACROS_H_ @@ -21,7 +22,7 @@ require(P.VU.vsew == 32); \ require_egw_fits(128); \ require_align(insn.rd(), P.VU.vflmul); \ - require_noover(insn.rs2(), 1, insn.rd(), P.VU.vflmul); \ + require_noover_eglmul(insn.rd(), insn.rs2()); \ } while (false) // vaes*.vv instruction constraints. Those are the same as the .vs ones,