Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 130 additions & 0 deletions alphatrion/server/graphql/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from alphatrion.storage.sql_models import (
AgentType,
Status,
StatusMap,
)

from .types import (
Expand Down Expand Up @@ -43,6 +44,7 @@
Team,
TraceEvent,
TraceLink,
UpdateExperimentInput,
UpdateOrganizationInput,
UpdateUserInput,
User,
Expand Down Expand Up @@ -1679,6 +1681,73 @@ def create_experiment(
updated_at=exp.updated_at,
)

@staticmethod
def update_experiment(
info: Info[GraphQLContext, None], input: UpdateExperimentInput
) -> Experiment:
"""Update an existing experiment."""

user_id = uuid.UUID(info.context.user_id)
experiment_id = uuid.UUID(input.id)

metadb = runtime.storage_runtime().metadb

# Verify user has access to the experiment
if not metadb.experiment_is_accessible_to_user(
experiment_id=experiment_id, user_id=user_id
):
raise RuntimeError(
"Not allowed to update experiment that user does not have access to"
)

# Get the experiment to check if it exists
exp = metadb.get_experiment(experiment_id=experiment_id)
if not exp:
raise RuntimeError(f"Experiment with id '{input.id}' not found")

# Build update kwargs
update_kwargs = {}
if input.description is not None:
update_kwargs["description"] = input.description
if input.meta is not None:
update_kwargs["meta"] = input.meta
if input.params is not None:
update_kwargs["params"] = input.params
if input.labels is not None:
update_kwargs["labels"] = input.labels
if input.tags is not None:
update_kwargs["tags"] = input.tags

# Update experiment
if update_kwargs:
metadb.update_experiment(
experiment_id=experiment_id,
**update_kwargs,
)

# Get the updated experiment
updated_exp = metadb.get_experiment(experiment_id=experiment_id)
if not updated_exp:
raise RuntimeError("Failed to retrieve updated experiment")

return Experiment(
id=updated_exp.uuid,
org_id=updated_exp.org_id,
team_id=updated_exp.team_id,
user_id=updated_exp.user_id,
name=updated_exp.name,
description=updated_exp.description,
meta=updated_exp.meta,
params=updated_exp.params,
duration=updated_exp.duration,
status=GraphQLStatusEnum[Status(updated_exp.status).name],
kind=GraphQLExperimentTypeEnum[
GraphQLExperimentType(updated_exp.kind).name
],
created_at=updated_exp.created_at,
updated_at=updated_exp.updated_at,
)

@staticmethod
def delete_experiment(
info: Info[GraphQLContext, None], experiment_id: strawberry.ID
Expand Down Expand Up @@ -1740,3 +1809,64 @@ def delete_datasets(
for id in dataset_ids:
GraphQLMutations.delete_dataset(info=info, dataset_id=id)
return True

@staticmethod
def abort_experiment(
info: Info[GraphQLContext, None], experiment_id: strawberry.ID
) -> Experiment:
"""Abort an experiment by changing its status to ABORTED.
Only works if the experiment is in PENDING status."""

user_id = uuid.UUID(info.context.user_id)
experiment_id_uuid = uuid.UUID(experiment_id)

metadb = runtime.storage_runtime().metadb

# Verify user has access to the experiment
if not metadb.experiment_is_accessible_to_user(
experiment_id=experiment_id_uuid, user_id=user_id
):
raise RuntimeError(
"Not allowed to abort experiment that user does not have access to"
)

# Get the experiment to check if it exists
exp = metadb.get_experiment(experiment_id=experiment_id_uuid)
if not exp:
raise RuntimeError(f"Experiment with id '{experiment_id}' not found")

# Only abort if experiment is in PENDING status
if exp.status != Status.PENDING:
raise RuntimeError(
f"Cannot abort experiment with status '{StatusMap[Status(exp.status)]}'. "
"Only experiments in PENDING status can be aborted."
)

# Update status to ABORTED
metadb.update_experiment(
experiment_id=experiment_id_uuid,
status=Status.ABORTED,
)

# Get the updated experiment
updated_exp = metadb.get_experiment(experiment_id=experiment_id_uuid)
if not updated_exp:
raise RuntimeError("Failed to retrieve aborted experiment")

return Experiment(
id=updated_exp.uuid,
org_id=updated_exp.org_id,
team_id=updated_exp.team_id,
user_id=updated_exp.user_id,
name=updated_exp.name,
description=updated_exp.description,
meta=updated_exp.meta,
params=updated_exp.params,
duration=updated_exp.duration,
status=GraphQLStatusEnum[Status(updated_exp.status).name],
kind=GraphQLExperimentTypeEnum[
GraphQLExperimentType(updated_exp.kind).name
],
created_at=updated_exp.created_at,
updated_at=updated_exp.updated_at,
)
13 changes: 13 additions & 0 deletions alphatrion/server/graphql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
Session,
Span,
Team,
UpdateExperimentInput,
UpdateOrganizationInput,
UpdateUserInput,
User,
Expand Down Expand Up @@ -276,5 +277,17 @@ def create_experiment(
) -> Experiment:
return GraphQLMutations.create_experiment(info=info, input=input)

@strawberry.mutation
def update_experiment(
self, input: UpdateExperimentInput, info: Info[GraphQLContext, None]
) -> Experiment:
return GraphQLMutations.update_experiment(info=info, input=input)

@strawberry.mutation
def abort_experiment(
self, experiment_id: strawberry.ID, info: Info[GraphQLContext, None]
) -> Experiment:
return GraphQLMutations.abort_experiment(info=info, experiment_id=experiment_id)


schema = strawberry.Schema(query=Query, mutation=Mutation)
11 changes: 11 additions & 0 deletions alphatrion/server/graphql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ class GraphQLStatus(Enum):
CANCELLED = "CANCELLED"
COMPLETED = "COMPLETED"
FAILED = "FAILED"
ABORTED = "ABORTED"


GraphQLStatusEnum = strawberry.enum(GraphQLStatus)
Expand Down Expand Up @@ -444,6 +445,16 @@ class CreateExperimentInput:
kind: GraphQLExperimentTypeEnum = GraphQLExperimentTypeEnum.CRAFT_EXPERIMENT


@strawberry.input
class UpdateExperimentInput:
id: strawberry.ID
description: str | None = None
labels: str | None = None
tags: list[str] | None = None
meta: JSON | None = None
params: JSON | None = None


# Artifact types
@strawberry.type
class ArtifactRepository:
Expand Down
4 changes: 3 additions & 1 deletion alphatrion/storage/sql_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class Status(enum.IntEnum):
COMPLETED = 9
CANCELLED = 10
FAILED = 11
ABORTED = 12


StatusMap = {
Expand All @@ -35,9 +36,10 @@ class Status(enum.IntEnum):
Status.CANCELLED: "CANCELLED",
Status.COMPLETED: "COMPLETED",
Status.FAILED: "FAILED",
Status.ABORTED: "ABORTED",
}

FINISHED_STATUS = [Status.COMPLETED, Status.FAILED, Status.CANCELLED]
FINISHED_STATUS = [Status.COMPLETED, Status.FAILED, Status.CANCELLED, Status.ABORTED]


class Organization(Base):
Expand Down
52 changes: 52 additions & 0 deletions alphatrion/storage/sqlstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,13 +652,65 @@ def update_experiment(self, experiment_id: uuid.UUID, **kwargs) -> None:
.first()
)
if exp:
# Handle labels separately
labels = kwargs.pop("labels", None)
tags = kwargs.pop("tags", None)

# Update experiment fields
for key, value in kwargs.items():
if key == "meta" and isinstance(value, dict):
if exp.meta is None:
exp.meta = {}
exp.meta.update(value)
else:
setattr(exp, key, value)

# Update labels if provided
if labels is not None:
# Delete existing labels
session.query(ExperimentLabel).filter(
ExperimentLabel.experiment_id == experiment_id
).delete(synchronize_session=False)

# Add new labels
if labels:
label_pairs = labels.rstrip().split(",")
for pair in label_pairs:
if ":" in pair:
label_name, label_value = pair.split(":", 1)
elif "=" in pair:
label_name, label_value = pair.split("=", 1)
else:
continue # skip invalid label

exp_label = ExperimentLabel(
org_id=exp.org_id,
team_id=exp.team_id,
experiment_id=experiment_id,
label_name=label_name.strip(),
label_value=label_value.strip(),
)
session.add(exp_label)

# Update tags if provided
if tags is not None:
# Delete existing tags
session.query(ExperimentTag).filter(
ExperimentTag.experiment_id == experiment_id
).delete(synchronize_session=False)

# Add new tags
if tags:
for tag in [t.strip() for t in tags]:
if tag:
exp_tag = ExperimentTag(
org_id=exp.org_id,
team_id=exp.team_id,
experiment_id=experiment_id,
tag=tag,
)
session.add(exp_tag)

session.commit()
session.close()

Expand Down
Loading
Loading