Skip to content

Commit

Permalink
Don't use manylinux docker
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Aug 12, 2023
1 parent 0cdad87 commit 157af69
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 17 deletions.
21 changes: 7 additions & 14 deletions .github/workflows/build_wheels.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,14 @@ jobs:
name: Build Wheel
needs: setup_release
runs-on: ${{ matrix.os }}
container:
image: quay.io/pypa/manylinux2014_x86_64

strategy:
fail-fast: false
matrix:
os: [ubuntu-20.04]
# python-version: ['3.7', '3.8', '3.9', '3.10']
python-version: ['3.9']
# torch-version: ['1.12.1', '1.13.1', '2.0.1']
# cuda-version: ['11.6.2', '11.7.1', '11.8.0', '12.1.1']
# torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.0.dev20230613']
torch-version: ['1.12.1']
# cuda-version: ['11.6.2', '11.7.1', '11.8.0', '12.1.0']
cuda-version: ['11.6.2', '11.7.1']
python-version: ['3.7', '3.8', '3.9', '3.10']
torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.0.dev20230613']
cuda-version: ['11.6.2', '11.7.1', '11.8.0', '12.1.0']
exclude:
# Pytorch >= 2.0 only supports Python >= 3.8
- torch-version: '2.0.1'
Expand Down Expand Up @@ -86,10 +79,10 @@ jobs:
echo "MATRIX_CUDA_VERSION=$(echo ${{ matrix.cuda-version }} | awk -F \. {'print $1 $2'})" >> $GITHUB_ENV
echo "MATRIX_TORCH_VERSION=$(echo ${{ matrix.torch-version }} | awk -F \. {'print $1 "." $2'})" >> $GITHUB_ENV
# - name: Free up disk space
# if: ${{ runner.os == 'Linux' }}
# run: |
# sudo rm -rf /usr/share/dotnet
- name: Free up disk space
if: ${{ runner.os == 'Linux' }}
run: |
sudo rm -rf /usr/share/dotnet
- name: Install CUDA ${{ matrix.cuda-version }}
if: ${{ matrix.cuda-version != 'cpu' }}
Expand Down
2 changes: 1 addition & 1 deletion flash_attn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "2.0.8.post7"
__version__ = "2.0.8.post8"

from flash_attn.flash_attn_interface import flash_attn_func
from flash_attn.flash_attn_interface import flash_attn_kvpacked_func
Expand Down
7 changes: 5 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,12 +243,15 @@ def run(self):
raise_if_cuda_home_none("flash_attn")

# Determine the version numbers that will be used to determine the correct wheel
_, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
# We're using the CUDA version used to build torch, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version = parse(torch.version.cuda)
torch_version_raw = parse(torch.__version__)
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
platform_name = get_platform()
flash_version = get_package_version()
cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
cuda_version = f"{torch_cuda_version.major}{torch_cuda_version.minor}"
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"

# Determine wheel URL based on CUDA version, torch version, python version and OS
Expand Down

0 comments on commit 157af69

Please sign in to comment.