diff --git a/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp b/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp index c0651d81c96..aa7bce903ae 100644 --- a/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp +++ b/src/cpu/x64/brgemm/jit_brgemm_kernel.cpp @@ -305,7 +305,7 @@ struct jit_brgemm_kernel_t : public jit_generator { used_vregs = 5; else if (brg.is_f16_b_non_amx_vnni()) used_vregs = 2; - + if (one_of(brg.dt_b, data_type::nf4) && brg.isa_impl == avx2) { used_vregs += 5; } @@ -2431,7 +2431,11 @@ void jit_brgemm_kernel_t::gemm_microkernel_dyn_quant(int bd_block2, for (int bd = bd_b; bd < bd_e; bd++) { uni_vbroadcastss(vmm_src_scales, ptr[reg_local_src_scales + bd * brg.src_scales_stride * sizeof(float)]); for (int ld = 0; ld < ld_block2; ld++) { - uni_vmovups(load(ld), ptr[reg_local_wei_scales + ld * brg.ld_block * sizeof(float)]); + if (brg.wei_decomp_scales_stride == 0) { + uni_vbroadcastss(load(ld), ptr[reg_local_wei_scales]); + } else { + uni_vmovups(load(ld), ptr[reg_local_wei_scales + ld * brg.ld_block * sizeof(float)]); + } } for (int ld = 0; ld < ld_block2; ld++) { auto vmm_accm_aux = vmm_accm_tmp(ld_block2, bd, ld); @@ -2901,7 +2905,11 @@ void jit_brgemm_kernel_t::gemm_microkernel(int bd_block2, bool is_bdb_tail, for (int ld = 0; ld < ld_block2; ld++) { auto vmm_accm_tmp = accm_tmp(ld_block2, 0, ld); auto vmm_accm = accm(ld_block2, 0, ld); - load_scales(bcst(), ptr[reg_local_wei_scales + ld * brg.ld_block * types::data_type_size(brg.wei_decomp_scales_dt)]); + if (brg.wei_decomp_scales_stride == 0) { + load_scales(bcst(), ptr[reg_local_wei_scales]); + } else { + load_scales(bcst(), ptr[reg_local_wei_scales + ld * brg.ld_block * types::data_type_size(brg.wei_decomp_scales_dt)]); + } uni_vfmadd231ps(vmm_accm, vmm_accm_tmp, bcst()); } } @@ -3025,8 +3033,8 @@ void jit_brgemm_kernel_t::ldb_loop(int bd_block2, bool is_bdb_tail, mov(reg_rdb_loop, brg.rdb); L_aligned(rdb_loop_label, 64); { - if (brg.with_grouped_wei_decomp && (brg.wei_decomp_scales_stride != 0 || - brg.wei_decomp_zero_points_stride != 0)) { + if ((brg.with_grouped_wei_decomp && (brg.wei_decomp_scales_stride != 0 || + brg.wei_decomp_zero_points_stride != 0)) || brg.with_src_dyn_quant) { auto reg_local_ic = reg_aux_D; auto reg_local_wei_params = reg_bdb_loop; auto reg_local_ic_group = reg_ldb_loop;