[doc] Add a JaxTrainer template#59842
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a valuable JaxTrainer template for training a GPT-2 style model on GPUs and TPUs. The notebook and markdown file provide a comprehensive walkthrough. My review focuses on improving code correctness, clarity, and maintainability. I've identified a critical bug in the metric reporting logic that could cause training runs to hang, and I've provided a fix. Additionally, I've made several suggestions to enhance the example's robustness and readability, including correcting command-line syntax, simplifying data iteration, and removing redundant code.
doc/source/train/examples/jax/intro_to_jax_trainer/README.ipynb
Outdated
Show resolved
Hide resolved
doc/source/train/examples/jax/intro_to_jax_trainer/README.ipynb
Outdated
Show resolved
Hide resolved
doc/source/train/examples/jax/intro_to_jax_trainer/README.ipynb
Outdated
Show resolved
Hide resolved
doc/source/train/examples/jax/intro_to_jax_trainer/README.ipynb
Outdated
Show resolved
Hide resolved
doc/source/train/examples/jax/intro_to_jax_trainer/README.ipynb
Outdated
Show resolved
Hide resolved
doc/source/train/examples/jax/intro_to_jax_trainer/README.ipynb
Outdated
Show resolved
Hide resolved
JasonLi1909
left a comment
There was a problem hiding this comment.
Awesome template! Left some comments, reminder to do the other steps (add to ci, compute configs, etc). Would be good to get a pass from @angelinalg at some point. Thanks!
doc/source/train/examples/jax/intro_to_jax_trainer/README.ipynb
Outdated
Show resolved
Hide resolved
doc/source/train/examples/jax/intro_to_jax_trainer/README.ipynb
Outdated
Show resolved
Hide resolved
doc/source/train/examples/jax/intro_to_jax_trainer/README.ipynb
Outdated
Show resolved
Hide resolved
doc/source/train/examples/jax/intro_to_jax_trainer/README.ipynb
Outdated
Show resolved
Hide resolved
doc/source/train/examples/jax/intro_to_jax_trainer/README.ipynb
Outdated
Show resolved
Hide resolved
doc/source/train/examples/jax/intro_to_jax_trainer/README.ipynb
Outdated
Show resolved
Hide resolved
|
This pull request has been automatically marked as stale because it has not had You can always ask for help on our discussion forum or Ray's public slack channel. If you'd like to keep this open, just leave any comment, and the stale label will be removed. |
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
## Description 1. This PR adds a workspace template that walks users through how to Ray Train [JaxTrainer](https://docs.ray.io/en/master/train/api/doc/ray.train.v2.jax.JaxTrainer.html). 2. The purpose of this template is to walk user through how to use the JaxTrainer to train a GPT2 style model with both GPU and TPU. For a high level overview, this template covers: * A hands-on example of training an GPT2 model using Jax/Flax * Sample `ScalingConfig` for both GPU and TPU * Simple integration with Ray Data to read the pre-tokenized openwebtext dataset for training Testing: tested in Anyscale workspace, ## Additional information 1. https://console.anyscale-staging.com/cld_kvedZWag2qA8i5BjxUevf5i7/prj_g7p6lsu6r8g7garwbxifppyz23/workspaces/expwrk_bw8izpdi59293i5e73h6biwkak/train/train/46607266f3cb454aa9e7f7929b3aaae3/workers/fbe9a9b91fdf90e7476bb12a0d000000?workspace-tab=code&command-history-section=application_logs&file=%252Fmnt%252Fcluster_storage%252Fgpt2%252F0104_raydata%252Fcheckpoints%252F1x1&storage=cluster --------- Signed-off-by: Lehui Liu <lehui@anyscale.com> Signed-off-by: Sirui Huang <ray.huang@anyscale.com>
## Description 1. This PR adds a workspace template that walks users through how to Ray Train [JaxTrainer](https://docs.ray.io/en/master/train/api/doc/ray.train.v2.jax.JaxTrainer.html). 2. The purpose of this template is to walk user through how to use the JaxTrainer to train a GPT2 style model with both GPU and TPU. For a high level overview, this template covers: * A hands-on example of training an GPT2 model using Jax/Flax * Sample `ScalingConfig` for both GPU and TPU * Simple integration with Ray Data to read the pre-tokenized openwebtext dataset for training Testing: tested in Anyscale workspace, ## Additional information 1. https://console.anyscale-staging.com/cld_kvedZWag2qA8i5BjxUevf5i7/prj_g7p6lsu6r8g7garwbxifppyz23/workspaces/expwrk_bw8izpdi59293i5e73h6biwkak/train/train/46607266f3cb454aa9e7f7929b3aaae3/workers/fbe9a9b91fdf90e7476bb12a0d000000?workspace-tab=code&command-history-section=application_logs&file=%252Fmnt%252Fcluster_storage%252Fgpt2%252F0104_raydata%252Fcheckpoints%252F1x1&storage=cluster --------- Signed-off-by: Lehui Liu <lehui@anyscale.com> Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com>
## Description 1. This PR adds a workspace template that walks users through how to Ray Train [JaxTrainer](https://docs.ray.io/en/master/train/api/doc/ray.train.v2.jax.JaxTrainer.html). 2. The purpose of this template is to walk user through how to use the JaxTrainer to train a GPT2 style model with both GPU and TPU. For a high level overview, this template covers: * A hands-on example of training an GPT2 model using Jax/Flax * Sample `ScalingConfig` for both GPU and TPU * Simple integration with Ray Data to read the pre-tokenized openwebtext dataset for training Testing: tested in Anyscale workspace, ## Additional information 1. https://console.anyscale-staging.com/cld_kvedZWag2qA8i5BjxUevf5i7/prj_g7p6lsu6r8g7garwbxifppyz23/workspaces/expwrk_bw8izpdi59293i5e73h6biwkak/train/train/46607266f3cb454aa9e7f7929b3aaae3/workers/fbe9a9b91fdf90e7476bb12a0d000000?workspace-tab=code&command-history-section=application_logs&file=%252Fmnt%252Fcluster_storage%252Fgpt2%252F0104_raydata%252Fcheckpoints%252F1x1&storage=cluster --------- Signed-off-by: Lehui Liu <lehui@anyscale.com>
## Description 1. This PR adds a workspace template that walks users through how to Ray Train [JaxTrainer](https://docs.ray.io/en/master/train/api/doc/ray.train.v2.jax.JaxTrainer.html). 2. The purpose of this template is to walk user through how to use the JaxTrainer to train a GPT2 style model with both GPU and TPU. For a high level overview, this template covers: * A hands-on example of training an GPT2 model using Jax/Flax * Sample `ScalingConfig` for both GPU and TPU * Simple integration with Ray Data to read the pre-tokenized openwebtext dataset for training Testing: tested in Anyscale workspace, ## Additional information 1. https://console.anyscale-staging.com/cld_kvedZWag2qA8i5BjxUevf5i7/prj_g7p6lsu6r8g7garwbxifppyz23/workspaces/expwrk_bw8izpdi59293i5e73h6biwkak/train/train/46607266f3cb454aa9e7f7929b3aaae3/workers/fbe9a9b91fdf90e7476bb12a0d000000?workspace-tab=code&command-history-section=application_logs&file=%252Fmnt%252Fcluster_storage%252Fgpt2%252F0104_raydata%252Fcheckpoints%252F1x1&storage=cluster --------- Signed-off-by: Lehui Liu <lehui@anyscale.com> Signed-off-by: Adel Nour <ans9868@nyu.edu>
## Description 1. This PR adds a workspace template that walks users through how to Ray Train [JaxTrainer](https://docs.ray.io/en/master/train/api/doc/ray.train.v2.jax.JaxTrainer.html). 2. The purpose of this template is to walk user through how to use the JaxTrainer to train a GPT2 style model with both GPU and TPU. For a high level overview, this template covers: * A hands-on example of training an GPT2 model using Jax/Flax * Sample `ScalingConfig` for both GPU and TPU * Simple integration with Ray Data to read the pre-tokenized openwebtext dataset for training Testing: tested in Anyscale workspace, ## Additional information 1. https://console.anyscale-staging.com/cld_kvedZWag2qA8i5BjxUevf5i7/prj_g7p6lsu6r8g7garwbxifppyz23/workspaces/expwrk_bw8izpdi59293i5e73h6biwkak/train/train/46607266f3cb454aa9e7f7929b3aaae3/workers/fbe9a9b91fdf90e7476bb12a0d000000?workspace-tab=code&command-history-section=application_logs&file=%252Fmnt%252Fcluster_storage%252Fgpt2%252F0104_raydata%252Fcheckpoints%252F1x1&storage=cluster --------- Signed-off-by: Lehui Liu <lehui@anyscale.com> Signed-off-by: peterxcli <peterxcli@gmail.com>
## Description 1. This PR adds a workspace template that walks users through how to Ray Train [JaxTrainer](https://docs.ray.io/en/master/train/api/doc/ray.train.v2.jax.JaxTrainer.html). 2. The purpose of this template is to walk user through how to use the JaxTrainer to train a GPT2 style model with both GPU and TPU. For a high level overview, this template covers: * A hands-on example of training an GPT2 model using Jax/Flax * Sample `ScalingConfig` for both GPU and TPU * Simple integration with Ray Data to read the pre-tokenized openwebtext dataset for training Testing: tested in Anyscale workspace, ## Additional information 1. https://console.anyscale-staging.com/cld_kvedZWag2qA8i5BjxUevf5i7/prj_g7p6lsu6r8g7garwbxifppyz23/workspaces/expwrk_bw8izpdi59293i5e73h6biwkak/train/train/46607266f3cb454aa9e7f7929b3aaae3/workers/fbe9a9b91fdf90e7476bb12a0d000000?workspace-tab=code&command-history-section=application_logs&file=%252Fmnt%252Fcluster_storage%252Fgpt2%252F0104_raydata%252Fcheckpoints%252F1x1&storage=cluster --------- Signed-off-by: Lehui Liu <lehui@anyscale.com> Signed-off-by: peterxcli <peterxcli@gmail.com>
Description
For a high level overview, this template covers:
ScalingConfigfor both GPU and TPUTesting:
tested in Anyscale workspace,
Additional information