44import logging
55from abc import ABC , abstractmethod
66from collections .abc import Callable
7- from datetime import timedelta
87from pathlib import Path
98
109from access .config import YAMLParser
10+ from access .config .accessom3_layout_input import OM3ConfigLayout , OM3LayoutSearchConfig
1111from access .config .esm1p6_layout_input import LayoutSearchConfig
1212from access .config .layout_config import LayoutTuple
1313from experiment_generator .experiment_generator import ExperimentGenerator
1919
2020logger = logging .getLogger (__name__ )
2121
22+ # Type aliases for better readability
23+ SearchConfig = LayoutSearchConfig | OM3LayoutSearchConfig
24+ Layout = LayoutTuple | OM3ConfigLayout
25+
2226
2327class PayuManager (ProfilingManager , ABC ):
2428 """Abstract base class to handle profiling of Payu configurations."""
@@ -47,27 +51,43 @@ def generate_core_layouts_from_node_count(
4751 self ,
4852 num_nodes : float ,
4953 cores_per_node : int ,
50- layout_search_config : LayoutSearchConfig | None = None ,
51- ) -> list :
54+ layout_search_config : SearchConfig | None = None ,
55+ ) -> list [ Layout ] :
5256 """Generates core layouts from the given number of nodes.
5357
5458 Args:
5559 num_nodes (float): Number of nodes.
5660 cores_per_node (int): Number of cores per node.
57- layout_search_config (LayoutSearchConfig | None): Configuration for layout search.
61+ layout_search_config (SearchConfig | None): Configuration for layout search.
5862 """
5963
6064 @abstractmethod
61- def generate_perturbation_block (self , layout : LayoutTuple , branch_name_prefix : str ) -> dict :
65+ def generate_perturbation_block (
66+ self ,
67+ layout : Layout ,
68+ num_nodes : float ,
69+ branch_name_prefix : str ,
70+ walltime_hrs : float ,
71+ layout_search_config : SearchConfig | None = None ,
72+ block_overrides : dict | None = None ,
73+ ) -> dict :
6274 """Generates a perturbation block for the given layout to be passed to the experiment generator.
6375
6476 Args:
65- layout (LayoutTuple): Core layout tuple.
77+ layout (Layout): Core layout tuple.
78+ num_nodes (float): Number of nodes.
6679 branch_name_prefix (str): Branch name prefix.
80+ walltime_hrs (float): Walltime in hours.
81+ layout_search_config (SearchConfig | None): Configuration for layout search.
82+ block_overrides (dict | None): Overrides for the perturbation block.
6783 Returns:
6884 dict: Perturbation block configuration.
6985 """
7086
87+ @abstractmethod
88+ def layout_key (self , layout : Layout ) -> tuple :
89+ """Returns a stable key for a layout so PayuManager can deduplicate layouts."""
90+
7191 @property
7292 def nruns (self ) -> int :
7393 """Returns the number of repetitions for the Payu experiments.
@@ -121,19 +141,21 @@ def generate_scaling_experiments(
121141 num_nodes_list : list [float ],
122142 control_options : dict ,
123143 cores_per_node : int ,
124- tol_around_ctrl_ratio : float ,
125- max_wasted_ncores_frac : float | Callable [[float ], float ],
144+ branch_name_prefix : str ,
126145 walltime : float | Callable [[float ], float ],
146+ layout_search_config_builder : Callable [[float ], SearchConfig ] | None = None ,
147+ block_overrides_builder : Callable [[Layout , float ], dict ] | None = None ,
127148 ) -> None :
128- """Generates scaling experiments using the ExperimentGenerator.
149+ """Generates scaling experiments using ExperimentGenerator.
129150
130151 Args:
131152 num_nodes_list (list[int]): List of number of nodes to generate experiments for.
132153 control_options (dict): Options for the control experiment.
133154 cores_per_node (int): Number of cores per node.
134- tol_around_ctrl_ratio (float): Tolerance around control core ratio for layout generation.
135- max_wasted_ncores_frac (float | Callable[[float], float]): Maximum fraction of wasted cores allowed.
155+ branch_name_prefix (str): Branch name prefix for the generated experiments.
136156 walltime (float | Callable[[float], float]): Walltime in hours for each experiment.
157+ layout_search_config_builder: optional function (num_nodes)->OM3LayoutSearchConfig
158+ block_overrides_builder: optional function (layout, num_nodes)->dict.
137159 """
138160
139161 generator_config = {
@@ -144,37 +166,52 @@ def generate_scaling_experiments(
144166 "repository_directory" : self ._repository_directory ,
145167 "control_branch_name" : "ctrl" ,
146168 "Control_Experiment" : control_options ,
169+ "Perturbation_Experiment" : {},
147170 }
148171
149172 seen_layouts = set ()
150173 seqnum = 1
151- generator_config [ "Perturbation_Experiment" ] = {}
174+
152175 for num_nodes in num_nodes_list :
153- mwf = max_wasted_ncores_frac (num_nodes ) if callable (max_wasted_ncores_frac ) else max_wasted_ncores_frac
154- layout_config = LayoutSearchConfig (tol_around_ctrl_ratio = tol_around_ctrl_ratio , max_wasted_ncores_frac = mwf )
176+ layout_search_config = layout_search_config_builder (num_nodes ) if layout_search_config_builder else None
155177 layouts = self .generate_core_layouts_from_node_count (
156178 num_nodes ,
157179 cores_per_node = cores_per_node ,
158- layout_search_config = layout_config ,
180+ layout_search_config = layout_search_config ,
159181 )
182+
160183 if not layouts :
161184 logger .warning (f"No layouts found for { num_nodes } nodes" )
162185 continue
163186
164- layouts = [x for x in layouts if x not in seen_layouts ]
165- seen_layouts .update (layouts )
166- logger .info (f"Generated { len (layouts )} layouts for { num_nodes } nodes. Layouts: { layouts } " )
167-
168- # TODO: the branch name needs to be simpler and model agnostic
169- branch_name = f"layout-unused-cores-to-cice-{ layout_config .allocate_unused_cores_to_ice } "
170- walltime_hrs = walltime (num_nodes ) if callable (walltime ) else walltime
171-
187+ unique_layouts : list [Layout ] = []
172188 for layout in layouts :
173- pert_config = self .generate_perturbation_block (layout = layout , branch_name_prefix = branch_name )
174- branch = pert_config ["branches" ][0 ]
175- pert_config ["config.yaml" ]["walltime" ] = str (timedelta (hours = walltime_hrs ))
189+ key = self .layout_key (layout )
190+ if key in seen_layouts :
191+ continue
192+ seen_layouts .add (key )
193+ unique_layouts .append (layout )
194+
195+ logger .info (
196+ f"Generated { len (layouts )} layouts for { num_nodes } nodes "
197+ f"with { len (unique_layouts )} unique layouts: { unique_layouts } "
198+ )
176199
177- generator_config ["Perturbation_Experiment" ][f"Experiment_{ seqnum } " ] = pert_config
200+ for layout in unique_layouts :
201+ walltime_hrs = walltime (num_nodes ) if callable (walltime ) else walltime
202+ block_overrides = block_overrides_builder (layout , num_nodes ) if block_overrides_builder else None
203+ perturb_block = self .generate_perturbation_block (
204+ layout = layout ,
205+ num_nodes = num_nodes ,
206+ branch_name_prefix = branch_name_prefix ,
207+ walltime_hrs = walltime_hrs ,
208+ layout_search_config = layout_search_config ,
209+ block_overrides = block_overrides ,
210+ )
211+
212+ generator_config ["Perturbation_Experiment" ][f"Experiment_{ seqnum } " ] = perturb_block
213+
214+ branch = perturb_block ["branches" ][0 ]
178215 self .experiments [branch ] = ProfilingExperiment (self .work_dir / branch / self ._repository_directory )
179216
180217 seqnum += 1
0 commit comments