diff --git a/docs/index.md b/docs/index.md
index da0a4c7a..8d510daf 100644
--- a/docs/index.md
+++ b/docs/index.md
@@ -24,33 +24,25 @@ Check out the [JAX installation guide](https://github.com/google/jax#pip-install
### Installation at HEAD
-JAX-Triton and Pallas are developed at JAX and Jaxlib HEAD and close to Triton HEAD. To get a bleeding edge installation of JAX-Triton, run:
+JAX-Triton is developed at JAX and jaxlib HEAD and close to Triton HEAD. To get
+a bleeding edge installation of JAX-Triton, run:
+
```bash
$ pip install 'jax-triton @ git+https://github.com/jax-ml/jax-triton.git'
```
+
This should install compatible versions of JAX and Triton.
-JAX-Triton does depend on Jaxlib but it's usually a more stable dependency. You might be able to get away with using a recent jaxlib release:
-```bash
-$ pip install jaxlib[cuda]
-$ # or
-$ pip install jaxlib[cuda11_pip]
-$ # or
-$ pip install jaxlib[cuda12_pip]
-```
+JAX-Triton requires jaxlib with GPU support. You could install the latest stable
+release via
-If you find there are issues with the latest Jaxlib release, you can try using a Jaxlib nightly.
-To install a new jaxlib, you can find a link to a [CUDA 11 nightly](https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda_releases.html) or [CUDA 12 nightly](https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html). Then install it via:
-```bash
-$ pip install 'jaxlib @ '
-```
-or to install CUDA via pip automatically, you can do:
```bash
-$ pip install 'jaxlib[cuda11_pip] @ '
-$ # or
-$ pip install 'jaxlib[cuda12_pip] @ '
+$ pip install jaxlib[cuda12]
```
+In rare cases JAX-Triton might need a nighly version of jaxlib. You can install
+it following the instructions
+[here](https://jax.readthedocs.io/en/latest/installation.html#jax-nightly-installation).
### Quickstart
diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py
index f41e5d4c..ec71c434 100644
--- a/jax_triton/triton_lib.py
+++ b/jax_triton/triton_lib.py
@@ -537,7 +537,10 @@ def triton_kernel_call_lowering(
named_args = dict(unsafe_zip(fn.arg_names, args))
if isinstance(fn, autotuner.Autotuner):
- key_idxs = [fn.arg_names.index(k) for k in fn.keys]
+ if hasattr(fn, "key_idx"):
+ key_idxs = fn.key_idx # Triton <=3.2
+ else:
+ key_idxs = [fn.arg_names.index(k) for k in fn.keys]
if any(idx not in key_idxs for idx, _, _ in scalar_args):
logging.warning(
"Auto-tuning key does not include all scalar arguments. "
diff --git a/jax_triton/version.py b/jax_triton/version.py
index 38a4bf2d..bbe5e923 100644
--- a/jax_triton/version.py
+++ b/jax_triton/version.py
@@ -12,5 +12,5 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-__version_info__ = (0, 2, 0)
+__version_info__ = (0, 3, 0)
__version__ = ".".join(str(v) for v in __version_info__)
diff --git a/pyproject.toml b/pyproject.toml
index 0202e290..468459f3 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -6,8 +6,8 @@ readme = "README.md"
requires-python = ">=3.10"
dependencies = [
"absl-py>=1.4.0",
- "jax>=0.4.31",
- "triton>=3.0",
+ "jax>=0.4.34",
+ "triton>=3.1",
]
[project.optional-dependencies]