Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 73 additions & 13 deletions src/experiment_runner/experiment_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from payu.branch import clone, list_branches
from .base_experiment import BaseExperiment
from .pbs_job_manager import PBSJobManager
import subprocess
import git


class ExperimentRunner(BaseExperiment):
Expand Down Expand Up @@ -42,19 +44,77 @@ def _create_cloned_directory(self) -> None:
for clone_dir, branch in zip(all_cloned_directories, self.running_branches):
if clone_dir.exists():
print(f"-- Test dir: {clone_dir} already exists, skipping cloning.")
if not self._update_existing_repo(clone_dir, branch):
print(
f"Failed to update existing repo {clone_dir}, leaving as it is."
)
else:
print(f"-- Cloning branch '{branch}' into {clone_dir}...")
clone(
repository=self.base_directory,
directory=clone_dir,
branch=branch,
keep_uuid=self.keep_uuid,
model_type=self.model_type,
config_path=self.config_path,
lab_path=self.lab_path,
new_branch_name=self.new_branch_name,
restart_path=self.restart_path,
parent_experiment=self.parent_experiment,
start_point=self.start_point,
)
self._do_clone(clone_dir, branch)

return all_cloned_directories

def _do_clone(self, clone_dir: Path, branch: str):
clone(
repository=self.base_directory,
directory=clone_dir,
branch=branch,
keep_uuid=self.keep_uuid,
model_type=self.model_type,
config_path=self.config_path,
lab_path=self.lab_path,
new_branch_name=self.new_branch_name,
restart_path=self.restart_path,
parent_experiment=self.parent_experiment,
start_point=self.start_point,
)

def _update_existing_repo(self, clone_dir: Path, target_ref: str) -> bool:
"""
Updates the repo without removing the dir or untracked files
target_ref: branch to checkout
"""

try:
repo = git.Repo(str(clone_dir))
remote = repo.remotes.origin
remote.fetch(prune=True)

# save current HEAD
current_commit = repo.head.commit.hexsha

# ensure branch exists
if target_ref in repo.heads:
repo.git.checkout(target_ref)
else:
repo.git.checkout("-b", target_ref, f"origin/{target_ref}")

# try pulling with rebase
try:
repo.git.pull("--rebase", "--autostash", "origin", target_ref)
except git.exc.GitCommandError as e:
repo.git.reset("--keep", f"origin/{target_ref}")

# save new HEAD after update
new_commit = repo.head.commit.hexsha

rel_path = clone_dir.relative_to(self.test_path)

if current_commit == new_commit:
print(f"-- Repo {rel_path} is already up to date with {target_ref}.")
else:
print(
f"-- Repo {rel_path} updated from {current_commit[:7]} to {new_commit[:7]} on branch {target_ref}."
)
changed = repo.git.diff(
"--name-only", current_commit, new_commit
).splitlines()
if changed:
print("-- Changed files:")
for file in changed:
print(f" -- {file}")

return True
except git.exc.GitCommandError as e:
print(f"Failed updating existing repo {rel_path}: {e}")
return False