Installing JAX
This page describes how to install JAX with Python virtual environments
In this example, we will install Jax.
Step 1: Request an interactive session on a GPU node with Ampere architecture GPUs
Here, -f = feature. We only need to build on Ampere once.
Step 2: Once your session has started on a compute node, run nvidia-smi
to verify the GPU and then load the appropriate modules
Step 3: Create and activate the virtual environment
Step 4: Install the required packages
Step 5: Test that JAX is able to detect GPUs
If the above function returns gpu
, then it's working correctly. You are all set, now you can install other necessary packages.
Modify batch file: See below the example batch file with the created environment
Last updated