diff --git a/src/experiment_runner/main.py b/src/experiment_runner/main.py index e18747e..c8589a5 100644 --- a/src/experiment_runner/main.py +++ b/src/experiment_runner/main.py @@ -1,38 +1,57 @@ import argparse +import os + from .utils import read_yaml from .experiment_runner import ExperimentRunner def main(): """ - Managing ACCESS experiments. + Managing ACCESS experiment runs. + + This script loads experiment configurations from a YAML file + and invokes the ExperimentRunner to produce the required setups. - Args: - INPUT_YAML_FILE (str, optional): - Path to the YAML file specifying parameter values for experiment runs. - Defaults to "Expts_runner.yaml". + Command-line Arguments: + -i, --input-yaml-file (str, optional): + Path to the YAML file specifying parameter values for the experiment runs. + Defaults to 'Experiment_runner.yaml' if it exists. """ + parser = argparse.ArgumentParser( - description=""" - Manage ACCESS experiments. - Latest version and help: TODO - """ + description=( + "Manage ACCESS experiments using configurable YAML input.\n" + "If no YAML file is specified, the tool will look for 'Experiment_runner.yaml' " + "in the current directory.\n" + "If that file is missing, you must specify one with -i / --input-yaml-file." + ), + formatter_class=argparse.RawTextHelpFormatter, ) + parser.add_argument( - "INPUT_YAML_FILE", + "-i", + "--input-yaml-file", type=str, - nargs="?", - default="Experiment_runner.yaml", - help="YAML file specifying parameter values for experiment runs." - "Default is Experiment_runner.yaml", + help=( + "Path to the YAML file specifying parameter values for experiment runs.\n" + "Defaults to 'Experiment_runner.yaml' if present in the current directory." + ), ) args = parser.parse_args() - input_yaml = args.INPUT_YAML_FILE - indata = read_yaml(input_yaml) - generator = ExperimentRunner(indata) - generator.run() + if args.input_yaml_file: + input_yaml = args.input_yaml_file + elif os.path.exists("Experiment_runner.yaml"): + input_yaml = "Experiment_runner.yaml" + else: + parser.error( + "No YAML file specified and 'Experiment_runner.yaml' not found.\n" + "Please provide one using -i / --input-yaml-file." + ) + # Load the YAML file + indata = read_yaml(input_yaml) -if __name__ == "__main__": - main() + # Run the experiment runner + runner = ExperimentRunner(indata) + runner.run() diff --git a/src/experiment_runner/pbs_job_manager.py b/src/experiment_runner/pbs_job_manager.py index 6a878b0..8898576 100644 --- a/src/experiment_runner/pbs_job_manager.py +++ b/src/experiment_runner/pbs_job_manager.py @@ -36,30 +36,44 @@ def output_existing_pbs_jobs() -> dict: with open(current_job_status_path, "r", encoding="utf-8") as f: pbs_job_file = f.read() + def _flush_pair(): + nonlocal current_key, current_value, job_id + if current_key and job_id: + pbs_jobs[job_id][current_key] = current_value.strip() + current_key = None + current_value = "" + pbs_job_file = pbs_job_file.replace("\t", " ") for line in pbs_job_file.splitlines(): line = line.rstrip() if not line: + _flush_pair() continue + if line.startswith("Job Id:"): + _flush_pair() job_id = line.split(":", 1)[1].strip() pbs_jobs[job_id] = {} - current_key = None - current_value = "" - elif line.startswith(" ") and current_key: # 8 indents multi-line + continue + + if line.startswith(" ") and current_key: # 8 indents multi-line current_value += line.strip() - elif line.startswith(" ") and " = " in line: # 4 indents for new pair + continue + + if line.startswith(" ") and " = " in line: # 4 indents for new pair # Save the previous multi-line value - if current_key: - pbs_jobs[job_id][current_key] = current_value.strip() + _flush_pair() key, value = line.split(" = ", 1) # save key current_key = key.strip() current_value = value.strip() + continue + + # end of file, flush last pair + _flush_pair() # Clean up the temporary file: `current_job_status` current_job_status_path.unlink() - return pbs_jobs diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..74d8608 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,186 @@ +import pytest +from pathlib import Path +import experiment_runner.experiment_runner as exp_runner + + +class DummyBranch: + def __init__(self, name: str): + self.name = name + + +class DummyCommit: + def __init__(self, hexsha: str): + self.hexsha = hexsha + + +class DummyHead: + def __init__(self, name: str, commit: DummyCommit): + self.name = name + self.commit = commit + + +class DummyHeadContainer: + def __init__(self, commit): + self.commit = commit + + +class DummyGit: + """ + This mimics repo.git commands - checkout, pull, reset, diff etc. + """ + + def __init__(self, repo): + self._repo = repo + + def checkout(self, *args): + # checkout or checkout -b origin/ + if len(args) == 1: + target = args[0] + self._repo._checkout_existing_branch(target) + elif len(args) == 3 and args[0] == "-b" and args[2].startswith("origin/"): + target = args[1] + self._repo._create_and_checkout_branch(target) + + def pull(self, *args): + if self._repo._pull_raises: + raise self._repo._exc.GitCommandError("Simulated pull error") + if self._repo._new_commit_hash is not None: + self._repo.head.commit = DummyCommit(self._repo._new_commit_hash) + + def reset(self, *args): + if self._repo._new_commit_hash is not None: + self._repo.head.commit = DummyCommit(self._repo._new_commit_hash) + + def diff(self, *args): + return self._repo._diff_output + + +class DummyRemote: + def __init__(self, repo): + self._repo = repo + + def fetch(self, prune=False): + self._repo._fetch_called = True + + +class DummyRemotes: + def __init__(self, repo): + self.origin = DummyRemote(repo) + + +class DummyRepo: + def __init__(self, path: Path): + self.path = path + self.heads = {} # name -> DummyHead + self.head = DummyHeadContainer(commit=DummyCommit("initial")) + self.remotes = DummyRemotes(self) + self.git = DummyGit(self) + self._fetch_called = False + self._pull_raises = False + self._new_commit_hash = None + self._diff_output = "" + self._fetch_called = False + + def _checkout_existing_branch(self, branch_name: str): + if branch_name not in self.heads: + raise self._GitCommandError(f"Branch {branch_name} is missing!") + # self.head.commit = self.heads[branch_name].commit + + def _create_and_checkout_branch(self, branch_name: str): + self.heads[branch_name] = DummyHead(branch_name, self.head.commit) + + +class PayuCalls: + def __init__(self): + self.clone_calls = [] + self.list_calls = [] + + +def _dummy_clone(repository, directory: str, branch, **kwargs): + directory = Path(directory) + directory.mkdir(parents=True, exist_ok=True) + (directory / "config.yaml").write_text("queue: normal\n") + + _PAYU_CALLS.clone_calls.append( + { + "repository": str(repository), + "directory": str(directory), + "branch": branch, + "kwargs": kwargs, + } + ) + + +def _dummy_list_branches(config_path: Path) -> None: + _PAYU_CALLS.list_calls.append(Path(config_path)) + + +_PAYU_CALLS = PayuCalls() + + +@pytest.fixture +def payu_calls(): + global _PAYU_CALLS + _PAYU_CALLS = PayuCalls() + return _PAYU_CALLS + + +class DummyPbsJobManager: + def __init__(self): + self.calls = [] + + def pbs_job_runs(self, expt: Path, nrun: int): + self.calls.append((expt, nrun)) + + +@pytest.fixture +def pbs_job_recorder(): + return DummyPbsJobManager() + + +@pytest.fixture +def patch_runner(monkeypatch, tmp_path, payu_calls, pbs_job_recorder): + """ + Patch external calls in experiment_runner.experiment_runner. + - payu.branch.clone, payu.branch.list_branches + - PBSJobManager + - git.Repo + """ + monkeypatch.setattr(exp_runner, "clone", _dummy_clone, raising=True) + monkeypatch.setattr(exp_runner, "list_branches", _dummy_list_branches, raising=True) + monkeypatch.setattr( + exp_runner, "PBSJobManager", lambda: pbs_job_recorder, raising=True + ) + + class _Exc: + class GitCommandError(Exception): + pass + + monkeypatch.setattr(exp_runner.git, "exc", _Exc, raising=True) + + def _make_repo(path): + repo = DummyRepo(path) + repo._exc = exp_runner.git.exc # store the exc container here + return repo + + monkeypatch.setattr(exp_runner.git, "Repo", _make_repo, raising=True) + + class Controls: + pass + + controls = Controls() + controls.make_repo = _make_repo + controls.payu = payu_calls + controls.pbs = pbs_job_recorder + return controls + + +@pytest.fixture +def indata(tmp_path: Path) -> dict: + return { + "test_path": tmp_path / "tests", + "repository_directory": "test_repo", + "running_branches": ["perturb_1", "perturb_2"], + "nruns": [1, 2], + "keep_uuid": True, + } diff --git a/tests/test_experiment_runner.py b/tests/test_experiment_runner.py new file mode 100644 index 0000000..91f998d --- /dev/null +++ b/tests/test_experiment_runner.py @@ -0,0 +1,178 @@ +from pathlib import Path +import pytest +import experiment_runner.experiment_runner as exp_runner + + +def test_list_branches_is_called(indata, monkeypatch, patch_runner): + exp_runner.ExperimentRunner(indata).run() + assert patch_runner.payu.list_calls + + +def test_error_when_no_running_branches(indata, monkeypatch, patch_runner): + input_data = indata + input_data["running_branches"] = [] + er = exp_runner.ExperimentRunner(input_data) + with pytest.raises(ValueError): + er.run() + + +def test_update_existing_repo_creates_branch_if_missing( + tmp_path, indata, monkeypatch, patch_runner +): + for branch in indata["running_branches"]: + dir_path = tmp_path / "tests" / branch / indata["repository_directory"] + dir_path.mkdir(parents=True, exist_ok=True) + (dir_path / "config.yaml").write_text("queue: normal\n") + + indata["test_path"] = tmp_path / "tests" + + base_fac = patch_runner.make_repo + + def make_repo(path): + repo = base_fac(path) + repo._new_commit_hash = "abc1234" + repo._diff_output = "config.yaml\nnuopc.runseq\n" + return repo + + monkeypatch.setattr(exp_runner.git, "Repo", make_repo, raising=True) + + exp_runner.ExperimentRunner(indata).run() + + assert len(patch_runner.payu.clone_calls) == 0 + assert len(patch_runner.pbs.calls) == 2 + + +def test_update_existing_repo_already_up_to_date( + tmp_path, indata, monkeypatch, patch_runner, capsys +): + for branch in indata["running_branches"]: + dir_path = tmp_path / "tests" / branch / indata["repository_directory"] + dir_path.mkdir(parents=True, exist_ok=True) + (dir_path / "config.yaml").write_text("queue: normal\n") + + indata["test_path"] = tmp_path / "tests" + + base_fac = patch_runner.make_repo + + def make_repo(path): + repo = base_fac(path) + # Make branches present so code goes "checkout " rather than "-b" + for b in indata["running_branches"]: + repo.heads[b] = True + # dont move head -> current_commit == new_commit + repo._new_commit_hash = None + repo._diff_output = "" + return repo + + monkeypatch.setattr(exp_runner.git, "Repo", make_repo, raising=True) + + exp_runner.ExperimentRunner(indata).run() + + out = capsys.readouterr().out + assert "already up to date" in out + + +def test_update_existing_repo_outer_except_returns_false_and_caller_prints( + tmp_path, indata, monkeypatch, patch_runner, capsys +): + for branch in indata["running_branches"]: + dir_path = tmp_path / "tests" / branch / indata["repository_directory"] + dir_path.mkdir(parents=True, exist_ok=True) + (dir_path / "config.yaml").write_text("queue: normal\n") + + indata["test_path"] = tmp_path / "tests" + + base_fac = patch_runner.make_repo + + def make_repo(path): + repo = base_fac(path) + # Make branches present so checkout() succeeds + for b in indata["running_branches"]: + repo.heads[b] = True + # Move head so the code proceeds to compute rel_path and call diff(...) + repo._new_commit_hash = "abcd123" + + # Now force .git.diff(...) to raise the SAME exception class prod code catches + def raise_gitcmderror(*args, **kwargs): + raise repo._exc.GitCommandError("boom from diff") + + repo.git.diff = raise_gitcmderror + return repo + + monkeypatch.setattr(exp_runner.git, "Repo", make_repo, raising=True) + + exp_runner.ExperimentRunner(indata).run() + + out = capsys.readouterr().out + assert "Failed to update existing repo" in out or "leaving as it is" in out + + +def test_run_clones_and_runs_jobs(indata, monkeypatch, patch_runner): + exp_runner.ExperimentRunner(indata).run() + + assert len(patch_runner.payu.clone_calls) == len(indata["running_branches"]) + + expt1 = ( + Path(indata["test_path"]) + / indata["running_branches"][0] + / indata["repository_directory"] + ) + expt2 = ( + Path(indata["test_path"]) + / indata["running_branches"][1] + / indata["repository_directory"] + ) + assert patch_runner.pbs.calls == [(expt1, 1), (expt2, 2)] + + +def test_run_existing_dirs_update_success(tmp_path, indata, monkeypatch, patch_runner): + expt_dirs = [] + for branch in indata["running_branches"]: + dir_path = tmp_path / "tests" / branch / indata["repository_directory"] + dir_path.mkdir(parents=True, exist_ok=True) + (dir_path / "config.yaml").write_text("queue: normal\n") + expt_dirs.append(dir_path) + + base_fac = patch_runner.make_repo + + def make_repo(path): + repo = base_fac(path) + for b in indata["running_branches"]: + repo.heads[b] = True + repo._new_commit_hash = "abc1234" + repo._diff_output = "config.yaml\nnuopc.runseq\n" + return repo + + monkeypatch.setattr(exp_runner.git, "Repo", make_repo, raising=True) + + exp_runner.ExperimentRunner(indata).run() + + assert len(patch_runner.payu.clone_calls) == 0 + assert len(patch_runner.pbs.calls) == 2 + + +def test_run_existing_dirs_pull_failure_uses_reset( + tmp_path, indata, monkeypatch, patch_runner +): + + for branch in indata["running_branches"]: + dir_path = tmp_path / "tests" / branch / indata["repository_directory"] + dir_path.mkdir(parents=True, exist_ok=True) + (dir_path / "config.yaml").write_text("queue: normal\n") + + base_fac = patch_runner.make_repo + + def make_repo(path): + repo = base_fac(path) + for b in indata["running_branches"]: + repo.heads[b] = True + repo._pull_raises = True + repo._new_commit_hash = "def5678" + repo._diff_output = "config.yaml\nnuopc.runseq\n" + return repo + + monkeypatch.setattr(exp_runner.git, "Repo", make_repo, raising=True) + + exp_runner.ExperimentRunner(indata).run() + + assert len(patch_runner.pbs.calls) == 2 diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..ad28397 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,93 @@ +import sys +from pathlib import Path +import experiment_runner.main as main_module +import pytest + + +def test_main_runs_with_i_flag(tmp_path, monkeypatch): + yaml = tmp_path / "example.yaml" + yaml.write_text( + f""" + test_path: {tmp_path / "custom_test_path"} + repository_directory: test_repo + running_branches: ["branch1", "branch2"] + nruns: [1, 2] + keep_uuid: True +""", + ) + + called = {} + + class DummyER: + def __init__(self, indata): + called["indata"] = indata + + def run(self): + called["run"] = True + + monkeypatch.setattr(main_module, "ExperimentRunner", DummyER, raising=True) + monkeypatch.setattr(sys, "argv", ["prog", "--input-yaml-file", yaml.as_posix()]) + + main_module.main() + + assert called.get("run") is True + assert Path(called["indata"]["test_path"]) == tmp_path / "custom_test_path" + assert called["indata"]["repository_directory"] == "test_repo" + assert called["indata"]["running_branches"] == ["branch1", "branch2"] + assert called["indata"]["nruns"] == [1, 2] + assert called["indata"]["keep_uuid"] is True + + +def test_main_uses_default_yaml_when_present(tmp_path, monkeypatch): + yaml = tmp_path / "Experiment_runner.yaml" + yaml.write_text( + f""" + test_path: {tmp_path / "custom_test_path"} + repository_directory: test_repo + running_branches: ["branch1", "branch2"] + nruns: [1, 2] + keep_uuid: True +""", + ) + + monkeypatch.chdir(tmp_path) + + called = {} + + class DummyER: + def __init__(self, indata): + called["indata"] = indata + + def run(self): + called["run"] = True + + monkeypatch.setattr(main_module, "ExperimentRunner", DummyER, raising=True) + monkeypatch.setattr(sys, "argv", ["prog"]) + + main_module.main() + + assert called.get("run") is True + assert Path(called["indata"]["test_path"]) == tmp_path / "custom_test_path" + assert called["indata"]["repository_directory"] == "test_repo" + assert called["indata"]["running_branches"] == ["branch1", "branch2"] + assert called["indata"]["nruns"] == [1, 2] + assert called["indata"]["keep_uuid"] is True + + +def test_main_errors_when_no_yaml_provided_and_default_missing( + tmp_path, monkeypatch, capsys +): + monkeypatch.chdir(tmp_path) + + monkeypatch.setattr(sys, "argv", ["prog"]) + + with pytest.raises(SystemExit) as exc_info: + main_module.main() + + assert exc_info.value.code != 0 + + captured = capsys.readouterr() + + err = captured.err + assert "Experiment_runner.yaml" in err + assert "-i / --input-yaml-file" in err diff --git a/tests/test_pbs_job_manager.py b/tests/test_pbs_job_manager.py new file mode 100644 index 0000000..b2e2094 --- /dev/null +++ b/tests/test_pbs_job_manager.py @@ -0,0 +1,211 @@ +from pathlib import Path +import pytest +from experiment_runner.pbs_job_manager import output_existing_pbs_jobs +from experiment_runner.pbs_job_manager import ( + _extract_current_and_parent_path, + GADI_PREFIX, +) +from experiment_runner.pbs_job_manager import PBSJobManager +import os + + +def test_output_existing_pbs_jobs_parses_jobs_correctly(tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + + sample_qstat_output = ( + "Job Id: 123.gadi\n" + " Error_Path = gadi.nci.org.au:/g/data/group/parentA/expA/job.e123\n" + " multiline_value = This is a long value\n" + " that continues on the next line.\n" + " job_state = R\n" + "\n" + "Job Id: 999.gadi\n" + " Error_Path = gadi.nci.org.au:/g/data/group/parentB/expB/job.e999\n" + " job_state = Q\n" + ) + + def dummy_run(*args, **kwargs): + # simulate qstat -f > current_job_status + (tmp_path / "current_job_status").write_text(sample_qstat_output) + + monkeypatch.setattr("subprocess.run", dummy_run, raising=True) + + jobs = output_existing_pbs_jobs() + + assert "123.gadi" in jobs and "999.gadi" in jobs + assert ( + jobs["123.gadi"]["Error_Path"] + == "gadi.nci.org.au:/g/data/group/parentA/expA/job.e123" + ) + assert jobs["123.gadi"]["job_state"] == "R" + assert ( + jobs["999.gadi"]["Error_Path"] + == "gadi.nci.org.au:/g/data/group/parentB/expB/job.e999" + ) + assert jobs["999.gadi"]["job_state"] == "Q" + assert not (tmp_path / "current_job_status").exists() + + +def test_extract_current_and_parent_path(tmp_path): + current_path = tmp_path / "parentA" / "expA" / "job.e123" + current_path.parent.mkdir(parents=True, exist_ok=True) + + folder, parent = _extract_current_and_parent_path(GADI_PREFIX + str(current_path)) + + assert folder.name == "expA" + assert parent.name == "parentA" + + +def test_pbs_job_runs_not_duplicated(tmp_path, monkeypatch): + current_path = tmp_path / "parentA" / "expA" + current_path.mkdir(parents=True, exist_ok=True) + + # no relevant jobs running + monkeypatch.setattr( + "experiment_runner.pbs_job_manager.output_existing_pbs_jobs", + lambda: {"irrelevant": {"a": "b"}}, + raising=True, + ) + + # ensures non-duplicated path + seen_check = {} + + def dummy_check(self, path, jobs): + seen_check["args"] = (path, jobs) + return False + + # capture start args + called = {} + + def dummy_start(self, path, nruns, duplicated): + called["args"] = (path, nruns, duplicated) + + monkeypatch.setattr( + PBSJobManager, "_check_duplicated_jobs", dummy_check, raising=True + ) + monkeypatch.setattr( + PBSJobManager, "_start_experiment_runs", dummy_start, raising=True + ) + + pbs_job_manager = PBSJobManager() + pbs_job_manager.pbs_job_runs(current_path, nruns=3) + + assert seen_check["args"][0] == current_path + assert called["args"] == (current_path, 3, False) + + +def test_pbs_job_runs_with_duplicated(tmp_path, monkeypatch): + current_path = tmp_path / "parentA" / "expA" + current_path.mkdir(parents=True, exist_ok=True) + + monkeypatch.setattr( + "experiment_runner.pbs_job_manager.output_existing_pbs_jobs", + lambda: { + "123.gadi": { + "Error_Path": GADI_PREFIX + str(current_path / "job.e123"), + "job_state": "R", + } + }, + raising=True, + ) + + # ensures duplicated path + def dummy_check(self, path, jobs): + return True + + called_start = {} + + def dummy_start(self, path, nruns, duplicated): + called_start["args"] = (path, nruns, duplicated) + + monkeypatch.setattr( + PBSJobManager, "_check_duplicated_jobs", dummy_check, raising=True + ) + monkeypatch.setattr( + PBSJobManager, "_start_experiment_runs", dummy_start, raising=True + ) + + pbs_job_manager = PBSJobManager() + pbs_job_manager.pbs_job_runs(current_path, nruns=2) + + assert called_start["args"] == (current_path, 2, True) + + +def test_start_experiment_runs_return_early_if_duplicated( + tmp_path, monkeypatch, capsys +): + pbs_job_manager = PBSJobManager() + current_path = tmp_path / "parentA" / "expA" + current_path.mkdir(parents=True, exist_ok=True) + + pbs_job_manager._start_experiment_runs(current_path, nruns=2, duplicated=True) + + out = capsys.readouterr().out + assert "-- " not in out + + +def test_check_duplicated_jobs_detects(tmp_path, monkeypatch, capsys): + pbs_job_manager = PBSJobManager() + current_path = tmp_path / "parentA" / "expA" + current_path.mkdir(parents=True, exist_ok=True) + + jobs = { + "123.gadi": { + "Error_Path": GADI_PREFIX + str(current_path / "job.e123"), + "job_state": "R", + } + } + + duplicated = pbs_job_manager._check_duplicated_jobs(current_path, jobs) + + assert duplicated is True + out = capsys.readouterr().out + assert "You have duplicated runs for" in out + + +def test_start_experiment_runs_counts_and_starts(tmp_path, monkeypatch, capsys): + pbs_job_manager = PBSJobManager() + current_path = tmp_path / "parentA" / "expA" + archive_path = current_path / "archive" + archive_path.mkdir(parents=True, exist_ok=True) + + # archive has 1 done, but in total 3 -> hence another 2 runs needed + (archive_path / "output000").mkdir() + runs = [] + + def dummy_run(cmd, *args, **kwargs): + runs.append(cmd) + + monkeypatch.setattr("subprocess.run", dummy_run, raising=True) + + pbs_job_manager._start_experiment_runs(current_path, nruns=3, duplicated=False) + assert runs and "payu run -n 2 -f" in runs[0] + + # considering all completed runs, no new runs needed + runs.clear() + (archive_path / "output001").mkdir() + (archive_path / "output002").mkdir() + pbs_job_manager._start_experiment_runs(current_path, nruns=3, duplicated=False) + assert not runs + + +def test_clean_workspace_removes_work_dir(tmp_path, monkeypatch): + pbs_job_manager = PBSJobManager() + current_path = tmp_path / "parentA" / "expA" + current_path.mkdir(parents=True, exist_ok=True) + + real_work_path = tmp_path / "work" + real_work_path.mkdir(parents=True, exist_ok=True) + + work_path_link = current_path / "work" + + os.symlink(real_work_path, work_path_link) + + calls = [] + + def dummy_run(cmd, *args, **kwargs): + calls.append(cmd) + + monkeypatch.setattr("subprocess.run", dummy_run, raising=True) + pbs_job_manager._clean_workspace(current_path) + assert calls and "payu sweep && payu setup" in calls[0]