Skip to content

[doc] Add a JaxTrainer template#59842

Merged
matthewdeng merged 15 commits intoray-project:masterfrom
liulehui:jaxtrainer
Feb 2, 2026
Merged

[doc] Add a JaxTrainer template#59842
matthewdeng merged 15 commits intoray-project:masterfrom
liulehui:jaxtrainer

Conversation

@liulehui
Copy link
Contributor

@liulehui liulehui commented Jan 5, 2026

Description

  1. This PR adds a workspace template that walks users through how to Ray Train JaxTrainer.
  2. The purpose of this template is to walk user through how to use the JaxTrainer to train a GPT2 style model with both GPU and TPU.

For a high level overview, this template covers:

  • A hands-on example of training an GPT2 model using Jax/Flax
  • Sample ScalingConfig for both GPU and TPU
  • Simple integration with Ray Data to read the pre-tokenized openwebtext dataset for training

Testing:
tested in Anyscale workspace,

Additional information

  1. https://console.anyscale-staging.com/cld_kvedZWag2qA8i5BjxUevf5i7/prj_g7p6lsu6r8g7garwbxifppyz23/workspaces/expwrk_bw8izpdi59293i5e73h6biwkak/train/train/46607266f3cb454aa9e7f7929b3aaae3/workers/fbe9a9b91fdf90e7476bb12a0d000000?workspace-tab=code&command-history-section=application_logs&file=%252Fmnt%252Fcluster_storage%252Fgpt2%252F0104_raydata%252Fcheckpoints%252F1x1&storage=cluster

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a valuable JaxTrainer template for training a GPT-2 style model on GPUs and TPUs. The notebook and markdown file provide a comprehensive walkthrough. My review focuses on improving code correctness, clarity, and maintainability. I've identified a critical bug in the metric reporting logic that could cause training runs to hang, and I've provided a fix. Additionally, I've made several suggestions to enhance the example's robustness and readability, including correcting command-line syntax, simplifying data iteration, and removing redundant code.

@liulehui liulehui marked this pull request as ready for review January 5, 2026 17:00
@liulehui liulehui requested review from a team as code owners January 5, 2026 17:00
@ray-gardener ray-gardener bot added docs An issue or change related to documentation train Ray Train Related Issue labels Jan 5, 2026
Copy link
Contributor

@JasonLi1909 JasonLi1909 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome template! Left some comments, reminder to do the other steps (add to ci, compute configs, etc). Would be good to get a pass from @angelinalg at some point. Thanks!

@liulehui liulehui requested a review from JasonLi1909 January 13, 2026 22:49
@liulehui liulehui added the go add ONLY when ready to merge, run all tests label Jan 15, 2026
@github-actions
Copy link

This pull request has been automatically marked as stale because it has not had
any activity for 14 days. It will be closed in another 14 days if no further activity occurs.
Thank you for your contributions.

You can always ask for help on our discussion forum or Ray's public slack channel.

If you'd like to keep this open, just leave any comment, and the stale label will be removed.

@github-actions github-actions bot added the stale The issue is stale. It will be closed within 7 days unless there are further conversation label Jan 30, 2026
Copy link
Contributor

@matthewdeng matthewdeng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very cool

@github-actions github-actions bot added unstale A PR that has been marked unstale. It will not get marked stale again if this label is on it. and removed stale The issue is stale. It will be closed within 7 days unless there are further conversation labels Jan 30, 2026
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Lehui Liu <lehui@anyscale.com>
Copy link

@cursor cursor bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cursor Bugbot has reviewed your changes and found 3 potential issues.

@matthewdeng matthewdeng enabled auto-merge (squash) February 2, 2026 22:20
@matthewdeng matthewdeng merged commit 1490183 into ray-project:master Feb 2, 2026
4 of 7 checks passed
rayhhome pushed a commit to rayhhome/ray that referenced this pull request Feb 4, 2026
## Description

1. This PR adds a workspace template that walks users through how to Ray
Train
[JaxTrainer](https://docs.ray.io/en/master/train/api/doc/ray.train.v2.jax.JaxTrainer.html).
2. The purpose of this template is to walk user through how to use the
JaxTrainer to train a GPT2 style model with both GPU and TPU.

For a high level overview, this template covers:

* A hands-on example of training an GPT2 model using Jax/Flax
* Sample `ScalingConfig` for both GPU and TPU
* Simple integration with Ray Data to read the pre-tokenized openwebtext
dataset for training

Testing:
tested in Anyscale workspace,

## Additional information
1.
https://console.anyscale-staging.com/cld_kvedZWag2qA8i5BjxUevf5i7/prj_g7p6lsu6r8g7garwbxifppyz23/workspaces/expwrk_bw8izpdi59293i5e73h6biwkak/train/train/46607266f3cb454aa9e7f7929b3aaae3/workers/fbe9a9b91fdf90e7476bb12a0d000000?workspace-tab=code&command-history-section=application_logs&file=%252Fmnt%252Fcluster_storage%252Fgpt2%252F0104_raydata%252Fcheckpoints%252F1x1&storage=cluster

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Sirui Huang <ray.huang@anyscale.com>
elliot-barn pushed a commit that referenced this pull request Feb 9, 2026
## Description

1. This PR adds a workspace template that walks users through how to Ray
Train
[JaxTrainer](https://docs.ray.io/en/master/train/api/doc/ray.train.v2.jax.JaxTrainer.html).
2. The purpose of this template is to walk user through how to use the
JaxTrainer to train a GPT2 style model with both GPU and TPU.


For a high level overview, this template covers:

* A hands-on example of training an GPT2 model using Jax/Flax
* Sample `ScalingConfig` for both GPU and TPU
* Simple integration with Ray Data to read the pre-tokenized openwebtext
dataset for training

Testing:
tested in Anyscale workspace, 



## Additional information
1.
https://console.anyscale-staging.com/cld_kvedZWag2qA8i5BjxUevf5i7/prj_g7p6lsu6r8g7garwbxifppyz23/workspaces/expwrk_bw8izpdi59293i5e73h6biwkak/train/train/46607266f3cb454aa9e7f7929b3aaae3/workers/fbe9a9b91fdf90e7476bb12a0d000000?workspace-tab=code&command-history-section=application_logs&file=%252Fmnt%252Fcluster_storage%252Fgpt2%252F0104_raydata%252Fcheckpoints%252F1x1&storage=cluster

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: elliot-barn <elliot.barnwell@anyscale.com>
elliot-barn pushed a commit that referenced this pull request Feb 9, 2026
## Description

1. This PR adds a workspace template that walks users through how to Ray
Train
[JaxTrainer](https://docs.ray.io/en/master/train/api/doc/ray.train.v2.jax.JaxTrainer.html).
2. The purpose of this template is to walk user through how to use the
JaxTrainer to train a GPT2 style model with both GPU and TPU.


For a high level overview, this template covers:

* A hands-on example of training an GPT2 model using Jax/Flax
* Sample `ScalingConfig` for both GPU and TPU
* Simple integration with Ray Data to read the pre-tokenized openwebtext
dataset for training

Testing:
tested in Anyscale workspace, 



## Additional information
1.
https://console.anyscale-staging.com/cld_kvedZWag2qA8i5BjxUevf5i7/prj_g7p6lsu6r8g7garwbxifppyz23/workspaces/expwrk_bw8izpdi59293i5e73h6biwkak/train/train/46607266f3cb454aa9e7f7929b3aaae3/workers/fbe9a9b91fdf90e7476bb12a0d000000?workspace-tab=code&command-history-section=application_logs&file=%252Fmnt%252Fcluster_storage%252Fgpt2%252F0104_raydata%252Fcheckpoints%252F1x1&storage=cluster

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
ans9868 pushed a commit to ans9868/ray that referenced this pull request Feb 18, 2026
## Description

1. This PR adds a workspace template that walks users through how to Ray
Train
[JaxTrainer](https://docs.ray.io/en/master/train/api/doc/ray.train.v2.jax.JaxTrainer.html).
2. The purpose of this template is to walk user through how to use the
JaxTrainer to train a GPT2 style model with both GPU and TPU.

For a high level overview, this template covers:

* A hands-on example of training an GPT2 model using Jax/Flax
* Sample `ScalingConfig` for both GPU and TPU
* Simple integration with Ray Data to read the pre-tokenized openwebtext
dataset for training

Testing:
tested in Anyscale workspace,

## Additional information
1.
https://console.anyscale-staging.com/cld_kvedZWag2qA8i5BjxUevf5i7/prj_g7p6lsu6r8g7garwbxifppyz23/workspaces/expwrk_bw8izpdi59293i5e73h6biwkak/train/train/46607266f3cb454aa9e7f7929b3aaae3/workers/fbe9a9b91fdf90e7476bb12a0d000000?workspace-tab=code&command-history-section=application_logs&file=%252Fmnt%252Fcluster_storage%252Fgpt2%252F0104_raydata%252Fcheckpoints%252F1x1&storage=cluster

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: Adel Nour <ans9868@nyu.edu>
peterxcli pushed a commit to peterxcli/ray that referenced this pull request Feb 25, 2026
## Description

1. This PR adds a workspace template that walks users through how to Ray
Train
[JaxTrainer](https://docs.ray.io/en/master/train/api/doc/ray.train.v2.jax.JaxTrainer.html).
2. The purpose of this template is to walk user through how to use the
JaxTrainer to train a GPT2 style model with both GPU and TPU.

For a high level overview, this template covers:

* A hands-on example of training an GPT2 model using Jax/Flax
* Sample `ScalingConfig` for both GPU and TPU
* Simple integration with Ray Data to read the pre-tokenized openwebtext
dataset for training

Testing:
tested in Anyscale workspace,

## Additional information
1.
https://console.anyscale-staging.com/cld_kvedZWag2qA8i5BjxUevf5i7/prj_g7p6lsu6r8g7garwbxifppyz23/workspaces/expwrk_bw8izpdi59293i5e73h6biwkak/train/train/46607266f3cb454aa9e7f7929b3aaae3/workers/fbe9a9b91fdf90e7476bb12a0d000000?workspace-tab=code&command-history-section=application_logs&file=%252Fmnt%252Fcluster_storage%252Fgpt2%252F0104_raydata%252Fcheckpoints%252F1x1&storage=cluster

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: peterxcli <peterxcli@gmail.com>
peterxcli pushed a commit to peterxcli/ray that referenced this pull request Feb 25, 2026
## Description

1. This PR adds a workspace template that walks users through how to Ray
Train
[JaxTrainer](https://docs.ray.io/en/master/train/api/doc/ray.train.v2.jax.JaxTrainer.html).
2. The purpose of this template is to walk user through how to use the
JaxTrainer to train a GPT2 style model with both GPU and TPU.

For a high level overview, this template covers:

* A hands-on example of training an GPT2 model using Jax/Flax
* Sample `ScalingConfig` for both GPU and TPU
* Simple integration with Ray Data to read the pre-tokenized openwebtext
dataset for training

Testing:
tested in Anyscale workspace,

## Additional information
1.
https://console.anyscale-staging.com/cld_kvedZWag2qA8i5BjxUevf5i7/prj_g7p6lsu6r8g7garwbxifppyz23/workspaces/expwrk_bw8izpdi59293i5e73h6biwkak/train/train/46607266f3cb454aa9e7f7929b3aaae3/workers/fbe9a9b91fdf90e7476bb12a0d000000?workspace-tab=code&command-history-section=application_logs&file=%252Fmnt%252Fcluster_storage%252Fgpt2%252F0104_raydata%252Fcheckpoints%252F1x1&storage=cluster

---------

Signed-off-by: Lehui Liu <lehui@anyscale.com>
Signed-off-by: peterxcli <peterxcli@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

docs An issue or change related to documentation go add ONLY when ready to merge, run all tests train Ray Train Related Issue unstale A PR that has been marked unstale. It will not get marked stale again if this label is on it.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants