arrow-left

All pages
gitbookPowered by GitBook
1 of 1

Loading...

Installing JAX

This page describes how to install JAX with Python virtual environments

In this example, we will install Jax.

Step 1: Request an interactive session on a GPU node with Ampere architecture GPUs

interact -q gpu -g 1 -f ampere -m 20g -n 4

Here, -f = feature. We only need to build on Ampere once.

Step 2: Once your session has started on a compute node, run nvidia-smi to verify the GPU and then load the appropriate modules

module purge 
unset LD_LIBRARY_PATH
module load cuda cudnn

Step 3: Create and activate the virtual environment

Step 4: Install the required packages

Step 5: Test that JAX is able to detect GPUs

If the above function returns gpu, then it's working correctly. You are all set, now you can install other necessary packages.

hashtag
Modify batch file: See below the example batch file with the created environment

python -m venv jax.venv
source jax.venv/bin/activate
pip install --upgrade pip
pip  install  --upgrade  "jax[cuda12_pip]"  -f  https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
python
>>> from jax.lib import xla_bridge
>>> print(xla_bridge.get_backend().platform)
gpu
#SBATCH -J RBC
#SBATCH -N 1
#SBATCH --ntasks=1
#SBATCH --ntasks-per-node=1
#SBATCH --time=3:30:00
#SBATCH --mem=64GB
#SBATCH --partition=gpu
#SBATCH --gres=gpu:1
#SBATCH -o RBC_job_%j.o
#SBATCH -e RBC_job_%j.e

echo $LD_LIBRARY_PATH
unset LD_LIBRARY_PATH
echo $LD_LIBRARY_PATH

source /oscar/data/gk/psaluja/jax_env.venv/bin/activate
python3 -u kernel.py