Skip to content

Commit

Permalink
Fix some build issues
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Dec 8, 2023
1 parent 84930ce commit ef3efba
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 5 deletions.
4 changes: 2 additions & 2 deletions csrc/causal_conv1d_bwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ void causal_conv1d_bwd_kernel(ConvParamsBwd params) {
constexpr int kWidth = Ktraits::kWidth;
constexpr int kNThreads = Ktraits::kNThreads;
constexpr bool kSiluAct = Ktraits::kSiluAct;
constexpr int kNElts = Ktraits::kNElts;
static constexpr int kNElts = Ktraits::kNElts;
constexpr int kNExchangeRounds = Ktraits::kNExchangeRounds;
constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
using input_t = typename Ktraits::input_t;
using vec_t = typename Ktraits::vec_t;
using weight_t = typename Ktraits::weight_t;
Expand Down
2 changes: 1 addition & 1 deletion csrc/causal_conv1d_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
constexpr int kWidth = Ktraits::kWidth;
constexpr int kNThreads = Ktraits::kNThreads;
constexpr int kNElts = Ktraits::kNElts;
constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
using input_t = typename Ktraits::input_t;
using vec_t = typename Ktraits::vec_t;
using weight_t = typename Ktraits::weight_t;
Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings
import os
import re
import shutil
import ast
from pathlib import Path
from packaging.version import parse, Version
Expand Down Expand Up @@ -147,7 +148,7 @@ def append_nvcc_threads(nvcc_extra_args):
+ cc_flag
),
},
include_dirs=[this_dir],
include_dirs=[Path(this_dir) / "csrc" / "causal_conv1d"],
)
)

Expand Down Expand Up @@ -216,7 +217,7 @@ def run(self):

wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
print("Raw wheel path", wheel_path)
os.rename(wheel_filename, wheel_path)
shutil.move(wheel_filename, wheel_path)
except urllib.error.HTTPError:
print("Precompiled wheel not found. Building from source...")
# If the wheel could not be downloaded, build from source
Expand Down

0 comments on commit ef3efba

Please sign in to comment.