diff --git a/docs/index.md b/docs/index.md index da0a4c7a..8d510daf 100644 --- a/docs/index.md +++ b/docs/index.md @@ -24,33 +24,25 @@ Check out the [JAX installation guide](https://github.com/google/jax#pip-install ### Installation at HEAD -JAX-Triton and Pallas are developed at JAX and Jaxlib HEAD and close to Triton HEAD. To get a bleeding edge installation of JAX-Triton, run: +JAX-Triton is developed at JAX and jaxlib HEAD and close to Triton HEAD. To get +a bleeding edge installation of JAX-Triton, run: + ```bash $ pip install 'jax-triton @ git+https://github.com/jax-ml/jax-triton.git' ``` + This should install compatible versions of JAX and Triton. -JAX-Triton does depend on Jaxlib but it's usually a more stable dependency. You might be able to get away with using a recent jaxlib release: -```bash -$ pip install jaxlib[cuda] -$ # or -$ pip install jaxlib[cuda11_pip] -$ # or -$ pip install jaxlib[cuda12_pip] -``` +JAX-Triton requires jaxlib with GPU support. You could install the latest stable +release via -If you find there are issues with the latest Jaxlib release, you can try using a Jaxlib nightly. -To install a new jaxlib, you can find a link to a [CUDA 11 nightly](https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html) or [CUDA 12 nightly](https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html). Then install it via: -```bash -$ pip install 'jaxlib @ ' -``` -or to install CUDA via pip automatically, you can do: ```bash -$ pip install 'jaxlib[cuda11_pip] @ ' -$ # or -$ pip install 'jaxlib[cuda12_pip] @ ' +$ pip install jaxlib[cuda12] ``` +In rare cases JAX-Triton might need a nighly version of jaxlib. You can install +it following the instructions +[here](https://jax.readthedocs.io/en/latest/installation.html#jax-nightly-installation). ### Quickstart diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index f41e5d4c..ec71c434 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -537,7 +537,10 @@ def triton_kernel_call_lowering( named_args = dict(unsafe_zip(fn.arg_names, args)) if isinstance(fn, autotuner.Autotuner): - key_idxs = [fn.arg_names.index(k) for k in fn.keys] + if hasattr(fn, "key_idx"): + key_idxs = fn.key_idx # Triton <=3.2 + else: + key_idxs = [fn.arg_names.index(k) for k in fn.keys] if any(idx not in key_idxs for idx, _, _ in scalar_args): logging.warning( "Auto-tuning key does not include all scalar arguments. " diff --git a/jax_triton/version.py b/jax_triton/version.py index 38a4bf2d..bbe5e923 100644 --- a/jax_triton/version.py +++ b/jax_triton/version.py @@ -12,5 +12,5 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version_info__ = (0, 2, 0) +__version_info__ = (0, 3, 0) __version__ = ".".join(str(v) for v in __version_info__) diff --git a/pyproject.toml b/pyproject.toml index 0202e290..468459f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,8 +6,8 @@ readme = "README.md" requires-python = ">=3.10" dependencies = [ "absl-py>=1.4.0", - "jax>=0.4.31", - "triton>=3.0", + "jax>=0.4.34", + "triton>=3.1", ] [project.optional-dependencies]