Skip to content

Commit

Permalink
[Feature] Compiled and cudagraph for policies
Browse files Browse the repository at this point in the history
ghstack-source-id: aab4403c9dcc4f0692f48304d1781ac7ac9e6497
Pull Request resolved: #2478
  • Loading branch information
vmoens committed Oct 11, 2024
1 parent 205b83d commit c1c2e84
Show file tree
Hide file tree
Showing 8 changed files with 273 additions and 145 deletions.
17 changes: 12 additions & 5 deletions .github/unittest/linux/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
set -euxo pipefail
set -v

# ==================================================================================== #
# ================================ Setup env ========================================= #
# =============================================================================== #
# ================================ Init ========================================= #


if [[ $OSTYPE != 'darwin'* ]]; then
Expand All @@ -31,6 +31,10 @@ if [[ $OSTYPE != 'darwin'* ]]; then
cp $this_dir/10_nvidia.json /usr/share/glvnd/egl_vendor.d/10_nvidia.json
fi


# ==================================================================================== #
# ================================ Setup env ========================================= #

# Avoid error: "fatal: unsafe repository"
git config --global --add safe.directory '*'
root_dir="$(git rev-parse --show-toplevel)"
Expand Down Expand Up @@ -61,7 +65,7 @@ if [ ! -d "${env_dir}" ]; then
fi
conda activate "${env_dir}"

# 4. Install Conda dependencies
# 3. Install Conda dependencies
printf "* Installing dependencies (except PyTorch)\n"
echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml"
cat "${this_dir}/environment.yml"
Expand Down Expand Up @@ -185,7 +189,9 @@ fi

export PYTORCH_TEST_WITH_SLOW='1'
python -m torch.utils.collect_env
# Avoid error: "fatal: unsafe repository"
## Avoid error: "fatal: unsafe repository"
#git config --global --add safe.directory '*'
#root_dir="$(git rev-parse --show-toplevel)"

# solves ImportError: /lib64/libstdc++.so.6: version `GLIBCXX_3.4.21' not found
#export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$lib_dir
Expand All @@ -202,7 +208,8 @@ if [ "${CU_VERSION:-}" != cpu ] ; then
--timeout=120 --mp_fork_if_no_cuda
else
python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py --ignore test/test_distributed.py \
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \
--ignore test/test_distributed.py \
--timeout=120 --mp_fork_if_no_cuda
fi

Expand Down
36 changes: 0 additions & 36 deletions .github/unittest/linux_optdeps/scripts/install.sh

This file was deleted.

126 changes: 121 additions & 5 deletions .github/unittest/linux_optdeps/scripts/run_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

set -euxo pipefail
set -v
set -e

# ==================================================================================== #
# ================================ Init ============================================== #
# =============================================================================== #
# ================================ Init ========================================= #


if [[ $OSTYPE != 'darwin'* ]]; then
Expand Down Expand Up @@ -35,18 +36,133 @@ fi
# ==================================================================================== #
# ================================ Setup env ========================================= #

bash ${this_dir}/setup_env.sh
# Avoid error: "fatal: unsafe repository"
git config --global --add safe.directory '*'
root_dir="$(git rev-parse --show-toplevel)"
conda_dir="${root_dir}/conda"
env_dir="${root_dir}/env"
lib_dir="${env_dir}/lib"

cd "${root_dir}"

case "$(uname -s)" in
Darwin*) os=MacOSX;;
*) os=Linux
esac

# 1. Install conda at ./conda
if [ ! -d "${conda_dir}" ]; then
printf "* Installing conda\n"
wget -O miniconda.sh "http://repo.continuum.io/miniconda/Miniconda3-latest-${os}-x86_64.sh"
bash ./miniconda.sh -b -f -p "${conda_dir}"
fi
eval "$(${conda_dir}/bin/conda shell.bash hook)"

# 2. Create test environment at ./env
printf "python: ${PYTHON_VERSION}\n"
if [ ! -d "${env_dir}" ]; then
printf "* Creating a test environment\n"
conda create --prefix "${env_dir}" -y python="$PYTHON_VERSION"
fi
conda activate "${env_dir}"

# 3. Install Conda dependencies
printf "* Installing dependencies (except PyTorch)\n"
echo " - python=${PYTHON_VERSION}" >> "${this_dir}/environment.yml"
cat "${this_dir}/environment.yml"

pip3 install pip --upgrade

conda env update --file "${this_dir}/environment.yml" --prune

# ============================================================================================ #
# ================================ PyTorch & TorchRL ========================================= #

bash ${this_dir}/install.sh
unset PYTORCH_VERSION

if [ "${CU_VERSION:-}" == cpu ] ; then
version="cpu"
echo "Using cpu build"
else
if [[ ${#CU_VERSION} -eq 4 ]]; then
CUDA_VERSION="${CU_VERSION:2:1}.${CU_VERSION:3:1}"
elif [[ ${#CU_VERSION} -eq 5 ]]; then
CUDA_VERSION="${CU_VERSION:2:2}.${CU_VERSION:4:1}"
fi
echo "Using CUDA $CUDA_VERSION as determined by CU_VERSION ($CU_VERSION)"
version="$(python -c "print('.'.join(\"${CUDA_VERSION}\".split('.')[:2]))")"
fi

# submodules
git submodule sync && git submodule update --init --recursive

printf "Installing PyTorch with %s\n" "${CU_VERSION}"
if [[ "$TORCH_VERSION" == "nightly" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu -U
else
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/$CU_VERSION -U
fi
elif [[ "$TORCH_VERSION" == "stable" ]]; then
if [ "${CU_VERSION:-}" == cpu ] ; then
pip3 install torch --index-url https://download.pytorch.org/whl/cpu -U
else
pip3 install torch --index-url https://download.pytorch.org/whl/$CU_VERSION -U
fi
else
printf "Failed to install pytorch"
exit 1
fi

# smoke test
python -c "import functorch"

## install snapshot
#if [[ "$TORCH_VERSION" == "nightly" ]]; then
# pip3 install git+https://github.com/pytorch/torchsnapshot
#else
# pip3 install torchsnapshot
#fi

# install tensordict
if [[ "$RELEASE" == 0 ]]; then
pip3 install git+https://github.com/pytorch/tensordict.git
else
pip3 install tensordict
fi

printf "* Installing torchrl\n"
python setup.py develop

# smoke test
python -c "import torchrl"

# ==================================================================================== #
# ================================ Run tests ========================================= #


bash ${this_dir}/run_test.sh
# find libstdc
STDC_LOC=$(find conda/ -name "libstdc++.so.6" | head -1)

export PYTORCH_TEST_WITH_SLOW='1'
export LAZY_LEGACY_OP=False
python -m torch.utils.collect_env
# Avoid error: "fatal: unsafe repository"
git config --global --add safe.directory '*'
root_dir="$(git rev-parse --show-toplevel)"

export MKL_THREADING_LAYER=GNU
export CKPT_BACKEND=torch
export MAX_IDLE_COUNT=100
export BATCHED_PIPE_TIMEOUT=60

python .github/unittest/helpers/coverage_run_parallel.py -m pytest test \
--instafail --durations 200 -vv --capture no --ignore test/test_rlhf.py \
--ignore test/test_distributed.py \
--timeout=120 --mp_fork_if_no_cuda

coverage combine
coverage xml -i

# ==================================================================================== #
# ================================ Post-proc ========================================= #
Expand Down
25 changes: 0 additions & 25 deletions .github/unittest/linux_optdeps/scripts/run_test.sh

This file was deleted.

46 changes: 0 additions & 46 deletions .github/unittest/linux_optdeps/scripts/setup_env.sh

This file was deleted.

9 changes: 3 additions & 6 deletions .github/workflows/test-linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,8 @@ jobs:
tests-optdeps:
strategy:
matrix:
python_version: ["3.10"] # "3.9", "3.10", "3.11"
cuda_arch_version: ["12.1"] # "11.6", "11.7"
python_version: ["3.11"]
cuda_arch_version: ["12.1"]
fail-fast: false
uses: pytorch/test-infra/.github/workflows/linux_job.yml@main
with:
Expand All @@ -172,20 +172,17 @@ jobs:
# Commenting these out for now because the GPU test are not working inside docker
export CUDA_ARCH_VERSION=${{ matrix.cuda_arch_version }}
export CU_VERSION="cu${CUDA_ARCH_VERSION:0:2}${CUDA_ARCH_VERSION:3:1}"
# Remove the following line when the GPU tests are working inside docker, and uncomment the above lines
#export CU_VERSION="cpu"
if [[ "${{ github.ref }}" =~ release/* ]]; then
export RELEASE=1
export TORCH_VERSION=stable
else
export RELEASE=0
export TORCH_VERSION=nightly
fi
export TD_GET_DEFAULTS_TO_NONE=1
echo "PYTHON_VERSION: $PYTHON_VERSION"
echo "CU_VERSION: $CU_VERSION"
export TD_GET_DEFAULTS_TO_NONE=1
## setup_env.sh
bash .github/unittest/linux_optdeps/scripts/run_all.sh
Expand Down
Loading

0 comments on commit c1c2e84

Please sign in to comment.