Installing JAX
This page describes how to install JAX with Python virtual environments
In this example, we will install Jax.
Step 1: Request an interactive session on a GPU node with Ampere architecture GPUs
interact -q gpu -g 1 -f ampere -m 20g -n 4Here, -f = feature. We only need to build on Ampere once.
Step 2: Once your session has started on a compute node, run nvidia-smi to verify the GPU and then load the appropriate modules 
module load python/3.11.0 openssl/3.0.0 cuda/11.7.1 cudnn/8.2.0Step 3: Create and activate the virtual environment
virtualenv -p python3 jax.venv
source jax.venv/bin/activateStep 4: Install the required packages
pip install --upgrade pip
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html Step 5: Test that PyTorch is able to detect GPUs
python
>>> from jax.lib import xla_bridge
>>> print(xla_bridge.get_backend().platform)
gpuIf the above function returns gpu, then it's working correctly. You are all set, now you can install other necessary packages.
Last updated
Was this helpful?