Skip to content

Commit

Permalink
feat(CI): Add configuration for static code analysis and C++ code for…
Browse files Browse the repository at this point in the history
…matting. (microsoft#26)

* add configuration for CI.

* fix static code check.
  • Loading branch information
lcy-seso authored Dec 26, 2024
1 parent 6ed8859 commit 87b1f6d
Show file tree
Hide file tree
Showing 18 changed files with 298 additions and 107 deletions.
30 changes: 30 additions & 0 deletions .github/workflows/clang-format.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
name: clang-format

on:
push:
branches:
- master
pull_request:
branches:
- master

jobs:
clang-format:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install clang-format==18.1.5
- name: Running clang-format for first party source codes
run: |
find include src tests/cpp examples/cpp -type f \( -name '*.h' -o name '*.hpp' -o -name '*.cc' -o -name '*.cu' -o -name '*.cuh' \) -print \
| xargs clang-format --dry-run --Werror
41 changes: 41 additions & 0 deletions .github/workflows/ruff.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: ruff

on:
push:
branches:
- master
pull_request:
branches:
- master

jobs:
ruff:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.11"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install yapf==0.32.0 toml==0.10.2 tomli==2.0.1
pip install ruff==0.6.5 codespell==2.3.0
pip install isort==5.13.2 clang-format==18.1.5
- name: Analysing the code with ruff
run: |
ruff check .
- name: Spelling check with codespell
run: |
codespell --toml pyproject.toml
- name: Run isort
run: |
isort . --check-only
- name: Running yapf
run: |
echo "please run \" yapf --recursive .\" if errors"
yapf --diff --recursive .
1 change: 1 addition & 0 deletions examples/python/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[TBD]
57 changes: 33 additions & 24 deletions examples/python/gemm/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
"Compile",
]

cutlass_include_dir = os.path.join(os.path.dirname(__file__),
"../../../3rd-party/cutlass/include")
tilefusion_include_dir = os.path.join(os.path.dirname(__file__),
"../../../include/")
cutlass_include_dir = os.path.join(
os.path.dirname(__file__), "../../../3rd-party/cutlass/include"
)
tilefusion_include_dir = os.path.join(
os.path.dirname(__file__), "../../../include/"
)
csrc_include_dir = os.path.join(os.path.dirname(__file__), "csrc")


Expand All @@ -45,9 +47,9 @@ def py_str(x):
return os.environ["CUDA_PATH"]

cmd = ["which", "nvcc"]
proc = subprocess.Popen(cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT)
proc = subprocess.Popen(
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT
)
(out, _) = proc.communicate()

if proc.returncode == 0:
Expand All @@ -67,8 +69,9 @@ def _create_entry_code(
warp_per_col: int,
):
entry_code_path = "entry.py"
spec = importlib.util.spec_from_file_location("entry_code",
entry_code_path)
spec = importlib.util.spec_from_file_location(
"entry_code", entry_code_path
)
foo = importlib.util.module_from_spec(spec)
spec.loader.exec_module(foo)

Expand All @@ -84,27 +87,32 @@ def _create_entry_code(

return foo.types.format_map(shape) + foo.entry

def compile(self,
M: int,
N: int,
K: int,
TM: int,
TN: int,
kChunkK: int,
warp_per_row: int,
warp_per_col: int,
timeout: float = None):
def compile(
self,
M: int,
N: int,
K: int,
TM: int,
TN: int,
kChunkK: int,
warp_per_row: int,
warp_per_col: int,
timeout: float = None
):
temp_dir = self.tmp_dir

file_name = (f"{self.file_prefix}_{M}_{N}_{K}"
f"_{TM}_{TN}_{warp_per_row}_{warp_per_col}")
file_name = (
f"{self.file_prefix}_{M}_{N}_{K}"
f"_{TM}_{TN}_{warp_per_row}_{warp_per_col}"
)
lib_path = os.path.join(temp_dir, f"{file_name}.so")

if os.path.exists(lib_path):
return lib_path

entry_code = self._create_entry_code(M, N, K, TM, TN, kChunkK,
warp_per_row, warp_per_col)
entry_code = self._create_entry_code(
M, N, K, TM, TN, kChunkK, warp_per_row, warp_per_col
)

source_path = os.path.join(temp_dir, f"{file_name}.cu")
with open(source_path, "w") as f:
Expand Down Expand Up @@ -137,5 +145,6 @@ def apply(self, lib_path, torch_array: list, device: int):
torch.cuda.set_device(device)

ret = lib.kernel_entry(
*[ctypes.c_void_p(arr.data_ptr()) for arr in torch_array])
*[ctypes.c_void_p(arr.data_ptr()) for arr in torch_array]
)
return ret
5 changes: 3 additions & 2 deletions examples/python/gemm/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,9 @@ def forward(
warp_per_col: int,
) -> Tensor:
builder = Compile(file_prefix="gemm", tmp_dir="tmp")
lib_name = builder.compile(M, N, K, kM, kN, kChunkK, warp_per_row,
warp_per_col)
lib_name = builder.compile(
M, N, K, kM, kN, kChunkK, warp_per_row, warp_per_col
)

if lib_name is None:
raise RuntimeError("Failed to compile the library.")
Expand Down
70 changes: 39 additions & 31 deletions examples/python/gemm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,20 @@
from torch import Tensor


def run_unittest(a: Tensor,
b: Tensor,
c: Tensor,
M: int,
N: int,
K: int,
TM: int,
TN: int,
kChunkK: int,
warp_layout: Tuple,
epsilon: float = 5e-2,
debug_print=False):
def run_unittest(
a: Tensor,
b: Tensor,
c: Tensor,
M: int,
N: int,
K: int,
TM: int,
TN: int,
kChunkK: int,
warp_layout: Tuple,
epsilon: float = 5e-2,
debug_print=False
):
gemm_func(a, b, c, M, N, K, TM, TN, kChunkK, *warp_layout)
ref_c = a @ b.t()

Expand All @@ -33,10 +35,7 @@ def run_unittest(a: Tensor,
print(ref_c)

avg_diff = (torch.sum(torch.abs(ref_c - c)) / (M * N)).item()
if avg_diff > epsilon:
return False
else:
return True
return not avg_diff > epsilon


def run_test(
Expand All @@ -60,7 +59,7 @@ def run_test(

for _ in range(5): # warm up
gemm_func(a, b, c, M, N, K, TM, TN, kChunkK, *warp_layout)
ref_c = a @ b.t()
a @ b.t()

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
Expand All @@ -75,8 +74,8 @@ def run_test(
time1 = start_event.elapsed_time(end_event) / iters

start_event.record()
for i in range(iters):
ref_c = a @ b.t()
for _ in range(iters):
a @ b.t()
end_event.record()
torch.cuda.synchronize()

Expand All @@ -89,29 +88,38 @@ def run_test(
N = 4096
K = 4096

print(("Whole Shape\tBlock Shape\tthreads"
"\ttilefusion(ms)\tcublass(ms)\tRatio"))
print((
"Whole Shape\tBlock Shape\tthreads"
"\ttilefusion(ms)\tcublass(ms)\tRatio"
))

warp_layout = (1, 2)
threads = warp_layout[0] * warp_layout[1] * 32
for TM in [64, 128]:
for TN in [64, 128]:
for kChunkK in [32, 64, 128]:
time1, time2 = run_test(M, N, K, TM, TN, kChunkK, warp_layout)
print(("[{}, {}, {}]\t[{}, {}, {}]"
"\t{}\t{:.4f}\t{:.4f}\t{:.3f}").format(
M, N, K, TM, TN, kChunkK, threads, time1, time2,
time1 / time2))
print((
"[{}, {}, {}]\t[{}, {}, {}]"
"\t{}\t{:.4f}\t{:.4f}\t{:.3f}"
).format(
M, N, K, TM, TN, kChunkK, threads, time1, time2,
time1 / time2
))

for warp_layout in [(2, 2), (2, 4)]:
threads = warp_layout[0] * warp_layout[1] * 32

for TM in [64, 128, 256]:
for TN in [64, 128, 256]:
for kChunkK in [32, 64, 128]:
time1, time2 = run_test(M, N, K, TM, TN, kChunkK,
warp_layout)
print(("[{}, {}, {}]\t[{}, {}, {}]"
"\t{}\t{:.4f}\t{:.4f}\t{:.3f}").format(
M, N, K, TM, TN, kChunkK, threads, time1, time2,
time1 / time2))
time1, time2 = run_test(
M, N, K, TM, TN, kChunkK, warp_layout
)
print((
"[{}, {}, {}]\t[{}, {}, {}]"
"\t{}\t{:.4f}\t{:.4f}\t{:.3f}"
).format(
M, N, K, TM, TN, kChunkK, threads, time1, time2,
time1 / time2
))
2 changes: 0 additions & 2 deletions include/cell/copy/constants.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ enum class CopyInst {
};

enum class WarpReuse {
// TODO(haruhi): It seems that Cir/RowReuseCir/ColReuseCir are not ncessary,
// thus the reuse mode can be simplified.
// data are evenly partitioned to be loaded by warps.
kCont = 0, // all warps continuously load data, no reuse
kRowReuseCont = 1, // Row-wise even reuse, warps in the same row
Expand Down
6 changes: 3 additions & 3 deletions include/cell/copy/copy_atom.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ namespace tl = tile_layout;
using namespace cute;

template <typename Element>
requires std::is_same_v<Element, __half> ||
std::is_same_v<Element, cutlass::half_t>
requires std::is_same_v<Element, __half> ||
std::is_same_v<Element, cutlass::half_t>
struct LoadMatBase {
using DType = Element;
using ThreadLayout = tile_layout::ColMajor<16, 2>;
Expand Down Expand Up @@ -298,7 +298,7 @@ template <class Global, class Shared>
struct GlobalToSharedBaseTileLoader<Global, Shared, tl::Layout::kRowMajor> {
using DType = Shared::DType;

// NOTE: Please keep this thread layout striclty consistent with the thread
// NOTE: Please keep this thread layout strictly consistent with the thread
// layout for ldmatrix.
// The macro kernel breaks down the entire copy operation into iterations
// over 16x16 BaseTiles. To transfer a single BaseTile, threads in a warp
Expand Down
2 changes: 1 addition & 1 deletion include/cell/copy/shared_to_register.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ struct SharedToRegLoader {
};

/// @brief partial specialization for 16x16x16 wmma's output, and st.shared.f32
/// to revert the data distrubution into an comphrehensive row-major
/// to revert the data distribution into an comprehensive row-major
/// matrix.
template <typename Reg_, typename WarpLayout_>
struct RegToSharedStorer {
Expand Down
18 changes: 7 additions & 11 deletions include/cell/copy/warp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,9 @@ struct WarpOffsetHelper<WarpReuse::kRowReuseCont, kRowStride_, kColStride_> {
};
} // namespace

/*
* @brief In a thread block, warps are organized as 2-D matrices, each with
* a row index and a column index. Given `threadIdx.x`, this function
* calculates the row index of the current thread.
*/
// @brief In a thread block, warps are organized as 2-D matrices, each with
// a row index and a column index. Given `threadIdx.x`, this function
// calculates the row index of the current thread.
template <typename WarpLayout>
DEVICE int warp_row_id() {
/*
Expand Down Expand Up @@ -91,11 +89,9 @@ DEVICE int warp_row_id() {
}
}

/*
* @brief In a thread block, warps are organized as 2-D matrices, each with
* a row index and a column index. Given `threadIdx.x`, this function
* calculates the column index of the current thread.
*/
// @brief In a thread block, warps are organized as 2-D matrices, each with
// a row index and a column index. Given `threadIdx.x`, this function
// calculates the column index of the current thread.
template <typename WarpLayout>
DEVICE int warp_col_id() {
/*
Expand Down Expand Up @@ -179,7 +175,7 @@ struct ExecCounter {
static constexpr int kColExec = col_exec_count();
};

/// @brief Determine the automic shape of a single warp based on the shape of
/// @brief Determine the automatic shape of a single warp based on the shape of
/// the entire tile. The final warp tile shape is multiple of this atomic
/// shape.
template <typename DType, typename TileLayout, const tl::Layout kType>
Expand Down
4 changes: 2 additions & 2 deletions include/types/layout.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ using RowMajor = MatrixLayout<kRow, kCol, kStride, 1>;
template <const int kRow, const int kCol, const int kStride = kRow>
using ColMajor = MatrixLayout<kRow, kCol, 1, kStride>;

/// @brief: Wapper for creating non-swizzled or swizzled shared memory layout.
/// @brief: Wrapper for creating non-swizzled or swizzled shared memory layout.
template <typename Shared, const int kBitsPerAccess>
struct SharedLayoutWrapper {
using Layout =
Expand All @@ -356,7 +356,7 @@ static constexpr size_t get_numel = Layout::kNumel;
// We wrap CuTe's `Layout`, which consists of `Shape` and `Stride`, into an
// intelligent row-major or column-major layout. In a row-major layout, the
// column stride is 1, whereas in a column-major layout, the row stride is 1.
// NOTE: A potential issue is that `ColMajor<1, 1>` will also be indentified as
// NOTE: A potential issue is that `ColMajor<1, 1>` will also be identified as
// a row-major layout.
template <typename Layout_>
static constexpr Layout layout_type = Layout_::kType;
Expand Down
Loading

0 comments on commit 87b1f6d

Please sign in to comment.