Skip to content

Commit 9c7e4e2

Browse files
rootvadam5
authored andcommitted
Ran linter fixes
1 parent c799cdf commit 9c7e4e2

4 files changed

Lines changed: 17 additions & 9 deletions

File tree

examples/nemo_gym/run_grpo_nemo_gym.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323

2424
import ray
2525
from omegaconf import OmegaConf
26-
from wandb import Table
2726

2827
from nemo_rl.algorithms.grpo import (
2928
ColocatablePolicyInterface,
@@ -54,6 +53,7 @@
5453
register_omegaconf_resolvers,
5554
)
5655
from nemo_rl.utils.logger import get_next_experiment_dir
56+
from wandb import Table
5757

5858

5959
def parse_args() -> tuple[argparse.Namespace, list[str]]:

nemo_rl/experience/rollouts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
import ray
2828
import torch
2929
from transformers import PreTrainedTokenizerBase
30-
from wandb import Histogram, Table
3130

3231
from nemo_rl.data.interfaces import (
3332
DatumSpec,
@@ -50,6 +49,7 @@
5049
GenerationOutputSpec,
5150
)
5251
from nemo_rl.utils.timer import Timer
52+
from wandb import Histogram, Table
5353

5454
TokenizerType = PreTrainedTokenizerBase
5555

nemo_rl/models/megatron/setup.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -705,9 +705,13 @@ def freeze_moe_router(megatron_model):
705705
if use_peft:
706706
peft_cfg = policy_cfg["megatron_cfg"].get("peft", {})
707707
if "dim" not in peft_cfg or peft_cfg["dim"] is None:
708-
raise ValueError("If megtatron_cfg.peft.enabled is True, dim must be set in peft_cfg")
708+
raise ValueError(
709+
"If megtatron_cfg.peft.enabled is True, dim must be set in peft_cfg"
710+
)
709711
if "alpha" not in peft_cfg or peft_cfg["alpha"] is None:
710-
raise ValueError("If megtatron_cfg.peft.enabled is True, alpha must be set in peft_cfg")
712+
raise ValueError(
713+
"If megtatron_cfg.peft.enabled is True, alpha must be set in peft_cfg"
714+
)
711715
peft = LoRA(
712716
target_modules=peft_cfg.get("target_modules", []),
713717
exclude_modules=peft_cfg.get("exclude_modules", []),
@@ -875,13 +879,17 @@ def setup_reference_model_state(
875879

876880
ref_pre_wrap_hooks = []
877881
use_peft = config["megatron_cfg"].get("peft", {}).get("enabled", False)
878-
882+
879883
if use_peft:
880884
peft_cfg = config["megatron_cfg"].get("peft", {})
881885
if "dim" not in peft_cfg or peft_cfg["dim"] is None:
882-
raise ValueError("If megtatron_cfg.peft.enabled is True, dim must be set in peft_cfg")
886+
raise ValueError(
887+
"If megtatron_cfg.peft.enabled is True, dim must be set in peft_cfg"
888+
)
883889
if "alpha" not in peft_cfg or peft_cfg["alpha"] is None:
884-
raise ValueError("If megtatron_cfg.peft.enabled is True, alpha must be set in peft_cfg")
890+
raise ValueError(
891+
"If megtatron_cfg.peft.enabled is True, alpha must be set in peft_cfg"
892+
)
885893
peft = LoRA(
886894
target_modules=peft_cfg.get("target_modules", []),
887895
exclude_modules=peft_cfg.get("exclude_modules", []),
@@ -931,7 +939,7 @@ def composed_peft_hook(model: list[MegatronModule]) -> list[MegatronModule]:
931939
ref_megatron_cfg.checkpoint.finetune = False
932940

933941
print("Loading the Reference Model")
934-
942+
935943
if should_load_checkpoint:
936944
load_checkpoint(
937945
ref_state,

nemo_rl/utils/logger.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
import requests
3131
import swanlab
3232
import torch
33-
import wandb
3433
from matplotlib import pyplot as plt
3534
from prometheus_client.parser import text_string_to_metric_families
3635
from prometheus_client.samples import Sample
@@ -40,6 +39,7 @@
4039
from rich.panel import Panel
4140
from torch.utils.tensorboard import SummaryWriter
4241

42+
import wandb
4343
from nemo_rl.data.interfaces import LLMMessageLogType
4444
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
4545

0 commit comments

Comments
 (0)