1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- from typing import NotRequired , TypedDict
15+ from typing import Literal , NotRequired , TypedDict
1616
1717
1818# TODO: split this typed dict up so it can be PreferenceDataConfig | ResponseDataConfig | etc
@@ -31,7 +31,7 @@ class DataConfig(TypedDict):
3131 add_generation_prompt : NotRequired [bool ]
3232 add_system_prompt : NotRequired [bool ]
3333 split : NotRequired [str | None ]
34- shuffle : NotRequired [ bool ]
34+ shuffle : bool
3535 seed : NotRequired [int | None ]
3636 download_dir : NotRequired [str ]
3737 train_data_path : NotRequired [str ]
@@ -43,9 +43,110 @@ class DataConfig(TypedDict):
4343 num_workers : NotRequired [int ]
4444
4545
46- # TODO: split this typed dict up so it can be MMLUConfig | AIMEConfig | etc
47- # so that we can type check the configs more rigorously as opposed to saying everything
48- # is not required.
49- class MathDataConfig (DataConfig ):
50- problem_key : NotRequired [str ]
51- solution_key : NotRequired [str ]
46+ # ===============================================================================
47+ # Eval Dataset Configs
48+ # ===============================================================================
49+ # These configs correspond to the eval datasets in data/datasets/eval_datasets/
50+ # Note: TypedDict doesn't allow narrowing types in child classes, so each config
51+ # is defined independently with common fields repeated.
52+
53+
54+ class MMLUEvalDataConfig (TypedDict ):
55+ """Config for MMLU and multilingual MMLU datasets.
56+
57+ Supports dataset_name: "mmlu" or "mmlu_{language}" where language is one of:
58+ AR-XY, BN-BD, DE-DE, EN-US, ES-LA, FR-FR, HI-IN, ID-ID, IT-IT, JA-JP,
59+ KO-KR, PT-BR, ZH-CN, SW-KE, YO-NG
60+ """
61+
62+ max_input_seq_length : int
63+ dataset_name : Literal [
64+ "mmlu" ,
65+ "mmlu_AR-XY" ,
66+ "mmlu_BN-BD" ,
67+ "mmlu_DE-DE" ,
68+ "mmlu_EN-US" ,
69+ "mmlu_ES-LA" ,
70+ "mmlu_FR-FR" ,
71+ "mmlu_HI-IN" ,
72+ "mmlu_ID-ID" ,
73+ "mmlu_IT-IT" ,
74+ "mmlu_JA-JP" ,
75+ "mmlu_KO-KR" ,
76+ "mmlu_PT-BR" ,
77+ "mmlu_ZH-CN" ,
78+ "mmlu_SW-KE" ,
79+ "mmlu_YO-NG" ,
80+ ]
81+ shuffle : NotRequired [bool ]
82+ prompt_file : NotRequired [str | None ]
83+ system_prompt_file : NotRequired [str | None ]
84+
85+
86+ class MMLUProEvalDataConfig (TypedDict ):
87+ """Config for MMLU Pro dataset."""
88+
89+ max_input_seq_length : int
90+ dataset_name : Literal ["mmlu_pro" ]
91+ shuffle : NotRequired [bool ]
92+ prompt_file : NotRequired [str | None ]
93+ system_prompt_file : NotRequired [str | None ]
94+
95+
96+ class AIMEEvalDataConfig (TypedDict ):
97+ """Config for AIME datasets."""
98+
99+ max_input_seq_length : int
100+ dataset_name : Literal ["aime2024" , "aime2025" ]
101+ shuffle : NotRequired [bool ]
102+ prompt_file : NotRequired [str | None ]
103+ system_prompt_file : NotRequired [str | None ]
104+
105+
106+ class GPQAEvalDataConfig (TypedDict ):
107+ """Config for GPQA datasets."""
108+
109+ max_input_seq_length : int
110+ dataset_name : Literal ["gpqa" , "gpqa_diamond" ]
111+ shuffle : NotRequired [bool ]
112+ prompt_file : NotRequired [str | None ]
113+ system_prompt_file : NotRequired [str | None ]
114+
115+
116+ class MathEvalDataConfig (TypedDict ):
117+ """Config for Math datasets."""
118+
119+ max_input_seq_length : int
120+ dataset_name : Literal ["math" , "math500" ]
121+ shuffle : NotRequired [bool ]
122+ prompt_file : NotRequired [str | None ]
123+ system_prompt_file : NotRequired [str | None ]
124+
125+
126+ class LocalMathEvalDataConfig (TypedDict ):
127+ """Config for local math datasets loaded from files.
128+
129+ dataset_name can be a URL or local file path.
130+ Requires additional fields: problem_key, solution_key, file_format, split.
131+ """
132+
133+ max_input_seq_length : int
134+ dataset_name : str # URL or file path
135+ problem_key : str
136+ solution_key : str
137+ file_format : Literal ["csv" , "json" ]
138+ split : NotRequired [str | None ]
139+ shuffle : NotRequired [bool ]
140+ prompt_file : NotRequired [str | None ]
141+ system_prompt_file : NotRequired [str | None ]
142+
143+
144+ # Union type for all eval dataset configs
145+ EvalDataConfigType = (
146+ MMLUEvalDataConfig
147+ | MMLUProEvalDataConfig
148+ | AIMEEvalDataConfig
149+ | GPQAEvalDataConfig
150+ | MathEvalDataConfig
151+ | LocalMathEvalDataConfig
152+ )
0 commit comments