Skip to content

Commit 723c4b3

Browse files
committed
fix eval configs
Signed-off-by: Terry Kong <terryk@nvidia.com>
1 parent e30d658 commit 723c4b3

2 files changed

Lines changed: 111 additions & 10 deletions

File tree

nemo_rl/data/__init__.py

Lines changed: 109 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import NotRequired, TypedDict
15+
from typing import Literal, NotRequired, TypedDict
1616

1717

1818
# TODO: split this typed dict up so it can be PreferenceDataConfig | ResponseDataConfig | etc
@@ -31,7 +31,7 @@ class DataConfig(TypedDict):
3131
add_generation_prompt: NotRequired[bool]
3232
add_system_prompt: NotRequired[bool]
3333
split: NotRequired[str | None]
34-
shuffle: NotRequired[bool]
34+
shuffle: bool
3535
seed: NotRequired[int | None]
3636
download_dir: NotRequired[str]
3737
train_data_path: NotRequired[str]
@@ -43,9 +43,110 @@ class DataConfig(TypedDict):
4343
num_workers: NotRequired[int]
4444

4545

46-
# TODO: split this typed dict up so it can be MMLUConfig | AIMEConfig | etc
47-
# so that we can type check the configs more rigorously as opposed to saying everything
48-
# is not required.
49-
class MathDataConfig(DataConfig):
50-
problem_key: NotRequired[str]
51-
solution_key: NotRequired[str]
46+
# ===============================================================================
47+
# Eval Dataset Configs
48+
# ===============================================================================
49+
# These configs correspond to the eval datasets in data/datasets/eval_datasets/
50+
# Note: TypedDict doesn't allow narrowing types in child classes, so each config
51+
# is defined independently with common fields repeated.
52+
53+
54+
class MMLUEvalDataConfig(TypedDict):
55+
"""Config for MMLU and multilingual MMLU datasets.
56+
57+
Supports dataset_name: "mmlu" or "mmlu_{language}" where language is one of:
58+
AR-XY, BN-BD, DE-DE, EN-US, ES-LA, FR-FR, HI-IN, ID-ID, IT-IT, JA-JP,
59+
KO-KR, PT-BR, ZH-CN, SW-KE, YO-NG
60+
"""
61+
62+
max_input_seq_length: int
63+
dataset_name: Literal[
64+
"mmlu",
65+
"mmlu_AR-XY",
66+
"mmlu_BN-BD",
67+
"mmlu_DE-DE",
68+
"mmlu_EN-US",
69+
"mmlu_ES-LA",
70+
"mmlu_FR-FR",
71+
"mmlu_HI-IN",
72+
"mmlu_ID-ID",
73+
"mmlu_IT-IT",
74+
"mmlu_JA-JP",
75+
"mmlu_KO-KR",
76+
"mmlu_PT-BR",
77+
"mmlu_ZH-CN",
78+
"mmlu_SW-KE",
79+
"mmlu_YO-NG",
80+
]
81+
shuffle: NotRequired[bool]
82+
prompt_file: NotRequired[str | None]
83+
system_prompt_file: NotRequired[str | None]
84+
85+
86+
class MMLUProEvalDataConfig(TypedDict):
87+
"""Config for MMLU Pro dataset."""
88+
89+
max_input_seq_length: int
90+
dataset_name: Literal["mmlu_pro"]
91+
shuffle: NotRequired[bool]
92+
prompt_file: NotRequired[str | None]
93+
system_prompt_file: NotRequired[str | None]
94+
95+
96+
class AIMEEvalDataConfig(TypedDict):
97+
"""Config for AIME datasets."""
98+
99+
max_input_seq_length: int
100+
dataset_name: Literal["aime2024", "aime2025"]
101+
shuffle: NotRequired[bool]
102+
prompt_file: NotRequired[str | None]
103+
system_prompt_file: NotRequired[str | None]
104+
105+
106+
class GPQAEvalDataConfig(TypedDict):
107+
"""Config for GPQA datasets."""
108+
109+
max_input_seq_length: int
110+
dataset_name: Literal["gpqa", "gpqa_diamond"]
111+
shuffle: NotRequired[bool]
112+
prompt_file: NotRequired[str | None]
113+
system_prompt_file: NotRequired[str | None]
114+
115+
116+
class MathEvalDataConfig(TypedDict):
117+
"""Config for Math datasets."""
118+
119+
max_input_seq_length: int
120+
dataset_name: Literal["math", "math500"]
121+
shuffle: NotRequired[bool]
122+
prompt_file: NotRequired[str | None]
123+
system_prompt_file: NotRequired[str | None]
124+
125+
126+
class LocalMathEvalDataConfig(TypedDict):
127+
"""Config for local math datasets loaded from files.
128+
129+
dataset_name can be a URL or local file path.
130+
Requires additional fields: problem_key, solution_key, file_format, split.
131+
"""
132+
133+
max_input_seq_length: int
134+
dataset_name: str # URL or file path
135+
problem_key: str
136+
solution_key: str
137+
file_format: Literal["csv", "json"]
138+
split: NotRequired[str | None]
139+
shuffle: NotRequired[bool]
140+
prompt_file: NotRequired[str | None]
141+
system_prompt_file: NotRequired[str | None]
142+
143+
144+
# Union type for all eval dataset configs
145+
EvalDataConfigType = (
146+
MMLUEvalDataConfig
147+
| MMLUProEvalDataConfig
148+
| AIMEEvalDataConfig
149+
| GPQAEvalDataConfig
150+
| MathEvalDataConfig
151+
| LocalMathEvalDataConfig
152+
)

nemo_rl/evals/eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from transformers import AutoTokenizer
2626

2727
from nemo_rl.algorithms.utils import set_seed
28-
from nemo_rl.data import MathDataConfig
28+
from nemo_rl.data import EvalDataConfigType
2929
from nemo_rl.data.collate_fn import eval_collate_fn
3030
from nemo_rl.data.datasets import AllTaskProcessedDataset
3131
from nemo_rl.data.llm_message_utils import get_keys_from_message_log
@@ -58,7 +58,7 @@ class MasterConfig(TypedDict):
5858
eval: EvalConfig
5959
generation: GenerationConfig # Fixed: was 'generate'
6060
tokenizer: TokenizerConfig # Added missing tokenizer key
61-
data: MathDataConfig
61+
data: EvalDataConfigType
6262
env: _PassThroughMathConfig
6363
cluster: ClusterConfig
6464

0 commit comments

Comments
 (0)