Installing JAX
This page describes how to install JAX with Python virtual environments
interact -q gpu -g 1 -f ampere -m 20g -n 4module purge
unset LD_LIBRARY_PATH
module load cuda cudnnpython -m venv jax.venv
source jax.venv/bin/activatepip install --upgrade pip
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.htmlpython
>>> from jax.lib import xla_bridge
>>> print(xla_bridge.get_backend().platform)
gpuModify batch file: See below the example batch file with the created environment
Last updated
Was this helpful?