@@ -705,9 +705,13 @@ def freeze_moe_router(megatron_model):
705705 if use_peft :
706706 peft_cfg = policy_cfg ["megatron_cfg" ].get ("peft" , {})
707707 if "dim" not in peft_cfg or peft_cfg ["dim" ] is None :
708- raise ValueError ("If megtatron_cfg.peft.enabled is True, dim must be set in peft_cfg" )
708+ raise ValueError (
709+ "If megtatron_cfg.peft.enabled is True, dim must be set in peft_cfg"
710+ )
709711 if "alpha" not in peft_cfg or peft_cfg ["alpha" ] is None :
710- raise ValueError ("If megtatron_cfg.peft.enabled is True, alpha must be set in peft_cfg" )
712+ raise ValueError (
713+ "If megtatron_cfg.peft.enabled is True, alpha must be set in peft_cfg"
714+ )
711715 peft = LoRA (
712716 target_modules = peft_cfg .get ("target_modules" , []),
713717 exclude_modules = peft_cfg .get ("exclude_modules" , []),
@@ -875,13 +879,17 @@ def setup_reference_model_state(
875879
876880 ref_pre_wrap_hooks = []
877881 use_peft = config ["megatron_cfg" ].get ("peft" , {}).get ("enabled" , False )
878-
882+
879883 if use_peft :
880884 peft_cfg = config ["megatron_cfg" ].get ("peft" , {})
881885 if "dim" not in peft_cfg or peft_cfg ["dim" ] is None :
882- raise ValueError ("If megtatron_cfg.peft.enabled is True, dim must be set in peft_cfg" )
886+ raise ValueError (
887+ "If megtatron_cfg.peft.enabled is True, dim must be set in peft_cfg"
888+ )
883889 if "alpha" not in peft_cfg or peft_cfg ["alpha" ] is None :
884- raise ValueError ("If megtatron_cfg.peft.enabled is True, alpha must be set in peft_cfg" )
890+ raise ValueError (
891+ "If megtatron_cfg.peft.enabled is True, alpha must be set in peft_cfg"
892+ )
885893 peft = LoRA (
886894 target_modules = peft_cfg .get ("target_modules" , []),
887895 exclude_modules = peft_cfg .get ("exclude_modules" , []),
@@ -931,7 +939,7 @@ def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]:
931939 ref_megatron_cfg .checkpoint .finetune = False
932940
933941 print ("Loading the Reference Model" )
934-
942+
935943 if should_load_checkpoint :
936944 load_checkpoint (
937945 ref_state ,
0 commit comments