Skip to content

Commit

Permalink
FlashAttention-2 fix (#104)
Browse files Browse the repository at this point in the history
* dnn: Increase error threshold in FA-2

* dnn: Switch to naive FP32 as baseline fails in FlashAttention-2

* target: Re-introduce FlashAttention in CI

---------

Co-authored-by: Viviane Potocnik <vivianep@iis.ee.ethz.ch>
Co-authored-by: Luca Colagrande <luca.colagrande3@gmail.com>
  • Loading branch information
3 people authored Mar 4, 2024
1 parent 1a0ff31 commit 454ab4c
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
7 changes: 4 additions & 3 deletions sw/blas/gemm/src/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -1232,9 +1232,10 @@ void sc_st_gemm(precision_t prec, uint32_t expand, uint32_t setup_ssr,
break;
case FP32:
if (baseline) {
gemm_fp32_baseline(frac_m, n, k, (float*)a + offsetA,
lda_strided, (float*)b, ldb,
(float*)c + offsetC, ldc_strided, beta);
gemm_fp32_naive(frac_m, n, k, (float*)a + offsetA,
lda_strided, transa, (float*)b, ldb, transb,
(float*)c + offsetC, ldc_strided,
(float)beta);
} else {
gemm_fp32_opt(frac_m, n, k, (float*)a + offsetA,
lda_strided, (float*)b, ldb,
Expand Down
2 changes: 1 addition & 1 deletion sw/dnn/flashattention_2/verify.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from data_utils import from_buffer, ctype_from_precision_t # noqa: E402


ERR_THRESHOLD = 1E-6
ERR_THRESHOLD = 1E-4


def main():
Expand Down
4 changes: 2 additions & 2 deletions target/snitch_cluster/sw/fdiv.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@ runs:
cmd: [../../../sw/dnn/layernorm/verify.py, "${sim_bin}", "${elf}"]
- elf: apps/dnn/gelu/build/gelu.elf
cmd: [../../../sw/dnn/gelu/verify.py, "${sim_bin}", "${elf}"]
# - elf: apps/dnn/flashattention_2/build/flashattention_2.elf
# cmd: [../../../sw/dnn/flashattention_2/verify.py, "${sim_bin}", "${elf}"]
- elf: apps/dnn/flashattention_2/build/flashattention_2.elf
cmd: [../../../sw/dnn/flashattention_2/verify.py, "${sim_bin}", "${elf}"]

0 comments on commit 454ab4c

Please sign in to comment.