Installing JAX

This page describes how to install JAX with Python virtual environments

In this example, we will install Jax.

Step 1: Request an interactive session on a GPU node with Ampere architecture GPUs

interact -q gpu -g 1 -f ampere -m 20g -n 4

Here, -f = feature. We only need to build on Ampere once.

Step 2: Once your session has started on a compute node, run nvidia-smi to verify the GPU and then load the appropriate modules

module purge 
unset LD_LIBRARY_PATH
module load cuda cudnn

Step 3: Create and activate the virtual environment

python -m venv jax.venv
source jax.venv/bin/activate

Step 4: Install the required packages

pip install --upgrade pip
pip  install  --upgrade  "jax[cuda12_pip]"  -f  https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Step 5: Test that JAX is able to detect GPUs

python
>>> from jax.lib import xla_bridge
>>> print(xla_bridge.get_backend().platform)
gpu

If the above function returns gpu, then it's working correctly. You are all set, now you can install other necessary packages.

Modify batch file: See below the example batch file with the created environment

Last updated

Was this helpful?