Skip to content

Commit

Permalink
Refactor(cell): Enhance the implementation of warp tile offset calcul…
Browse files Browse the repository at this point in the history
…ation. (microsoft#22)

* fix global offset calculation.

* reduce redundant codes.

* clean the implementation.
  • Loading branch information
lcy-seso authored Dec 25, 2024
1 parent 1afb9e5 commit 6ed8859
Show file tree
Hide file tree
Showing 15 changed files with 345 additions and 500 deletions.
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@ repos:
rev: v0.32.0
hooks:
- id: yapf
additional_dependencies: [toml]
files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
Expand Down
7 changes: 3 additions & 4 deletions examples/python/gemm/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

import os
import ctypes
import importlib.util
import shutil
import os
import subprocess
from collections import defaultdict

import subprocess
import ctypes
import torch

__all__ = [
Expand Down
3 changes: 1 addition & 2 deletions examples/python/gemm/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@
# --------------------------------------------------------------------------

import torch
from torch import Tensor

from compile import Compile
from torch import Tensor

__all__ = [
"gemm_func",
Expand Down
4 changes: 2 additions & 2 deletions examples/python/gemm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
# Licensed under the MIT License.
# --------------------------------------------------------------------------

import torch
from torch import Tensor
from typing import Tuple

import torch
from gemm import gemm_func
from torch import Tensor


def run_unittest(a: Tensor,
Expand Down
9 changes: 2 additions & 7 deletions include/cell/copy/constants.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include "config.hpp"

namespace tilefusion::cell::copy {

enum class CopyInst {
kLoadMat = 0, // ldmatrix for loading data from shared memory to register.
kStoreMat = 1, // stmatrix for storing data from register to shared memory.
Expand All @@ -19,13 +18,9 @@ enum class WarpReuse {
// thus the reuse mode can be simplified.
// data are evenly partitioned to be loaded by warps.
kCont = 0, // all warps continuously load data, no reuse
kCir = 1, // all warps circularly load data, no reuse
kRowReuseCont = 2, // Row-wise even reuse, warps in the same row
kRowReuseCont = 1, // Row-wise even reuse, warps in the same row
// repeatedly load the same data
kRowReuseCir = 3, // Row-wise circular reuse
kColReuseCont = 4, // Column-wise even reuse, warps in the same column
kColReuseCont = 2 // Column-wise even reuse, warps in the same column
// repeatedly load the same data
kColReuseCir = 5 // Column-wise circular reuse
};

} // namespace tilefusion::cell::copy
59 changes: 27 additions & 32 deletions include/cell/copy/global_to_register.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -303,36 +303,34 @@ struct RegToGlobalStorerImpl<Global_, Reg_, kRowExec_, kColExec_,
* @tparam kMode_ Warp reuse mode.
* @tparam Base Copy base.
*/
template <typename Reg_, typename WarpLayout_, const WarpReuse kMode_,
typename Base = warp::CopyBase<WarpLayout_, kMode_>>
struct GlobalToRegLoader : public Base {
template <typename Reg_, typename WarpLayout_, const WarpReuse kMode_>
struct GlobalToRegLoader {
using Reg = Reg_;
using DType = typename Reg::DType::DType;
using BaseShape = BaseTileShape<DType>;

using WarpLayout = WarpLayout_;
static constexpr WarpReuse kMode = kMode_;

// how many times a `BaseTile` is executed along the row and column
// direction.
static constexpr int kRowExec = Reg::kRows;
static constexpr int kColExec = Reg::kCols;

template <typename Global>
DEVICE void operator()(const Global& src, Reg& dst) {
const DType* src_ptr = src.data();

// 1. advance the pointer to input data to the current warp
// according to warp reuse mode.
src_ptr += Base::template get_warp_offset<Global>();

// how many times a `BaseTile` is executed along the row and column
// direction.
static constexpr int kRowExec =
Base::template row_exec_count<BaseShape, Global::kRows>();
static constexpr int kColExec =
Base::template col_exec_count<BaseShape, Global::kCols>();
int offset = global_offset_.template get_warp_offset<Global>();

using Loader = GlobalToRegLoaderImpl<Global, Reg, kRowExec, kColExec,
Global::kType>;
Loader loader;
loader(src_ptr, dst);
loader(src.data() + offset, dst);
}

private:
using GlobalOffset = warp::GlobalOffsetHelper<WarpLayout, WarpReuse::kCont>;

GlobalOffset global_offset_;
};

/**
Expand All @@ -342,37 +340,34 @@ struct GlobalToRegLoader : public Base {
* @tparam Reg_ Register tile type.
* @tparam WarpLayout_ Warp layout type.
* @tparam kMode_ Warp reuse mode.
* @tparam Base Copy base.
*/
template <typename Global_, typename Reg_, typename WarpLayout_,
typename Base = warp::CopyBase<WarpLayout_, WarpReuse::kCont>>
struct RegToGlobalStorer : public Base {
template <typename Global_, typename Reg_, typename WarpLayout_>
struct RegToGlobalStorer {
using Global = Global_;
using Reg = Reg_;
using DType = typename Global::DType;
using BaseShape = BaseTileShape<DType>;

using WarpLayout = WarpLayout_;

// how many times a `BaseTile` is executed along the row and column
// direction.
static constexpr int kRowExec = Reg::kRows;
static constexpr int kColExec = Reg::kCols;

DEVICE void operator()(const Reg& src, Global& dst) {
DType* dst_ptr = dst.mutable_data();

// 1. advance the pointer to output data to the current warp
// according to warp reuse mode.
dst_ptr += Base::template get_warp_offset<Global>();

// how many times a `BaseTile` is executed along the row and column
// direction.
static constexpr int kRowExec =
Base::template row_exec_count<BaseShape, Global::kRows>();
static constexpr int kColExec =
Base::template col_exec_count<BaseShape, Global::kCols>();
int offset = global_offset_.template get_warp_offset<Global>();

using Storer = RegToGlobalStorerImpl<Global, Reg, kRowExec, kColExec,
Global::kType>;
Storer storer;
storer(src, dst_ptr);
storer(src, dst_ptr + offset);
}
};

using GlobalOffset = warp::GlobalOffsetHelper<WarpLayout, WarpReuse::kCont>;

GlobalOffset global_offset_;
};
} // namespace tilefusion::cell::copy
86 changes: 46 additions & 40 deletions include/cell/copy/global_to_shared.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,44 +253,47 @@ struct SharedToGlobalStorerImpl<Shared_, Global_, kRowExec_, kColExec_,
/// @brief The thread-block level API that cooperatively transfers a data tile
/// from global memory to shared memory by all the threads within a
/// thread block.
template <typename Shared_, typename WarpLayout_,
typename Base = warp::CopyBase<WarpLayout_, WarpReuse::kCont>>
struct GlobalToSharedLoader : public Base {
template <typename Shared_, typename WarpLayout_>
struct GlobalToSharedLoader {
using Shared = Shared_;
using DType = Shared::DType;
using WarpLayout = WarpLayout_;

// TODO(ying): The atomic tile shape that a single warp loads. The atomic
// tile shape should be automatically determined to choose the best
// performance.
using BaseShape = traits::BaseTileShape<DType>;
// FIXME(ying): automatically infer the warp-level tile shape instead
// of using a fixed `BaseShape`.
// using WarpShape =
// warp::WarpTileShape<DType, typename Shared::Layout, Shared::kType>;

using WarpShape = traits::BaseTileShape<DType>;
static_assert(Shared::kRows % WarpShape::kRows == 0,
"Shared::kRows must be divisible by WarpShape::kRows.");
static_assert(Shared::kCols % WarpShape::kCols == 0,
"Shared::kCols must be divisible by WarpShape::kCols.");

static_assert((Shared::kSwizzled && sizeof(DType) == 2 ||
Shared::kSwizzled == false),
"Not implemented for swizzled layout with 4-byte data types "
"(single precision floating point).");
static const WarpReuse kMode = WarpReuse::kCont; // warp reuse mode
using ExecCounter = warp::ExecCounter<WarpShape, Shared, WarpLayout, kMode>;
using GlobalOffset = warp::GlobalOffsetHelper<WarpLayout, kMode>;
using SharedOffset =
warp::SharedOffsetHelper<WarpLayout, WarpShape, Shared, kMode>;

static constexpr int kRowExec =
Base::template row_exec_count<BaseShape, Shared::kRows>();
static constexpr int kColExec =
Base::template col_exec_count<BaseShape, Shared::kCols>();
static constexpr int kRowExec = ExecCounter::kRowExec;
static constexpr int kColExec = ExecCounter::kColExec;

static_assert(kRowExec && kColExec,
"Execution count should be greater than 0.");
"Ensure that the execution count for all "
"rows and columns is greater than 0.");

template <typename Global>
DEVICE void operator()(const Global& src, Shared& dst) {
static_assert(Shared::kNumel == Global::kNumel,
"Global and shared memory should have the same shape.");
static_assert(
Global::kRows == Shared::kRows && Global::kCols == Shared::kCols,
"Global and shared memory should have the same shape.");

const DType* src_ptr = src.data();
DType* dst_ptr = dst.mutable_data();

int offset_src = Base::template get_warp_offset<Global>(); // global
int offset_dst = offset_helper_.get_warp_offset(); // shared
int offset_src = global_offset_.template get_warp_offset<Global>();
int offset_dst = shared_offset_.get_warp_offset();

using Loader = GlobalToSharedLoaderImpl<Global, Shared, kRowExec,
kColExec, Shared::kType>;
Expand All @@ -300,28 +303,32 @@ struct GlobalToSharedLoader : public Base {
}

private:
using OffsetHelper =
warp::SharedOffsetHelper<WarpLayout, WarpReuse::kCont, Shared>;
OffsetHelper offset_helper_;
GlobalOffset global_offset_;
SharedOffset shared_offset_;
};

template <typename Shared_, typename WarpLayout_,
typename Base = warp::CopyBase<WarpLayout_, WarpReuse::kCont>>
struct SharedToGlobalStorer : public Base {
template <typename Shared_, typename WarpLayout_>
struct SharedToGlobalStorer {
using Shared = Shared_;
using DType = Shared::DType;
using WarpLayout = WarpLayout_;
using BaseShape = traits::BaseTileShape<DType>;

static_assert(Shared::kRows % BaseShape::kRows == 0,
"Shared::kRows must be divisible by BaseShape::kRows.");
static_assert(Shared::kCols % BaseShape::kCols == 0,
"Shared::kCols must be divisible by BaseShape::kCols.");
// FIXME(ying): automatically infer the warp-level tile shape instead
// of using a fixed `BaseShape`.
using WarpShape = traits::BaseTileShape<DType>;
static_assert(Shared::kRows % WarpShape::kRows == 0,
"Shared::kRows must be divisible by WarpShape::kRows.");
static_assert(Shared::kCols % WarpShape::kCols == 0,
"Shared::kCols must be divisible by WarpShape::kCols.");

static const WarpReuse kMode = WarpReuse::kCont; // warp reuse mode
using SharedOffset =
warp::SharedOffsetHelper<WarpLayout, WarpShape, Shared, kMode>;
using GlobalOffset = warp::GlobalOffsetHelper<WarpLayout, kMode>;
using ExecCounter = warp::ExecCounter<WarpShape, Shared, WarpLayout, kMode>;

static constexpr int kRowExec =
Shared::kRows / BaseShape::kRows / tl::num_rows<WarpLayout>;
static constexpr int kColExec =
Shared::kCols / BaseShape::kCols / tl::num_cols<WarpLayout>;
static constexpr int kRowExec = ExecCounter::kRowExec;
static constexpr int kColExec = ExecCounter::kColExec;

static_assert(kRowExec && kColExec,
"Execution count should be greater than 0.");
Expand All @@ -331,8 +338,8 @@ struct SharedToGlobalStorer : public Base {
const DType* src = src_.data();
DType* dst = dst_.mutable_data();

int offset_src = offset_helper_.get_warp_offset(); // shared
int offset_dst = Base::template get_warp_offset<Global>(); // global
int offset_src = shared_offset_.get_warp_offset();
int offset_dst = global_offset_.template get_warp_offset<Global>();

using Storer = SharedToGlobalStorerImpl<Shared, Global, kRowExec,
kColExec, Shared::kType>;
Expand All @@ -342,8 +349,7 @@ struct SharedToGlobalStorer : public Base {
}

private:
using OffsetHelper =
warp::SharedOffsetHelper<WarpLayout, WarpReuse::kCont, Shared>;
OffsetHelper offset_helper_;
SharedOffset shared_offset_;
GlobalOffset global_offset_;
};
} // namespace tilefusion::cell::copy
Loading

0 comments on commit 6ed8859

Please sign in to comment.