Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/access/profiling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
with suppress(PackageNotFoundError):
__version__ = version("access-profiling")

from access.profiling.access_models import ESM16Profiling, RAM3Profiling
from access.profiling.access_models import ACCESSOM3Profiling, ESM16Profiling, RAM3Profiling
from access.profiling.cice5_parser import CICE5ProfilingParser
from access.profiling.cylc_parser import CylcDBReader, CylcProfilingParser
from access.profiling.esmf_parser import ESMFSummaryProfilingParser
Expand All @@ -25,6 +25,7 @@
"CICE5ProfilingParser",
"PayuJSONProfilingParser",
"ESMFSummaryProfilingParser",
"ACCESSOM3Profiling",
"ESM16Profiling",
"CylcProfilingParser",
"CylcDBReader",
Expand Down
74 changes: 74 additions & 0 deletions src/access/profiling/access_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
from pathlib import Path

from access.config import YAMLParser
from access.config.accessom3_layout_input import (
OM3ConfigLayout,
OM3LayoutSearchConfig,
generate_om3_core_layouts_from_node_count,
generate_om3_perturb_block,
)
from access.config.esm1p6_layout_input import (
LayoutSearchConfig,
LayoutTuple,
Expand All @@ -22,6 +28,74 @@
logger = logging.getLogger(__name__)


class ACCESSOM3Profiling(PayuManager):
"""Handles profiling of ACCESS-OM3 configurations."""

@property
def model_type(self) -> str:
return "access-om3"

def layout_key(self, layout: OM3ConfigLayout) -> tuple:
"""
Return stable identity for OM3 layout.
"""
return (
int(layout.ncpus),
tuple(sorted(layout.pool_ntasks.items())),
tuple(sorted(layout.pool_rootpe.items())),
)

def get_component_logs(self, path: Path) -> dict[str, ProfilingLog]:
"""Returns available profiling logs for the components in ACCESS-OM3.

Args:
path (Path): Path to the output directory.
Returns:
dict[str, ProfilingLog]: Dictionary mapping component names to their ProfilingLog instances.
"""
logs = {}
parser = YAMLParser()

config_path = path / "config.yaml"
payu_config = parser.parse(config_path.read_text())
mom6_logfile = path / f"{payu_config['model']}.out"
if mom6_logfile.is_file():
logger.debug(f"Found MOM log file: {mom6_logfile}")
logs["MOM"] = ProfilingLog(mom6_logfile, FMSProfilingParser(has_hits=True))

cice6_logfile = path / "log" / "ice.log"
if cice6_logfile.is_file():
logger.debug(f"Found CICE log file: {cice6_logfile}")
logs["CICE"] = ProfilingLog(cice6_logfile, CICE5ProfilingParser())

return logs

def generate_core_layouts_from_node_count(
self, num_nodes: float, cores_per_node: int, layout_search_config: OM3LayoutSearchConfig | None = None
) -> list:
return generate_om3_core_layouts_from_node_count(
num_nodes, cores_per_node, layout_search_config=layout_search_config
)

def generate_perturbation_block(
self,
layout: OM3ConfigLayout,
num_nodes: float,
branch_name_prefix: str,
walltime_hrs: float,
layout_search_config: OM3LayoutSearchConfig | None = None,
block_overrides: dict | None = None,
) -> dict:
return generate_om3_perturb_block(
layout=layout,
num_nodes=num_nodes,
layout_search_config=layout_search_config,
branch_name_prefix=branch_name_prefix,
walltime_hrs=walltime_hrs,
block_overrides=block_overrides,
)


class ESM16Profiling(PayuManager):
"""Handles profiling of ACCESS-ESM1.6 configurations."""

Expand Down
91 changes: 64 additions & 27 deletions src/access/profiling/payu_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Callable
from datetime import timedelta
from pathlib import Path

from access.config import YAMLParser
from access.config.accessom3_layout_input import OM3ConfigLayout, OM3LayoutSearchConfig
from access.config.esm1p6_layout_input import LayoutSearchConfig
from access.config.layout_config import LayoutTuple
from experiment_generator.experiment_generator import ExperimentGenerator
Expand All @@ -19,6 +19,10 @@

logger = logging.getLogger(__name__)

# Type aliases for better readability
SearchConfig = LayoutSearchConfig | OM3LayoutSearchConfig
Layout = LayoutTuple | OM3ConfigLayout


class PayuManager(ProfilingManager, ABC):
"""Abstract base class to handle profiling of Payu configurations."""
Expand Down Expand Up @@ -47,27 +51,43 @@ def generate_core_layouts_from_node_count(
self,
num_nodes: float,
cores_per_node: int,
layout_search_config: LayoutSearchConfig | None = None,
) -> list:
layout_search_config: SearchConfig | None = None,
) -> list[Layout]:
"""Generates core layouts from the given number of nodes.

Args:
num_nodes (float): Number of nodes.
cores_per_node (int): Number of cores per node.
layout_search_config (LayoutSearchConfig | None): Configuration for layout search.
layout_search_config (SearchConfig | None): Configuration for layout search.
"""

@abstractmethod
def generate_perturbation_block(self, layout: LayoutTuple, branch_name_prefix: str) -> dict:
def generate_perturbation_block(
self,
layout: Layout,
num_nodes: float,
branch_name_prefix: str,
walltime_hrs: float,
layout_search_config: SearchConfig | None = None,
block_overrides: dict | None = None,
) -> dict:
"""Generates a perturbation block for the given layout to be passed to the experiment generator.

Args:
layout (LayoutTuple): Core layout tuple.
layout (Layout): Core layout tuple.
num_nodes (float): Number of nodes.
branch_name_prefix (str): Branch name prefix.
walltime_hrs (float): Walltime in hours.
layout_search_config (SearchConfig | None): Configuration for layout search.
block_overrides (dict | None): Overrides for the perturbation block.
Returns:
dict: Perturbation block configuration.
"""

@abstractmethod
def layout_key(self, layout: Layout) -> tuple:
"""Returns a stable key for a layout so PayuManager can deduplicate layouts."""

@property
def nruns(self) -> int:
"""Returns the number of repetitions for the Payu experiments.
Expand Down Expand Up @@ -121,19 +141,21 @@ def generate_scaling_experiments(
num_nodes_list: list[float],
control_options: dict,
cores_per_node: int,
tol_around_ctrl_ratio: float,
max_wasted_ncores_frac: float | Callable[[float], float],
branch_name_prefix: str,
walltime: float | Callable[[float], float],
layout_search_config_builder: Callable[[float], SearchConfig] | None = None,
block_overrides_builder: Callable[[Layout, float], dict] | None = None,
) -> None:
"""Generates scaling experiments using the ExperimentGenerator.
"""Generates scaling experiments using ExperimentGenerator.

Args:
num_nodes_list (list[int]): List of number of nodes to generate experiments for.
control_options (dict): Options for the control experiment.
cores_per_node (int): Number of cores per node.
tol_around_ctrl_ratio (float): Tolerance around control core ratio for layout generation.
max_wasted_ncores_frac (float | Callable[[float], float]): Maximum fraction of wasted cores allowed.
branch_name_prefix (str): Branch name prefix for the generated experiments.
walltime (float | Callable[[float], float]): Walltime in hours for each experiment.
layout_search_config_builder: optional function (num_nodes)->OM3LayoutSearchConfig
block_overrides_builder: optional function (layout, num_nodes)->dict.
"""

generator_config = {
Expand All @@ -144,37 +166,52 @@ def generate_scaling_experiments(
"repository_directory": self._repository_directory,
"control_branch_name": "ctrl",
"Control_Experiment": control_options,
"Perturbation_Experiment": {},
}

seen_layouts = set()
seqnum = 1
generator_config["Perturbation_Experiment"] = {}

for num_nodes in num_nodes_list:
mwf = max_wasted_ncores_frac(num_nodes) if callable(max_wasted_ncores_frac) else max_wasted_ncores_frac
layout_config = LayoutSearchConfig(tol_around_ctrl_ratio=tol_around_ctrl_ratio, max_wasted_ncores_frac=mwf)
layout_search_config = layout_search_config_builder(num_nodes) if layout_search_config_builder else None
layouts = self.generate_core_layouts_from_node_count(
num_nodes,
cores_per_node=cores_per_node,
layout_search_config=layout_config,
layout_search_config=layout_search_config,
)

if not layouts:
logger.warning(f"No layouts found for {num_nodes} nodes")
continue

layouts = [x for x in layouts if x not in seen_layouts]
seen_layouts.update(layouts)
logger.info(f"Generated {len(layouts)} layouts for {num_nodes} nodes. Layouts: {layouts}")

# TODO: the branch name needs to be simpler and model agnostic
branch_name = f"layout-unused-cores-to-cice-{layout_config.allocate_unused_cores_to_ice}"
walltime_hrs = walltime(num_nodes) if callable(walltime) else walltime

unique_layouts: list[Layout] = []
for layout in layouts:
pert_config = self.generate_perturbation_block(layout=layout, branch_name_prefix=branch_name)
branch = pert_config["branches"][0]
pert_config["config.yaml"]["walltime"] = str(timedelta(hours=walltime_hrs))
key = self.layout_key(layout)
if key in seen_layouts:
continue
seen_layouts.add(key)
unique_layouts.append(layout)

logger.info(
f"Generated {len(layouts)} layouts for {num_nodes} nodes "
f"with {len(unique_layouts)} unique layouts: {unique_layouts}"
)

generator_config["Perturbation_Experiment"][f"Experiment_{seqnum}"] = pert_config
for layout in unique_layouts:
walltime_hrs = walltime(num_nodes) if callable(walltime) else walltime
block_overrides = block_overrides_builder(layout, num_nodes) if block_overrides_builder else None
perturb_block = self.generate_perturbation_block(
layout=layout,
num_nodes=num_nodes,
branch_name_prefix=branch_name_prefix,
walltime_hrs=walltime_hrs,
layout_search_config=layout_search_config,
block_overrides=block_overrides,
)

generator_config["Perturbation_Experiment"][f"Experiment_{seqnum}"] = perturb_block

branch = perturb_block["branches"][0]
self.experiments[branch] = ProfilingExperiment(self.work_dir / branch / self._repository_directory)

seqnum += 1
Expand Down
Loading