Skip to content

Commit 58b645d

Browse files
Generalise PayuManager for layout generation and update ACCESSOM3Profiling
1 parent bbd111e commit 58b645d

2 files changed

Lines changed: 109 additions & 27 deletions

File tree

src/access/profiling/access_models.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@
55
from pathlib import Path
66

77
from access.config import YAMLParser
8+
from access.config.accessom3_layout_input import (
9+
OM3ConfigLayout,
10+
OM3LayoutSearchConfig,
11+
generate_om3_core_layouts_from_node_count,
12+
generate_om3_perturb_block,
13+
)
814
from access.config.esm1p6_layout_input import (
915
LayoutSearchConfig,
1016
LayoutTuple,
@@ -25,6 +31,20 @@
2531
class ACCESSOM3Profiling(PayuManager):
2632
"""Handles profiling of ACCESS-OM3 configurations."""
2733

34+
@property
35+
def model_type(self) -> str:
36+
return "access-om3"
37+
38+
def layout_key(self, layout: OM3ConfigLayout) -> tuple:
39+
"""
40+
Return stable identity for OM3 layout.
41+
"""
42+
return (
43+
int(layout.ncpus),
44+
tuple(sorted(layout.pool_ntasks.items())),
45+
tuple(sorted(layout.pool_rootpe.items())),
46+
)
47+
2848
def get_component_logs(self, path: Path) -> dict[str, ProfilingLog]:
2949
"""Returns available profiling logs for the components in ACCESS-OM3.
3050
@@ -50,6 +70,31 @@ def get_component_logs(self, path: Path) -> dict[str, ProfilingLog]:
5070

5171
return logs
5272

73+
def generate_core_layouts_from_node_count(
74+
self, num_nodes: float, cores_per_node: int, layout_search_config: OM3LayoutSearchConfig | None = None
75+
) -> list:
76+
return generate_om3_core_layouts_from_node_count(
77+
num_nodes, cores_per_node, layout_search_config=layout_search_config
78+
)
79+
80+
def generate_perturbation_block(
81+
self,
82+
layout: OM3ConfigLayout,
83+
num_nodes: float,
84+
branch_name_prefix: str,
85+
walltime_hrs: float,
86+
layout_search_config: OM3LayoutSearchConfig | None = None,
87+
block_overrides: dict | None = None,
88+
) -> dict:
89+
return generate_om3_perturb_block(
90+
layout=layout,
91+
num_nodes=num_nodes,
92+
layout_search_config=layout_search_config,
93+
branch_name_prefix=branch_name_prefix,
94+
walltime_hrs=walltime_hrs,
95+
block_overrides=block_overrides,
96+
)
97+
5398

5499
class ESM16Profiling(PayuManager):
55100
"""Handles profiling of ACCESS-ESM1.6 configurations."""

src/access/profiling/payu_manager.py

Lines changed: 64 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
import logging
55
from abc import ABC, abstractmethod
66
from collections.abc import Callable
7-
from datetime import timedelta
87
from pathlib import Path
98

109
from access.config import YAMLParser
10+
from access.config.accessom3_layout_input import OM3ConfigLayout, OM3LayoutSearchConfig
1111
from access.config.esm1p6_layout_input import LayoutSearchConfig
1212
from access.config.layout_config import LayoutTuple
1313
from experiment_generator.experiment_generator import ExperimentGenerator
@@ -19,6 +19,10 @@
1919

2020
logger = logging.getLogger(__name__)
2121

22+
# Type aliases for better readability
23+
SearchConfig = LayoutSearchConfig | OM3LayoutSearchConfig
24+
Layout = LayoutTuple | OM3ConfigLayout
25+
2226

2327
class PayuManager(ProfilingManager, ABC):
2428
"""Abstract base class to handle profiling of Payu configurations."""
@@ -47,27 +51,43 @@ def generate_core_layouts_from_node_count(
4751
self,
4852
num_nodes: float,
4953
cores_per_node: int,
50-
layout_search_config: LayoutSearchConfig | None = None,
51-
) -> list:
54+
layout_search_config: SearchConfig | None = None,
55+
) -> list[Layout]:
5256
"""Generates core layouts from the given number of nodes.
5357
5458
Args:
5559
num_nodes (float): Number of nodes.
5660
cores_per_node (int): Number of cores per node.
57-
layout_search_config (LayoutSearchConfig | None): Configuration for layout search.
61+
layout_search_config (SearchConfig | None): Configuration for layout search.
5862
"""
5963

6064
@abstractmethod
61-
def generate_perturbation_block(self, layout: LayoutTuple, branch_name_prefix: str) -> dict:
65+
def generate_perturbation_block(
66+
self,
67+
layout: Layout,
68+
num_nodes: float,
69+
branch_name_prefix: str,
70+
walltime_hrs: float,
71+
layout_search_config: SearchConfig | None = None,
72+
block_overrides: dict | None = None,
73+
) -> dict:
6274
"""Generates a perturbation block for the given layout to be passed to the experiment generator.
6375
6476
Args:
65-
layout (LayoutTuple): Core layout tuple.
77+
layout (Layout): Core layout tuple.
78+
num_nodes (float): Number of nodes.
6679
branch_name_prefix (str): Branch name prefix.
80+
walltime_hrs (float): Walltime in hours.
81+
layout_search_config (SearchConfig | None): Configuration for layout search.
82+
block_overrides (dict | None): Overrides for the perturbation block.
6783
Returns:
6884
dict: Perturbation block configuration.
6985
"""
7086

87+
@abstractmethod
88+
def layout_key(self, layout: Layout) -> tuple:
89+
"""Returns a stable key for a layout so PayuManager can deduplicate layouts."""
90+
7191
@property
7292
def nruns(self) -> int:
7393
"""Returns the number of repetitions for the Payu experiments.
@@ -121,19 +141,21 @@ def generate_scaling_experiments(
121141
num_nodes_list: list[float],
122142
control_options: dict,
123143
cores_per_node: int,
124-
tol_around_ctrl_ratio: float,
125-
max_wasted_ncores_frac: float | Callable[[float], float],
144+
branch_name_prefix: str,
126145
walltime: float | Callable[[float], float],
146+
layout_search_config_builder: Callable[[float], SearchConfig] | None = None,
147+
block_overrides_builder: Callable[[Layout, float], dict] | None = None,
127148
) -> None:
128-
"""Generates scaling experiments using the ExperimentGenerator.
149+
"""Generates scaling experiments using ExperimentGenerator.
129150
130151
Args:
131152
num_nodes_list (list[int]): List of number of nodes to generate experiments for.
132153
control_options (dict): Options for the control experiment.
133154
cores_per_node (int): Number of cores per node.
134-
tol_around_ctrl_ratio (float): Tolerance around control core ratio for layout generation.
135-
max_wasted_ncores_frac (float | Callable[[float], float]): Maximum fraction of wasted cores allowed.
155+
branch_name_prefix (str): Branch name prefix for the generated experiments.
136156
walltime (float | Callable[[float], float]): Walltime in hours for each experiment.
157+
layout_search_config_builder: optional function (num_nodes)->OM3LayoutSearchConfig
158+
block_overrides_builder: optional function (layout, num_nodes)->dict.
137159
"""
138160

139161
generator_config = {
@@ -144,37 +166,52 @@ def generate_scaling_experiments(
144166
"repository_directory": self._repository_directory,
145167
"control_branch_name": "ctrl",
146168
"Control_Experiment": control_options,
169+
"Perturbation_Experiment": {},
147170
}
148171

149172
seen_layouts = set()
150173
seqnum = 1
151-
generator_config["Perturbation_Experiment"] = {}
174+
152175
for num_nodes in num_nodes_list:
153-
mwf = max_wasted_ncores_frac(num_nodes) if callable(max_wasted_ncores_frac) else max_wasted_ncores_frac
154-
layout_config = LayoutSearchConfig(tol_around_ctrl_ratio=tol_around_ctrl_ratio, max_wasted_ncores_frac=mwf)
176+
layout_search_config = layout_search_config_builder(num_nodes) if layout_search_config_builder else None
155177
layouts = self.generate_core_layouts_from_node_count(
156178
num_nodes,
157179
cores_per_node=cores_per_node,
158-
layout_search_config=layout_config,
180+
layout_search_config=layout_search_config,
159181
)
182+
160183
if not layouts:
161184
logger.warning(f"No layouts found for {num_nodes} nodes")
162185
continue
163186

164-
layouts = [x for x in layouts if x not in seen_layouts]
165-
seen_layouts.update(layouts)
166-
logger.info(f"Generated {len(layouts)} layouts for {num_nodes} nodes. Layouts: {layouts}")
167-
168-
# TODO: the branch name needs to be simpler and model agnostic
169-
branch_name = f"layout-unused-cores-to-cice-{layout_config.allocate_unused_cores_to_ice}"
170-
walltime_hrs = walltime(num_nodes) if callable(walltime) else walltime
171-
187+
unique_layouts: list[Layout] = []
172188
for layout in layouts:
173-
pert_config = self.generate_perturbation_block(layout=layout, branch_name_prefix=branch_name)
174-
branch = pert_config["branches"][0]
175-
pert_config["config.yaml"]["walltime"] = str(timedelta(hours=walltime_hrs))
189+
key = self.layout_key(layout)
190+
if key in seen_layouts:
191+
continue
192+
seen_layouts.add(key)
193+
unique_layouts.append(layout)
194+
195+
logger.info(
196+
f"Generated {len(layouts)} layouts for {num_nodes} nodes "
197+
f"with {len(unique_layouts)} unique layouts: {unique_layouts}"
198+
)
176199

177-
generator_config["Perturbation_Experiment"][f"Experiment_{seqnum}"] = pert_config
200+
for layout in unique_layouts:
201+
walltime_hrs = walltime(num_nodes) if callable(walltime) else walltime
202+
block_overrides = block_overrides_builder(layout, num_nodes) if block_overrides_builder else None
203+
perturb_block = self.generate_perturbation_block(
204+
layout=layout,
205+
num_nodes=num_nodes,
206+
branch_name_prefix=branch_name_prefix,
207+
walltime_hrs=walltime_hrs,
208+
layout_search_config=layout_search_config,
209+
block_overrides=block_overrides,
210+
)
211+
212+
generator_config["Perturbation_Experiment"][f"Experiment_{seqnum}"] = perturb_block
213+
214+
branch = perturb_block["branches"][0]
178215
self.experiments[branch] = ProfilingExperiment(self.work_dir / branch / self._repository_directory)
179216

180217
seqnum += 1

0 commit comments

Comments
 (0)