Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions alphatrion/server/graphql/resolvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
ArtifactContent,
ArtifactRepository,
ArtifactTag,
CreateExperimentInput,
CreateTeamInput,
CreateUserInput,
DailyCostUsage,
Expand Down Expand Up @@ -1614,6 +1615,70 @@ def remove_user_from_team(
# Remove user from team (deletes TeamMember entry)
return metadb.remove_user_from_team(user_id=user_id, team_id=team_id)

@staticmethod
def create_experiment(
info: Info[GraphQLContext, None], input: CreateExperimentInput
) -> Experiment:
"""Create a new experiment."""

user_id = uuid.UUID(info.context.user_id)
org_id = uuid.UUID(info.context.org_id)
team_id = uuid.UUID(input.team_id)

metadb = runtime.storage_runtime().metadb

# Verify user has access to the team
if not metadb.team_is_accessible_to_user(
team_id=team_id, user_id=user_id, org_id=org_id
):
raise RuntimeError(
"Not allowed to create experiments in team that user does not belong to"
)

# Check if experiment with same name already exists in the team
existing_exp = metadb.get_exp_by_name(
name=input.name, team_id=team_id, include_deleted=True
)
if existing_exp:
raise RuntimeError(
f"Experiment with name '{input.name}' already exists in this team"
)

# Create experiment
experiment_id = metadb.create_experiment(
name=input.name,
org_id=org_id,
team_id=team_id,
user_id=user_id,
description=input.description,
labels=input.labels,
tags=input.tags,
meta=input.meta,
params=input.params,
status=Status.PENDING,
)

# Get the created experiment
exp = metadb.get_experiment(experiment_id=experiment_id)
if not exp:
raise RuntimeError("Failed to create experiment")

return Experiment(
id=exp.uuid,
org_id=exp.org_id,
team_id=exp.team_id,
user_id=exp.user_id,
name=exp.name,
description=exp.description,
meta=exp.meta,
params=exp.params,
duration=exp.duration,
status=GraphQLStatusEnum[Status(exp.status).name],
kind=GraphQLExperimentTypeEnum[GraphQLExperimentType(exp.kind).name],
created_at=exp.created_at,
updated_at=exp.updated_at,
)

@staticmethod
def delete_experiment(
info: Info[GraphQLContext, None], experiment_id: strawberry.ID
Expand Down
7 changes: 7 additions & 0 deletions alphatrion/server/graphql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ArtifactFile,
ArtifactRepository,
ArtifactTag,
CreateExperimentInput,
CreateTeamInput,
CreateUserInput,
DailyCostUsage,
Expand Down Expand Up @@ -269,5 +270,11 @@ def delete_datasets(
) -> bool:
return GraphQLMutations.delete_datasets(info=info, dataset_ids=dataset_ids)

@strawberry.mutation
def create_experiment(
self, input: CreateExperimentInput, info: Info[GraphQLContext, None]
) -> Experiment:
return GraphQLMutations.create_experiment(info=info, input=input)


schema = strawberry.Schema(query=Query, mutation=Mutation)
12 changes: 12 additions & 0 deletions alphatrion/server/graphql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,18 @@ class RemoveUserFromTeamInput:
team_id: strawberry.ID


@strawberry.input
class CreateExperimentInput:
name: str
team_id: strawberry.ID
description: str | None = None
labels: str | None = None
tags: list[str] | None = None
meta: JSON | None = None
params: JSON | None = None
kind: GraphQLExperimentTypeEnum = GraphQLExperimentTypeEnum.CRAFT_EXPERIMENT


# Artifact types
@strawberry.type
class ArtifactRepository:
Expand Down
111 changes: 111 additions & 0 deletions tests/integration/server/test_graphql_mutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1103,3 +1103,114 @@ def test_delete_experiments_all_running(
assert exp_2 is not None
assert exp_1.status == Status.RUNNING
assert exp_2.status == Status.RUNNING


def test_create_experiment_mutation(
execute_graphql, test_org_id, test_user_id, test_team_id
):
"""Test creating an experiment via GraphQL mutation"""
runtime.init()
metadb = runtime.storage_runtime().metadb

mutation = f"""
mutation {{
createExperiment(input: {{
name: "Test Experiment"
teamId: "{test_team_id}"
description: "An experiment created via mutation"
tags: ["ml", "training"]
meta: {{model: "gpt-4", version: "1.0"}}
params: {{learningRate: 0.001, batchSize: 32}}
}}) {{
id
name
description
status
kind
meta
params
tags
createdAt
updatedAt
}}
}}
"""
response = execute_graphql(
query=mutation,
org_id=test_org_id,
user_id=test_user_id,
)
assert response.errors is None
assert response.data["createExperiment"]["name"] == "Test Experiment"
assert (
response.data["createExperiment"]["description"]
== "An experiment created via mutation"
)
assert response.data["createExperiment"]["status"] == "PENDING"
assert response.data["createExperiment"]["kind"] == "CRAFT_EXPERIMENT"
assert response.data["createExperiment"]["meta"] == {
"model": "gpt-4",
"version": "1.0",
}
assert response.data["createExperiment"]["params"] == {
"learningRate": 0.001,
"batchSize": 32,
}
assert response.data["createExperiment"]["tags"] == ["ml", "training"]

# Verify experiment was actually created in database
new_exp_id = uuid.UUID(response.data["createExperiment"]["id"])
exp = metadb.get_experiment(experiment_id=new_exp_id)
assert exp is not None
assert exp.name == "Test Experiment"
assert exp.status == Status.PENDING


def test_create_experiment_duplicate_name(
execute_graphql, test_org_id, test_user_id, test_team_id
):
"""Test that creating an experiment with duplicate name fails"""
runtime.init()

# Create first experiment
mutation1 = f"""
mutation {{
createExperiment(input: {{
name: "Duplicate Name Test"
teamId: "{test_team_id}"
description: "First experiment"
}}) {{
id
name
}}
}}
"""
response1 = execute_graphql(
query=mutation1,
org_id=test_org_id,
user_id=test_user_id,
)
assert response1.errors is None
assert response1.data["createExperiment"]["name"] == "Duplicate Name Test"

# Try to create second experiment with same name
mutation2 = f"""
mutation {{
createExperiment(input: {{
name: "Duplicate Name Test"
teamId: "{test_team_id}"
description: "Second experiment with same name"
}}) {{
id
name
}}
}}
"""
response2 = execute_graphql(
query=mutation2,
org_id=test_org_id,
user_id=test_user_id,
)
# Should return an error
assert response2.errors is not None
assert "already exists" in str(response2.errors[0])
Loading