-
Notifications
You must be signed in to change notification settings - Fork 11
/
Copy pathpackage_specific.sh
62 lines (52 loc) · 2.2 KB
/
package_specific.sh
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
#! /bin/bash
set -eu -o pipefail
SCRIPT_DIR=${BASH_SOURCE%/*}
if [[ $REPO == "facebookresearch/Mask2Former" ]] \
&& [[ $COMPUTE_PLATFORM == "cpu" ]]; then
patch -p0 < "$SCRIPT_DIR"/package_specific/Mask2Former_cpu.patch
fi
if { [[ $REPO == "rusty1s/pytorch_cluster" ]] || [[ $REPO == "facebookresearch/fairseq" ]]; } \
&& [[ $OS == "Windows" ]] \
&& [[ ${TORCH_VERSION:0:4} == "1.12" ]] \
&& [[ $COMPUTE_PLATFORM == "cu116" ]]; then
# Fixes https://github.com/facebookresearch/pytorch3d/issues/1024
# shellcheck disable=SC2154
TORCH_PYBIND_DIR="$Python_ROOT_DIR/lib/site-packages/torch/include/pybind11"
patch -d "$TORCH_PYBIND_DIR" < "$SCRIPT_DIR"/package_specific/torch_pybind_cast_h.patch
fi
if [[ $REPO == "facebookresearch/pytorch3d" ]] || [[ $REPO == "facebookresearch/fairseq" ]]; then
CUB_VERSION=""
if [[ $OS == "Windows" ]] \
&& [[ $REPO == "facebookresearch/pytorch3d" ]] \
&& { [[ $COMPUTE_PLATFORM == "cu117" ]] || [[ $COMPUTE_PLATFORM == "cu118" ]] || [[ $COMPUTE_PLATFORM == "cu121" ]]; }; then
CUB_VERSION="1.17.2"
fi
if [[ $OS == "Linux" ]] \
&& { [[ $COMPUTE_PLATFORM == "cu102" ]] || [[ $COMPUTE_PLATFORM == "cu113" ]]; }; then
CUB_VERSION="1.10.0"
fi
if [ -n "${CUB_VERSION}" ]; then
mkdir cub
curl -L https://github.com/NVIDIA/cub/archive/${CUB_VERSION}.tar.gz | tar -xzf - --strip-components=1 --directory cub
echo "CUB_HOME=$PWD/cub" >> "$GITHUB_ENV"
fi
fi
if [[ $REPO == "facebookresearch/pytorch3d" ]] \
&& [[ $OS == "Linux" ]] \
&& [[ $COMPUTE_PLATFORM == "cu102" ]]; then
patch -p0 < "$SCRIPT_DIR"/package_specific/pytorch3d_cpp14.patch
fi
if [[ $REPO == "facebookresearch/fairseq" ]]; then
pip install cython
patch -p0 < "$SCRIPT_DIR"/package_specific/fairseq_cub.patch
fi
if [[ $REPO == "open-mmlab/mmcv" ]] \
&& [[ $TORCH_VERSION == "1.12.1" ]] \
&& [[ $COMPUTE_PLATFORM == "cu102" ]]; then
patch -p0 < "$SCRIPT_DIR"/package_specific/mmcv_cpp14.patch
fi
if [[ $REPO == "NVlabs/tiny-cuda-nn" ]]; then
source "$SCRIPT_DIR"/.github/workflows/cuda/${OS}_env.sh
echo "LIBRARY_PATH=/usr/local/cuda/lib64/stubs" >> "$GITHUB_ENV"
echo "TCNN_CUDA_ARCHITECTURES=${TORCH_CUDA_ARCH_LIST}" | sed "s/\(\.\|\+PTX\)//g" >> "$GITHUB_ENV"
fi