Skip to content

Commit

Permalink
squash-patch changes
Browse files Browse the repository at this point in the history
move heuristic into C++ code

fix unit tests + format

update for 3.5.1

remove custom scheduler

codespell

cleanup comment

cleanup diff

review comments

review comments

review comment changes

review comments

fix codespell

cleanup util logic

make dim names for prepack layout more canoncial

missed refactor

wip

interleaving + recasting

tweak tolerances

comments plus interleaving

format

codespell

review comments

end2end first pass

seperate out kernels, format

add machete as a gptq backend

update to use  ModelWeightParameter

formatting

update parameter.py

refactor permute layout

wip
  • Loading branch information
LucasWilkinson committed Aug 20, 2024
1 parent 6e4658c commit ab9d0c0
Show file tree
Hide file tree
Showing 29 changed files with 978 additions and 571 deletions.
14 changes: 6 additions & 8 deletions csrc/cutlass_extensions/cute_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@ CUTE_HOST_DEVICE static constexpr bool is_identity_layout() {
else {
constexpr auto coalesced_layout = coalesce(Layout{});
if constexpr (rank(coalesced_layout) == 1 &&
stride<0>(coalesced_layout) == 1) {
stride<0>(coalesced_layout) == 1)
return true;
}
return false;
}
}
Expand All @@ -52,17 +51,16 @@ static constexpr auto get_logical_ptr(PointerType* ptr) {
template <typename T, typename Elements>
CUTE_HOST_DEVICE static constexpr auto create_auto_vectorizing_copy() {
constexpr auto bits = sizeof_bits_v<T> * Elements{};
if constexpr (bits % 128 == 0) {
if constexpr (bits % 128 == 0)
return AutoVectorizingCopyWithAssumedAlignment<128>{};
} else if constexpr (bits % 64 == 0) {
else if constexpr (bits % 64 == 0)
return AutoVectorizingCopyWithAssumedAlignment<64>{};
} else if constexpr (bits % 32 == 0) {
else if constexpr (bits % 32 == 0)
return AutoVectorizingCopyWithAssumedAlignment<32>{};
} else if constexpr (bits % 16 == 0) {
else if constexpr (bits % 16 == 0)
return AutoVectorizingCopyWithAssumedAlignment<16>{};
} else {
else
return AutoVectorizingCopyWithAssumedAlignment<8>{};
}
}

}; // namespace cute
7 changes: 3 additions & 4 deletions csrc/cutlass_extensions/torch_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ namespace detail {
template <class T, class F, class G, int... I>
CUTE_HOST_DEVICE constexpr auto tapply_with_idx(T&& t, F&& f, G&& g,
seq<I...>) {
return g(f(cute::get<I>(static_cast<T&&>(t)), I)...);
return g(f(get<I>(static_cast<T&&>(t)), I)...);
}

template <class F, int... I>
Expand All @@ -29,7 +29,7 @@ CUTE_HOST_DEVICE constexpr auto make_shape_from_idx(F&& f, seq<I...>) {

template <class T, class F>
CUTE_HOST_DEVICE constexpr auto transform_with_idx(T const& t, F&& f) {
if constexpr (cute::is_tuple<T>::value) {
if constexpr (is_tuple<T>::value) {
return detail::tapply_with_idx(
t, f, [](auto const&... a) { return cute::make_tuple(a...); },
tuple_seq<T>{});
Expand Down Expand Up @@ -72,9 +72,8 @@ static inline auto make_cute_layout(torch::Tensor const& tensor,
}
} else {
// Extra strides are assumed to be 0 or 1
if constexpr (cute::is_static_v<StrideEle>) {
if constexpr (cute::is_static_v<StrideEle>)
static_assert(StrideEle::value == 0 || StrideEle::value == 1);
}
return StrideEle{};
}
});
Expand Down
2 changes: 1 addition & 1 deletion csrc/cutlass_extensions/vllm_numeric_conversion.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ struct NumericArrayConverter<cutlass::bfloat16_t, vllm_uint4b8_t, N, Round> {
// Below constructs the following temporary:
uint32_t const prmt_indices[4] = {0xF4F0, 0xF5F1, 0xF6F2, 0xF7F3};
static_assert(RegArray::kElements <= 4,
"Too many inputs for uint4b8_t -> BF16 vector converter");
"Too many inputs for BF16 -> I4 vector converter");
CUTLASS_PRAGMA_UNROLL
for (int ii = 0; ii < RegArray::kElements; ++ii) {
asm volatile(
Expand Down
4 changes: 2 additions & 2 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ namespace machete {
std::vector<std::string> supported_schedules(
vllm::ScalarTypeTorchPtr const& btype);

torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
torch::Tensor gemm(torch::Tensor const A, torch::Tensor const B,
vllm::ScalarTypeTorchPtr const& btype,
c10::optional<torch::Tensor> const& scales,
c10::optional<torch::Tensor> const& zeros,
Expand All @@ -97,7 +97,7 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
c10::optional<double> alpha, c10::optional<double> beta,
c10::optional<std::string> schedule);

torch::Tensor prepack_B(torch::Tensor const& B,
torch::Tensor prepack_B(torch::Tensor const B,
vllm::ScalarTypeTorchPtr const& btype);

}; // namespace machete
Expand Down
2 changes: 0 additions & 2 deletions csrc/quantization/machete/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,6 @@ def create_sources(impl_config: ImplConfig, num_impl_files=2):


def generate():
# See csrc/quantization/machete/Readme.md, the Codegeneration for more info
# about how this works
SCRIPT_DIR = os.path.dirname(__file__)

schedules = [
Expand Down
Loading

0 comments on commit ab9d0c0

Please sign in to comment.