Installing JAX
This page describes how to install JAX with Python virtual environments
In this example, we will install Jax.
Step 1: Request an interactive session on a GPU node with Ampere architecture GPUs
interact -q gpu -g 1 -f ampere -m 20g -n 4
Here, -f = feature. We only need to build on Ampere once.
Step 2: Once your session has started on a compute node, run nvidia-smi to verify the GPU and then load the appropriate modules
module purge
unset LD_LIBRARY_PATH
module load cuda cudnn
Step 3: Create and activate the virtual environment
Step 4: Install the required packages
Step 5: Test that JAX is able to detect GPUs
If the above function returns gpu, then it's working correctly. You are all set, now you can install other necessary packages.
Modify batch file: See below the example batch file with the created environment