This package is oriented toward the CosmoPower and CosmoPower JAX emulators. However, it does train emulators similar to CosmoPower, but entirely in JAX. Therefore, it is independent of TensorFlow and can be easily used for gradient-dependent inferences.
We recommend using a dedicated conda environment for clean dependency management.
git clone https://github.com/pburger112/cosmoemuJAX.git
cd cosmoemuJAXconda create -n cosmoemu_JAX_env python=3.10
conda activate cosmoemu_JAX_envInstall the package and its dependencies using pip:
pip install . ".[cpu]" If you need editable/development mode:
pip install -e . ".[cpu]"If you want use the Mac M1,M2,M3 GPU install like:
pip install . ".[gpu]" If you need editable/development mode:
pip install -e . ".[gpu]"The core dependencies for CPU (automatically installed) include:
- numpy
- jax
- scipy
- optax
- gdown
- matplotlib
- tqdm
- getdist
The adjusted core dependencies for GPU (automatically installed) include:
- numpy<2.0
- jax==0.4.20
- jax-metal==0.0.5
- scipy==1.11.4
- optax==0.1.7
cosmoemu_jax/: Python file with the JAX emulatoroutputs/: Folder where data is downloaded and emulators will be savedcosmoemu_jax_3times2pt.ipynb: Notebook to demonstrate emulator useHMC_example.ipynb: Notebook to demonstrate emulator use for an HMC inference
For questions, please contact: Pierre Burger – pierre.burger@uwaterloo.ca
