Installing JAX
This page describes how to install JAX with Python virtual environments
In this example, we will install Jax.
Step 1: Request an interactive session on a GPU node with Ampere architecture GPUs
interact -q gpu -g 1 -f ampere -m 20g -n 4
Here, -f = feature. We only need to build on Ampere once.
Step 2: Once your session has started on a compute node, run nvidia-smi
to verify the GPU and then load the appropriate modules
module purge
unset LD_LIBRARY_PATH
module load cuda cudnn
Step 3: Create and activate the virtual environment
python -m venv jax.venv
source jax.venv/bin/activate
Step 4: Install the required packages
pip install --upgrade pip
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
Step 5: Test that JAX is able to detect GPUs
python
>>> from jax.lib import xla_bridge
>>> print(xla_bridge.get_backend().platform)
gpu
If the above function returns gpu
, then it's working correctly. You are all set, now you can install other necessary packages.
Modify batch file: See below the example batch file with the created environment
#SBATCH -J RBC
#SBATCH -N 1
#SBATCH --ntasks=1
#SBATCH --ntasks-per-node=1
#SBATCH --time=3:30:00
#SBATCH --mem=64GB
#SBATCH --partition=gpu
#SBATCH --gres=gpu:1
#SBATCH -o RBC_job_%j.o
#SBATCH -e RBC_job_%j.e
echo $LD_LIBRARY_PATH
unset LD_LIBRARY_PATH
echo $LD_LIBRARY_PATH
source /oscar/data/gk/psaluja/jax_env.venv/bin/activate
python3 -u kernel.py