[train][checkpoint] Add validate_function and validate_config to ray.train.report#56360
Conversation
5a673c5 to
21656e1
Compare
…train.report Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
42aa2ff to
4d0c000
Compare
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
python/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/controller/controller.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/controller/controller.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
…troller shutdown Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
justinvyu
left a comment
There was a problem hiding this comment.
Great! I'll do a more thorough pass on the tests in the next round.
python/ray/train/v2/_internal/execution/controller/controller.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/checkpoint/validation_manager.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/checkpoint/validation_manager.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/checkpoint/validation_manager.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/checkpoint/validation_manager.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/checkpoint/validation_manager.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
justinvyu
left a comment
There was a problem hiding this comment.
Should be good after this!
python/ray/train/v2/_internal/execution/checkpoint/report_handler.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/checkpoint/validation_manager.py
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/checkpoint/validation_manager.py
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/checkpoint/validation_manager.py
Show resolved
Hide resolved
Signed-off-by: Timothy Seah <tseah@anyscale.com>
…eport accordingly Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
python/ray/train/v2/_internal/execution/checkpoint/validation_manager.py
Show resolved
Hide resolved
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
python/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py
Show resolved
Hide resolved
|
A few comments on 747e99b: I considered the following alternate implementation methods but decided against them for various reasons:
I reworked the
I'm not worrying about restoring
|
python/ray/train/v2/_internal/execution/checkpoint/checkpoint_manager.py
Outdated
Show resolved
Hide resolved
python/ray/train/v2/_internal/execution/checkpoint/validation_manager.py
Show resolved
Hide resolved
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Signed-off-by: Timothy Seah <tseah@anyscale.com>
Why is using the same experiment directory a requirement here? If you reuse the same storage_path+name, then you'll end up "restoring" the previous run: https://docs.ray.io/en/latest/train/user-guides/fault-tolerance.html#job-driver-fault-tolerance The validation run shouldn't actually need to save any files or record checkpoints, so the storage path will be mostly unused. I think we should recommend something like this: def eval_only_train_func(config_dict):
# ...
# ray.train.report(metrics, checkpoint, NO_UPLOAD)
# !! The previous usage of ray.train.report w/ NO_UPLOAD is a bit confusing,
# since you're essentially checkpointing within a validation loop.
# The only reason why we needed the dummy checkpoint previously was so
# that we could get these `metrics` out of the training function.
metrics = {'score': mean_valid_loss.compute().item()}
return metrics
def validate_with_torch_trainer(checkpoint, config):
trainer = ray.train.torch.TorchTrainer(
eval_only_train_func,
train_loop_config={'checkpoint': checkpoint},
scaling_config=ray.train.ScalingConfig(num_workers=2, use_gpu=True),
# !! Just leave the default auto-generated UUID run name.
run_config=ray.train.RunConfig(storage_path="/mnt/cluster_storage"),
datasets={"test": config['dataset']},
)
result = trainer.fit()
return result.return_values[0]Note that |
Thanks, good catch! I agree that |
…train.report (ray-project#56360) The main change here is: * Train workers report validation function + validation config * Controller kicks off validation Ray task and associates its return value with the relevant checkpoint. * Main controller step polls workers and validations, only finishing when both are done. --------- Signed-off-by: Timothy Seah <tseah@anyscale.com> Signed-off-by: Douglas Strodtman <douglas@anyscale.com>
…train.report (ray-project#56360) The main change here is: * Train workers report validation function + validation config * Controller kicks off validation Ray task and associates its return value with the relevant checkpoint. * Main controller step polls workers and validations, only finishing when both are done. --------- Signed-off-by: Timothy Seah <tseah@anyscale.com> Signed-off-by: Josh Kodi <joshkodi@gmail.com>
…train.report (ray-project#56360) The main change here is: * Train workers report validation function + validation config * Controller kicks off validation Ray task and associates its return value with the relevant checkpoint. * Main controller step polls workers and validations, only finishing when both are done. --------- Signed-off-by: Timothy Seah <tseah@anyscale.com>
…train.report (ray-project#56360) The main change here is: * Train workers report validation function + validation config * Controller kicks off validation Ray task and associates its return value with the relevant checkpoint. * Main controller step polls workers and validations, only finishing when both are done. --------- Signed-off-by: Timothy Seah <tseah@anyscale.com>
…train.report (ray-project#56360) The main change here is: * Train workers report validation function + validation config * Controller kicks off validation Ray task and associates its return value with the relevant checkpoint. * Main controller step polls workers and validations, only finishing when both are done. --------- Signed-off-by: Timothy Seah <tseah@anyscale.com> Signed-off-by: Aydin Abiar <aydin@anyscale.com>
…train.report (ray-project#56360) The main change here is: * Train workers report validation function + validation config * Controller kicks off validation Ray task and associates its return value with the relevant checkpoint. * Main controller step polls workers and validations, only finishing when both are done. --------- Signed-off-by: Timothy Seah <tseah@anyscale.com> Signed-off-by: Future-Outlier <eric901201@gmail.com>
Summary
The main change here is:
In doing this, we are leveraging Ray Train's single controller architecture to provide a global understanding of training progress. This makes it easy to do stuff like early stopping in the future.
A few other notes:
Result. I added a TODO to retry and time out in the future.Result.failed_validationscontains all the checkpoints that failed validations, which may or may not have been deleted.CheckpointManagerrestoration; right now we simply won't rerun interrupted validations.API Examples
You can define a
validate_functionwithmap_batchesor
TorchTrainer.Note the following about the
map_batchesandTorchTrainermethods:map_batchesmethod's__call__function must move the batch to the device, but theTorchTrainermethod's forward pass does not because the dataloader automatically does this.map_batchesmethod uses Ray Data's metric aggregation methods, whereas theTorchTrainermethod uses Torch's.Either way, you need to report with the validate function as follows:
Note that both methods pass the
test_datasetdirectly as a global variable - let me know if there is a better way to do this.Testing
I tried both API's above in an Anyscale workspace on 2 epochs:
The
map_batchesAPI:result.best_checkpointsat the end looks like[(Checkpoint(filesystem=local, path=/mnt/cluster_storage/ray_train_run-2025-09-15_18-31-44/checkpoint_2025-09-15_18-34-05.089176), {'loss': 2.0584681034088135, 'epoch': 0, 'score': 0.2369}), (Checkpoint(filesystem=local, path=/mnt/cluster_storage/ray_train_run-2025-09-15_18-31-44/checkpoint_2025-09-15_18-36-07.816422), {'loss': 1.8418619632720947, 'epoch': 1, 'score': 0.3152})]result.fitreturned) - (time train_func exited), was 23.8961102962s, which is equal to (time taken to set up last validation) + (time taken to perform last validation)The
TorchTrainerAPI:result.best_checkpointsat the end looks like[(Checkpoint(filesystem=local, path=/mnt/cluster_storage/06325ef3-fc29-48eb-bc21-f3fca71da238/checkpoint_2025-09-16_13-34-59.088135), {'loss': 1.7960023880004883, 'epoch': 0, 'score': 1.8543205261230469}), (Checkpoint(filesystem=local, path=/mnt/cluster_storage/06325ef3-fc29-48eb-bc21-f3fca71da238/checkpoint_2025-09-16_13-37-05.083466), {'loss': 1.615540862083435, 'epoch': 1, 'score': 1.6425774097442627})]I also tested the
TorchTrainerAPI with autoscaling enabled:The results are basically the same as above, but interestingly:
Note
Introduce async checkpoint validation via validate_fn/validate_config in report(), add ValidationManager, and refactor reporting to a TrainingReport with pending-checkpoint handling.
ray.train.v2.api.train_fn_utils.reportnow acceptsvalidate_fnandvalidate_configto run async checkpoint validation.ValidationManagerruns validation tasks, polls results, and updates checkpoint metrics.ValidationManagerintoTrainControllerandReportCallbackHandler._TrainingReportand_ValidationSpecto carry checkpoint, metrics, and validation spec end-to-end.ReportCallback.after_reportsignature to receivetraining_reportplus per-worker metrics.WorkerStatusnow holdstraining_report(renamed fromtraining_result).CheckpointManagersupports pending checkpoints and updates metrics post-validation; tie-break equal scores by report index; persist state before deletions._insert_into_sorted_listto accept typed items and optional tie-break map.UserCallbackHandler.after_reportnow sources checkpoint fromtraining_report.test_async_checkpointing_validation*,test_validation_manager).Written by Cursor Bugbot for commit 6e1ddbc. This will update automatically on new commits. Configure here.