diff --git a/csrc/causal_conv1d_bwd.cu b/csrc/causal_conv1d_bwd.cu index 6660975..2c37da6 100644 --- a/csrc/causal_conv1d_bwd.cu +++ b/csrc/causal_conv1d_bwd.cu @@ -49,9 +49,9 @@ void causal_conv1d_bwd_kernel(ConvParamsBwd params) { constexpr int kWidth = Ktraits::kWidth; constexpr int kNThreads = Ktraits::kNThreads; constexpr bool kSiluAct = Ktraits::kSiluAct; - constexpr int kNElts = Ktraits::kNElts; + static constexpr int kNElts = Ktraits::kNElts; constexpr int kNExchangeRounds = Ktraits::kNExchangeRounds; - constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; + static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; using input_t = typename Ktraits::input_t; using vec_t = typename Ktraits::vec_t; using weight_t = typename Ktraits::weight_t; diff --git a/csrc/causal_conv1d_fwd.cu b/csrc/causal_conv1d_fwd.cu index 74a1459..642b258 100644 --- a/csrc/causal_conv1d_fwd.cu +++ b/csrc/causal_conv1d_fwd.cu @@ -42,7 +42,7 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { constexpr int kWidth = Ktraits::kWidth; constexpr int kNThreads = Ktraits::kNThreads; constexpr int kNElts = Ktraits::kNElts; - constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; + static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; using input_t = typename Ktraits::input_t; using vec_t = typename Ktraits::vec_t; using weight_t = typename Ktraits::weight_t; diff --git a/setup.py b/setup.py index 84f5640..23b6625 100644 --- a/setup.py +++ b/setup.py @@ -3,6 +3,7 @@ import warnings import os import re +import shutil import ast from pathlib import Path from packaging.version import parse, Version @@ -147,7 +148,7 @@ def append_nvcc_threads(nvcc_extra_args): + cc_flag ), }, - include_dirs=[this_dir], + include_dirs=[Path(this_dir) / "csrc" / "causal_conv1d"], ) ) @@ -216,7 +217,7 @@ def run(self): wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") print("Raw wheel path", wheel_path) - os.rename(wheel_filename, wheel_path) + shutil.move(wheel_filename, wheel_path) except urllib.error.HTTPError: print("Precompiled wheel not found. Building from source...") # If the wheel could not be downloaded, build from source