> For the complete documentation index, see [llms.txt](https://docs.ccv.brown.edu/oscar/llms.txt). Markdown versions of documentation pages are available by appending `.md` to page URLs; this page is available as [Markdown](https://docs.ccv.brown.edu/oscar/gpu-computing/installing-frameworks-pytorch-tensorflow-jax/installing-jax.md).

# Installing JAX

In this example, we will install **Jax**.

**Step 1:** Request an interactive session on a GPU node with Ampere architecture GPUs

```bash
interact -q gpu -g 1 -f ampere -m 20g -n 4
```

Here, -f = feature. We only need to build on Ampere once.

**Step 2:** Once your session has started on a compute node, run `nvidia-smi` to verify the GPU and then load the appropriate modules

```bash
module purge 
unset LD_LIBRARY_PATH
module load cuda cudnn
```

**Step 3:** Create and activate the virtual environment

```bash
python -m venv jax.venv
source jax.venv/bin/activate
```

**Step 4:** Install the required packages

```bash
pip install --upgrade pip
pip  install  --upgrade  "jax[cuda12_pip]"  -f  https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

**Step 5:** Test that JAX is able to detect GPUs

```python
python
>>> from jax.lib import xla_bridge
>>> print(xla_bridge.get_backend().platform)
gpu
```

If the above function returns `gpu`, then it's working correctly. You are all set, now you can install other necessary packages.

## Modify batch file: See below the example batch file with the created environment

```bash
#SBATCH -J RBC
#SBATCH -N 1
#SBATCH --ntasks=1
#SBATCH --ntasks-per-node=1
#SBATCH --time=3:30:00
#SBATCH --mem=64GB
#SBATCH --partition=gpu
#SBATCH --gres=gpu:1
#SBATCH -o RBC_job_%j.o
#SBATCH -e RBC_job_%j.e

echo $LD_LIBRARY_PATH
unset LD_LIBRARY_PATH
echo $LD_LIBRARY_PATH

source /oscar/data/gk/psaluja/jax_env.venv/bin/activate
python3 -u kernel.py
```
