11import copy
22import itertools
3- import pathlib
43from functools import partial
54from typing import (
65 TYPE_CHECKING ,
@@ -732,63 +731,75 @@ def update_from_episodes(self, episodes, **kwargs):
732731 def async_update (self , * args , ** kwargs ):
733732 pass
734733
735- @Deprecated (new = "LearnerGroup.load_from_path(path=..., component=...)" , error = False )
734+ @Deprecated (
735+ old = "LearnerGroup.load_module_state()" ,
736+ help = "To restore RLModule or MultiRLModule state "
737+ "use LearnerGroup.restore_from_path(path=..., component=...). "
738+ "See docs for more details: "
739+ "https://docs.ray.io/en/latest/rllib/rl-modules.html#checkpointing-rlmodules" ,
740+ error = False ,
741+ )
736742 def load_module_state (
737743 self ,
738744 * ,
739745 multi_rl_module_ckpt_dir : Optional [str ] = None ,
740746 modules_to_load : Optional [Set [str ]] = None ,
741747 rl_module_ckpt_dirs : Optional [Dict [ModuleID , str ]] = None ,
742748 ) -> None :
743- """Load the checkpoints of the modules being trained by this LearnerGroup.
749+ """Load the checkpoints of the modules being trained by ` LearnerGroup` .
744750
745751 `load_module_state` can be used 3 ways:
746- 1. Load a checkpoint for the MultiRLModule being trained by this
747- LearnerGroup. Limit the modules that are loaded from the checkpoint
748- by specifying the `modules_to_load` argument.
749- 2. Load the checkpoint(s) for single agent RLModules that
750- are in the MultiRLModule being trained by this LearnerGroup.
751- 3. Load a checkpoint for the MultiRLModule being trained by this
752- LearnerGroup and load the checkpoint(s) for single agent RLModules
753- that are in the MultiRLModule. The checkpoints for the single
754- agent RLModules take precedence over the module states in the
755- MultiRLModule checkpoint.
756-
757- NOTE: At lease one of multi_rl_module_ckpt_dir or rl_module_ckpt_dirs is
758- must be specified. modules_to_load can only be specified if
759- multi_rl_module_ckpt_dir is specified.
752+ 1. Load a checkpoint for the `MultiRLModule` being trained by this
753+ `LearnerGroup`. Optionally, limit the modules that are loaded
754+ from the checkpoint by specifying the `modules_to_load` argument.
755+ 2. Load the checkpoint(s) for single agent `RLModules` that
756+ are in the `MultiRLModule` being trained by this `LearnerGroup`.
757+ 3. Load a checkpoint for the `MultiRLModule` being trained by this
758+ `LearnerGroup` and load the checkpoint(s) for single agent `RLModules`
759+ that are in the `MultiRLModule`. The checkpoints for the single
760+ agent `RLModules` take precedence over the module states in the
761+ `MultiRLModule` checkpoint.
762+
763+ At least one of `multi_rl_module_ckpt_dir` or `rl_module_ckpt_dirs`
764+ must be specified.
765+ `modules_to_load` can only be specified if `multi_rl_module_ckpt_dir`
766+ is provided.
760767
761768 Args:
762769 multi_rl_module_ckpt_dir: The path to the checkpoint for the
763- MultiRLModule.
764- modules_to_load: A set of module ids to load from the checkpoint.
770+ ` MultiRLModule` .
771+ modules_to_load: A set of `RLModule` ids to load from the checkpoint.
765772 rl_module_ckpt_dirs: A mapping from module ids to the path to a
766- checkpoint for a single agent RLModule.
773+ checkpoint for a single agent ` RLModule` .
767774 """
768775 if not (multi_rl_module_ckpt_dir or rl_module_ckpt_dirs ):
769776 raise ValueError (
770- "At least one of `multi_rl_module_ckpt_dir` or "
771- "`rl_module_ckpt_dirs` must be provided!"
777+ f"At least one of `multi_rl_module_ckpt_dir` or "
778+ f"`rl_module_ckpt_dirs` must be provided. "
779+ f"Got { multi_rl_module_ckpt_dir = } and { rl_module_ckpt_dirs = } ."
780+ )
781+
782+ if modules_to_load and not multi_rl_module_ckpt_dir :
783+ raise ValueError (
784+ f"`modules_to_load` can only be specified if a "
785+ f"multi_rl_module_ckpt_dir is provided. "
786+ f"Got { modules_to_load = } and { multi_rl_module_ckpt_dir = } ."
772787 )
773- if multi_rl_module_ckpt_dir :
774- multi_rl_module_ckpt_dir = pathlib .Path (multi_rl_module_ckpt_dir )
775- if rl_module_ckpt_dirs :
776- for module_id , path in rl_module_ckpt_dirs .items ():
777- rl_module_ckpt_dirs [module_id ] = pathlib .Path (path )
778788
779789 # MultiRLModule checkpoint is provided.
780790 if multi_rl_module_ckpt_dir :
781791 # Restore the entire MultiRLModule state.
782792 if modules_to_load is None :
783793 self .restore_from_path (
784- multi_rl_module_ckpt_dir ,
794+ path = multi_rl_module_ckpt_dir ,
785795 component = COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE ,
786- )
796+ ),
787797 # Restore individual module IDs.
788798 else :
789799 for module_id in modules_to_load :
800+ path = multi_rl_module_ckpt_dir + "/" + module_id
790801 self .restore_from_path (
791- multi_rl_module_ckpt_dir / module_id ,
802+ path = path ,
792803 component = (
793804 COMPONENT_LEARNER
794805 + "/"
@@ -800,7 +811,7 @@ def load_module_state(
800811 if rl_module_ckpt_dirs :
801812 for module_id , path in rl_module_ckpt_dirs .items ():
802813 self .restore_from_path (
803- path ,
814+ path = path ,
804815 component = (
805816 COMPONENT_LEARNER + "/" + COMPONENT_RL_MODULE + "/" + module_id
806817 ),
0 commit comments