Skip to content
This repository was archived by the owner on Aug 30, 2025. It is now read-only.

NiklasKappel/nnx-lightning

Repository files navigation

nnx-lightning

An example of training a Flax NNX model with PyTorch data loaders and the Lightning trainer.

Todo:

  • Is it bad that LightningModule.log requires a blocking call to wait for metric values because it can't collect JAX array futures?
  • Should we convert data to numpy/JAX arrays directly, or to torch tensors first and to numpy/JAX arrays later?

Notes:

  • Hyperparameters that are given to the LightningModule on initialization and used in the step functions must be passed explicitly through the JIT boundary (possibly in a step_config PyTree).
  • Lightning likes to warn about data loaders that don't use multiple processors, but JAX likes to warn that using Python multiprocessing results in deadlocks.

See also:

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages