Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions alphatrion/run/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Run package."""

from alphatrion.run.hooks import PostRunHooks
from alphatrion.run.hooks import PostRunHook

__all__ = ["PostRunHooks"]
__all__ = ["PostRunHook"]
4 changes: 2 additions & 2 deletions alphatrion/run/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from alphatrion.runtime.runtime import global_runtime


class PostRunHooks:
class PostRunHook:
"""Library of built-in post-run hooks."""

@staticmethod
Expand All @@ -26,7 +26,7 @@ async def train_model():
"num_epochs": 10,
}

run = exp.run(train_model, post_run_hooks=[PostRunHooks.sync_metadata])
run = exp.run(train_model, post_run_hooks=[PostRunHook.sync_metadata])
# After completion, run metadata will contain accuracy, loss, num_epochs

:param run_id: UUID of the run
Expand Down
14 changes: 7 additions & 7 deletions tests/integration/test_run_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import alphatrion as alpha
from alphatrion.experiment import CraftExperiment, ExperimentConfig
from alphatrion.run import PostRunHooks
from alphatrion.run import PostRunHook
from alphatrion.runtime.runtime import global_runtime


Expand Down Expand Up @@ -43,7 +43,7 @@ async def train_model():

async with CraftExperiment.start("test_hook_experiment") as exp:
# Create run with sync_metadata hook
run = exp.run(train_model, post_run_hooks=[PostRunHooks.sync_metadata])
run = exp.run(train_model, post_run_hooks=[PostRunHook.sync_metadata])
await exp.wait()

# Verify run completed
Expand Down Expand Up @@ -72,7 +72,7 @@ async def task_with_string_result():

async with CraftExperiment.start("test_hook_non_dict") as exp:
run = exp.run(
task_with_string_result, post_run_hooks=[PostRunHooks.sync_metadata]
task_with_string_result, post_run_hooks=[PostRunHook.sync_metadata]
)
await exp.wait()

Expand All @@ -98,7 +98,7 @@ async def task2():
return {"task": "task2", "accuracy": 0.94}

# Configure experiment with sync_metadata hook
config = ExperimentConfig(post_run_hooks=[PostRunHooks.sync_metadata])
config = ExperimentConfig(post_run_hooks=[PostRunHook.sync_metadata])

async with CraftExperiment.start("test_exp_hooks", config=config) as exp:
run1 = exp.run(task1)
Expand Down Expand Up @@ -143,7 +143,7 @@ async def train_model():
async with CraftExperiment.start("test_custom_hook") as exp:
# Use both built-in and custom hooks
run = exp.run(
train_model, post_run_hooks=[PostRunHooks.sync_metadata, add_custom_info]
train_model, post_run_hooks=[PostRunHook.sync_metadata, add_custom_info]
)
await exp.wait()

Expand Down Expand Up @@ -172,7 +172,7 @@ async def train_model():
return {"accuracy": 0.96, "loss": 0.04}

async with CraftExperiment.start("test_merge_metadata") as exp:
run = exp.run(train_model, post_run_hooks=[PostRunHooks.sync_metadata])
run = exp.run(train_model, post_run_hooks=[PostRunHook.sync_metadata])

# Manually add some metadata before run completes
metadb = global_runtime().metadb
Expand Down Expand Up @@ -209,7 +209,7 @@ async def train_model():

async with CraftExperiment.start("test_hook_failure") as exp:
run = exp.run(
train_model, post_run_hooks=[buggy_hook, PostRunHooks.sync_metadata]
train_model, post_run_hooks=[buggy_hook, PostRunHook.sync_metadata]
)
await exp.wait()

Expand Down
10 changes: 5 additions & 5 deletions tests/unit/run/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest

from alphatrion.run.hooks import PostRunHooks
from alphatrion.run.hooks import PostRunHook
from alphatrion.storage.sqlstore import SQLStore


Expand Down Expand Up @@ -51,7 +51,7 @@ def test_sync_metadata_with_dict_result(db):

with patch("alphatrion.run.hooks.global_runtime", return_value=mock_runtime):
# Call the hook
PostRunHooks.sync_metadata(run_id, result)
PostRunHook.sync_metadata(run_id, result)

# Verify metadata was updated
run = db.get_run(run_id)
Expand Down Expand Up @@ -92,7 +92,7 @@ def test_sync_metadata_with_non_dict_result(db):

with patch("alphatrion.run.hooks.global_runtime", return_value=mock_runtime):
# Call the hook
PostRunHooks.sync_metadata(run_id, result)
PostRunHook.sync_metadata(run_id, result)

# Verify metadata was not updated
run = db.get_run(run_id)
Expand Down Expand Up @@ -136,7 +136,7 @@ def test_sync_metadata_merges_with_existing_metadata(db):

with patch("alphatrion.run.hooks.global_runtime", return_value=mock_runtime):
# Call the hook
PostRunHooks.sync_metadata(run_id, result)
PostRunHook.sync_metadata(run_id, result)

# Verify metadata was merged, not replaced
run = db.get_run(run_id)
Expand Down Expand Up @@ -193,7 +193,7 @@ def test_hook_signature():
"""Test that sync_metadata has correct signature"""
import inspect

sig = inspect.signature(PostRunHooks.sync_metadata)
sig = inspect.signature(PostRunHook.sync_metadata)
params = list(sig.parameters.keys())

# Should have exactly 2 parameters: run_id and result
Expand Down
Loading