Skip to content

Commit

Permalink
Fix gradlib fp8 output (opendatahub-io#76)
Browse files Browse the repository at this point in the history
* fix gradlib fp8 output

* add condition check for existing tune result

* fix linter

* fix import order

* fix lint
  • Loading branch information
charlifu authored Jul 1, 2024
1 parent d6e7862 commit 52df169
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 8 deletions.
22 changes: 15 additions & 7 deletions gradlib/gradlib/GemmTuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,21 @@ def find_hipblas_sols(self):
self.hipb_sols = sols

def check_gemm_ref(self, libtype, solidx):
ref = F.linear(self.inp.to(torch.float32),
self.weights.to(torch.float32)).to(self.outdtype)
if self.indtype == torch.float8_e4m3fnuz:
ref, _ = torch._scaled_mm(self.inp,
self.weights.t(),
out_dtype=self.outdtype)
else:
ref = F.linear(self.inp, self.weights)
if libtype == 'hipblaslt':
c = hipbsolidxgemm.hipb_mm(self.inp, self.weights.t(), solidx,
self.outdtype)
elif libtype == 'rocblas':
c = rocsolidxgemm.rocb_mm(self.inp, self.weights.t(), solidx)
if torch.allclose(c, ref, atol=self.atol, rtol=self.rtol):
#print('>>>',libtype,'Solidx',solidx,'passed reference test')
if torch.allclose(c.to(torch.float32),
ref.to(torch.float32),
atol=self.atol,
rtol=self.rtol):
return True

print('>>>',
Expand Down Expand Up @@ -263,7 +269,8 @@ def add_gemm(self, m, n, k):

def find_best_sols(self):
df = self.gemm_problems
soldf = pd.DataFrame()
soldf = pd.DataFrame(
columns=['libtype', 'solidx', 'soltimems', 'indtype', 'outdtype'])
for i in range(len(df)):
ds = df.loc[i, :]
gemmobj = Gemm(ds['M'],
Expand All @@ -279,7 +286,8 @@ def find_best_sols(self):
soldf['indtype'] = self.indtype
soldf['outdtype'] = self.outdtype
finaldf = pd.concat([self.gemm_problems, soldf], axis=1)
finaldf = pd.concat([finaldf, self.gdf])
finaldf['solidx'] = finaldf['solidx'].astype('int64')
if self.gdf is not None:
finaldf = pd.concat([finaldf, self.gdf])
finaldf['solidx'] = finaldf['solidx'].convert_dtypes('int64')
finaldf.to_csv(self.tuned_file, index=False)
print(finaldf)
2 changes: 1 addition & 1 deletion gradlib/gradlib/gemm_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import os
from pathlib import Path

import torch # isort: split
import hipbsolidxgemm
import pandas as pd
import rocsolidxgemm
import torch

from gradlib.GemmTuner import GemmTuner

Expand Down

0 comments on commit 52df169

Please sign in to comment.