Skip to content

Commit a50b5b3

Browse files
kamil-kaczmarekpeterxcli
authored andcommitted
[RLlib] Fix LearnerGroup.load_module_state() and simultaneously mark as deprecated. (ray-project#60354)
## Description * allow to pass a path with cloud filesystem (for example `gcs` or `s3`) to the `LearnerGroup.load_module_state()`. * mark `LearnerGroup.load_module_state()` as Deprecated. Users should use `Algorithm.restore_from_path(path=..., component=...)` * mark `load_state_path` field in the `RLModuleSpec` dataclass as Deprecated. Direct users to use `Algorithm.restore_from_path(path=..., component=...)`. * add unit tests for `LearnerGroup.load_module_state()` --------- Signed-off-by: Kamil Kaczmarek <kamil@anyscale.com> Signed-off-by: peterxcli <peterxcli@gmail.com>
1 parent 397bc36 commit a50b5b3

File tree

7 files changed

+57
-35
lines changed

7 files changed

+57
-35
lines changed

rllib/algorithms/algorithm_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3524,7 +3524,7 @@ def multi_agent(
35243524
policy_map_capacity: Keep this many policies in the "policy_map" (before
35253525
writing least-recently used ones to disk/S3).
35263526
policy_mapping_fn: Function mapping agent ids to policy ids. The signature
3527-
is: `(agent_id, episode, worker, **kwargs) -> PolicyID`.
3527+
is: `(agent_id, episode, **kwargs) -> PolicyID`.
35283528
policies_to_train: Determines those policies that should be updated.
35293529
Options are:
35303530
- None, for training all policies.

rllib/core/learner/learner_group.py

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import copy
22
import itertools
3-
import pathlib
43
from functools import partial
54
from typing import (
65
TYPE_CHECKING,
@@ -732,63 +731,75 @@ def update_from_episodes(self, episodes, **kwargs):
732731
def async_update(self, *args, **kwargs):
733732
pass
734733

735-
@Deprecated(new="LearnerGroup.load_from_path(path=..., component=...)", error=False)
734+
@Deprecated(
735+
old="LearnerGroup.load_module_state()",
736+
help="To restore RLModule or MultiRLModule state "
737+
"use LearnerGroup.restore_from_path(path=..., component=...). "
738+
"See docs for more details: "
739+
"https://docs.ray.io/en/latest/rllib/rl-modules.html#checkpointing-rlmodules",
740+
error=False,
741+
)
736742
def load_module_state(
737743
self,
738744
*,
739745
multi_rl_module_ckpt_dir: Optional[str] = None,
740746
modules_to_load: Optional[Set[str]] = None,
741747
rl_module_ckpt_dirs: Optional[Dict[ModuleID, str]] = None,
742748
) -> None:
743-
"""Load the checkpoints of the modules being trained by this LearnerGroup.
749+
"""Load the checkpoints of the modules being trained by `LearnerGroup`.
744750
745751
`load_module_state` can be used 3 ways:
746-
1. Load a checkpoint for the MultiRLModule being trained by this
747-
LearnerGroup. Limit the modules that are loaded from the checkpoint
748-
by specifying the `modules_to_load` argument.
749-
2. Load the checkpoint(s) for single agent RLModules that
750-
are in the MultiRLModule being trained by this LearnerGroup.
751-
3. Load a checkpoint for the MultiRLModule being trained by this
752-
LearnerGroup and load the checkpoint(s) for single agent RLModules
753-
that are in the MultiRLModule. The checkpoints for the single
754-
agent RLModules take precedence over the module states in the
755-
MultiRLModule checkpoint.
756-
757-
NOTE: At lease one of multi_rl_module_ckpt_dir or rl_module_ckpt_dirs is
758-
must be specified. modules_to_load can only be specified if
759-
multi_rl_module_ckpt_dir is specified.
752+
1. Load a checkpoint for the `MultiRLModule` being trained by this
753+
`LearnerGroup`. Optionally, limit the modules that are loaded
754+
from the checkpoint by specifying the `modules_to_load` argument.
755+
2. Load the checkpoint(s) for single agent `RLModules` that
756+
are in the `MultiRLModule` being trained by this `LearnerGroup`.
757+
3. Load a checkpoint for the `MultiRLModule` being trained by this
758+
`LearnerGroup` and load the checkpoint(s) for single agent `RLModules`
759+
that are in the `MultiRLModule`. The checkpoints for the single
760+
agent `RLModules` take precedence over the module states in the
761+
`MultiRLModule` checkpoint.
762+
763+
At least one of `multi_rl_module_ckpt_dir` or `rl_module_ckpt_dirs`
764+
must be specified.
765+
`modules_to_load` can only be specified if `multi_rl_module_ckpt_dir`
766+
is provided.
760767
761768
Args:
762769
multi_rl_module_ckpt_dir: The path to the checkpoint for the
763-
MultiRLModule.
764-
modules_to_load: A set of module ids to load from the checkpoint.
770+
`MultiRLModule`.
771+
modules_to_load: A set of `RLModule` ids to load from the checkpoint.
765772
rl_module_ckpt_dirs: A mapping from module ids to the path to a
766-
checkpoint for a single agent RLModule.
773+
checkpoint for a single agent `RLModule`.
767774
"""
768775
if not (multi_rl_module_ckpt_dir or rl_module_ckpt_dirs):
769776
raise ValueError(
770-
"At least one of `multi_rl_module_ckpt_dir` or "
771-
"`rl_module_ckpt_dirs` must be provided!"
777+
f"At least one of `multi_rl_module_ckpt_dir` or "
778+
f"`rl_module_ckpt_dirs` must be provided. "
779+
f"Got {multi_rl_module_ckpt_dir=} and {rl_module_ckpt_dirs=}."
780+
)
781+
782+
if modules_to_load and not multi_rl_module_ckpt_dir:
783+
raise ValueError(
784+
f"`modules_to_load` can only be specified if a "
785+
f"multi_rl_module_ckpt_dir is provided. "
786+
f"Got {modules_to_load=} and {multi_rl_module_ckpt_dir=}."
772787
)
773-
if multi_rl_module_ckpt_dir:
774-
multi_rl_module_ckpt_dir = pathlib.Path(multi_rl_module_ckpt_dir)
775-
if rl_module_ckpt_dirs:
776-
for module_id, path in rl_module_ckpt_dirs.items():
777-
rl_module_ckpt_dirs[module_id] = pathlib.Path(path)
778788

779789
# MultiRLModule checkpoint is provided.
780790
if multi_rl_module_ckpt_dir:
781791
# Restore the entire MultiRLModule state.
782792
if modules_to_load is None:
783793
self.restore_from_path(
784-
multi_rl_module_ckpt_dir,
794+
path=multi_rl_module_ckpt_dir,
785795
component=COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE,
786-
)
796+
),
787797
# Restore individual module IDs.
788798
else:
789799
for module_id in modules_to_load:
800+
path = multi_rl_module_ckpt_dir + "/" + module_id
790801
self.restore_from_path(
791-
multi_rl_module_ckpt_dir / module_id,
802+
path=path,
792803
component=(
793804
COMPONENT_LEARNER
794805
+ "/"
@@ -800,7 +811,7 @@ def load_module_state(
800811
if rl_module_ckpt_dirs:
801812
for module_id, path in rl_module_ckpt_dirs.items():
802813
self.restore_from_path(
803-
path,
814+
path=path,
804815
component=(
805816
COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE + "/" + module_id
806817
),

rllib/core/rl_module/rl_module.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,12 @@ class RLModuleSpec:
6363
Note that `inference_only=True` AND `learner_only=True` is not allowed.
6464
model_config: The model config dict or default RLlib dataclass to use.
6565
catalog_class: The Catalog class to use.
66-
load_state_path: The path to the module state to load from. NOTE: This must be
67-
an absolute path.
66+
load_state_path: The path to the RLModule state to load from.
67+
Deprecated. This field will be removed in the future Ray release.
68+
To restore RLModule state use
69+
`Algorithm.restore_from_path(path=..., component=...)` instead.
70+
See docs for more details: :
71+
https://docs.ray.io/en/latest/rllib/rl-modules.html#checkpointing-rlmodules
6872
"""
6973

7074
module_class: Optional[Type["RLModule"]] = None

rllib/examples/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ def run_rllib_example_script_experiment(
539539
else 1
540540
) * num_actual_learners
541541
# Define compute resources used.
542-
config.resources(num_gpus=0) # old API stack setting
542+
config.resources(num_gpus=0) # @OldAPIStack
543543
if args.num_learners is not None:
544544
config.learners(num_learners=args.num_learners)
545545

rllib/models/torch/mingpt.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# @OldAPIStack
2+
13
# LICENSE: MIT
24
"""
35
Adapted from https://github.com/karpathy/minGPT

rllib/offline/tests/test_offline_prelearner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ def setUpClass(cls):
3636
def tearDownClass(cls):
3737
ray.shutdown()
3838

39+
# Delete the cluster address just in case.
40+
ray._common.utils.reset_ray_address()
41+
3942
def setUp(self) -> None:
4043
data_path = "offline/tests/data/cartpole/cartpole-v1_large"
4144
self.base_path = Path(__file__).parents[2]

rllib/utils/tests/run_memory_leak_tests.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# @OldAPIStack
2+
13
#!/usr/bin/env python
24
# Runs one or more memory leak tests.
35
#

0 commit comments

Comments
 (0)