Skip to content

Commit 08b4f53

Browse files
author
Aziz Berkay Yesilyurt
authored
Merge pull request #9310 from OpenMined/aziz/atomic
make stash.[set,update,upsert] atomic
2 parents 2174ec5 + 2eca5a5 commit 08b4f53

File tree

4 files changed

+62
-38
lines changed

4 files changed

+62
-38
lines changed

packages/syft/src/syft/service/api/api_service.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def set(
7272
public_message="An API endpoint already exists at the given path."
7373
)
7474

75-
result = self.stash.upsert(context.credentials, endpoint=new_endpoint).unwrap()
75+
result = self.stash.upsert(context.credentials, obj=new_endpoint).unwrap()
7676
action_obj = ActionObject.from_obj(
7777
id=new_endpoint.action_object_id,
7878
syft_action_data=CustomEndpointActionObject(endpoint_id=result.id),
@@ -157,7 +157,7 @@ def update(
157157
endpoint.mock_function.view_access = view_access
158158

159159
# save changes
160-
self.stash.upsert(context.credentials, endpoint=endpoint).unwrap()
160+
self.stash.upsert(context.credentials, obj=endpoint).unwrap()
161161
return SyftSuccess(message="Endpoint successfully updated.")
162162

163163
@service_method(
@@ -218,7 +218,7 @@ def set_state(
218218
if mock and api_endpoint.mock_function:
219219
api_endpoint.mock_function.state = state
220220

221-
self.stash.upsert(context.credentials, endpoint=api_endpoint).unwrap()
221+
self.stash.upsert(context.credentials, obj=api_endpoint).unwrap()
222222
return SyftSuccess(message=f"APIEndpoint {api_path} state updated.")
223223

224224
@service_method(
@@ -248,7 +248,7 @@ def set_settings(
248248
if mock and api_endpoint.mock_function:
249249
api_endpoint.mock_function.settings = settings
250250

251-
self.stash.upsert(context.credentials, endpoint=api_endpoint).unwrap()
251+
self.stash.upsert(context.credentials, obj=api_endpoint).unwrap()
252252
return SyftSuccess(message=f"APIEndpoint {api_path} settings updated.")
253253

254254
@service_method(

packages/syft/src/syft/service/api/api_stash.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,22 +33,3 @@ def path_exists(self, credentials: SyftVerifyKey, path: str) -> bool:
3333
return True
3434
except NotFoundException:
3535
return False
36-
37-
@as_result(StashException)
38-
def upsert(
39-
self,
40-
credentials: SyftVerifyKey,
41-
endpoint: TwinAPIEndpoint,
42-
has_permission: bool = False,
43-
) -> TwinAPIEndpoint:
44-
"""Upsert an endpoint."""
45-
exists = self.path_exists(credentials=credentials, path=endpoint.path).unwrap()
46-
47-
if exists:
48-
super().delete_by_uid(credentials=credentials, uid=endpoint.id).unwrap()
49-
50-
return (
51-
super()
52-
.set(credentials=credentials, obj=endpoint, ignore_duplicates=False)
53-
.unwrap()
54-
)

packages/syft/src/syft/store/db/stash.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,8 @@ def with_session(func: Callable[P, T]) -> Callable[P, T]: # type: ignore
7575
"""
7676
Decorator to inject a session into the function kwargs if it is not provided.
7777
78+
Make sure to pass session as a keyword argument to the function.
79+
7880
TODO: This decorator is a temporary fix, we want to move to a DI approach instead:
7981
move db connection and session to context, and pass context to all stash methods.
8082
"""
@@ -87,8 +89,9 @@ def with_session(func: Callable[P, T]) -> Callable[P, T]: # type: ignore
8789
def wrapper(self: "ObjectStash[StashT]", *args: Any, **kwargs: Any) -> Any:
8890
if inject_session and kwargs.get("session") is None:
8991
with self.sessionmaker() as session:
90-
kwargs["session"] = session
91-
return func(self, *args, **kwargs)
92+
with session.begin():
93+
kwargs["session"] = session
94+
return func(self, *args, **kwargs)
9295
return func(self, *args, **kwargs)
9396

9497
return wrapper # type: ignore
@@ -369,11 +372,13 @@ def set(
369372
uid = obj.id
370373

371374
# check if the object already exists
372-
if self.exists(credentials, uid) or not self.is_unique(obj):
375+
if self.exists(credentials, uid, session=session) or not self.is_unique(
376+
obj, session=session
377+
):
373378
if ignore_duplicates:
374379
return obj
375380
unique_fields_str = ", ".join(self.unique_fields)
376-
raise StashException(
381+
raise UniqueConstraintException(
377382
public_message=f"Duplication Key Error for {obj}.\n"
378383
f"The fields that should be unique are {unique_fields_str}."
379384
)
@@ -399,7 +404,6 @@ def set(
399404
raise StashException(
400405
f"Error serializing object: {e}. Some fields are invalid."
401406
)
402-
403407
# create the object with the permissions
404408
stmt = self.table.insert().values(
405409
id=uid,
@@ -408,7 +412,6 @@ def set(
408412
storage_permissions=storage_permissions,
409413
)
410414
session.execute(stmt)
411-
session.commit()
412415
return self.get_by_uid(credentials, uid, session=session).unwrap()
413416

414417
@as_result(ValidationError, AttributeError)
@@ -462,7 +465,7 @@ def update(
462465
).unwrap()
463466

464467
# TODO has_permission is not used
465-
if not self.is_unique(obj):
468+
if not self.is_unique(obj, session=session):
466469
raise UniqueConstraintException(
467470
f"Some fields are not unique for {type(obj).__name__} and unique fields {self.unique_fields}"
468471
)
@@ -483,14 +486,12 @@ def update(
483486
f"Error serializing object: {e}. Some fields are invalid."
484487
)
485488
stmt = stmt.values(fields=fields)
486-
487489
result = session.execute(stmt)
488-
session.commit()
489490
if result.rowcount == 0:
490491
raise NotFoundException(
491492
f"{self.object_type.__name__}: {obj.id} not found or no permission to update."
492493
)
493-
return self.get_by_uid(credentials, obj.id).unwrap()
494+
return self.get_by_uid(credentials, obj.id, session=session).unwrap()
494495

495496
@as_result(StashException, NotFoundException)
496497
@with_session
@@ -510,7 +511,6 @@ def delete_by_uid(
510511
session=session,
511512
)
512513
result = session.execute(stmt)
513-
session.commit()
514514
if result.rowcount == 0:
515515
raise NotFoundException(
516516
f"{self.object_type.__name__}: {uid} not found or no permission to delete."
@@ -649,8 +649,6 @@ def add_permission(
649649
stmt = self.table.update().where(self.table.c.id == permission.uid)
650650
stmt = stmt.values(permissions=list(existing_permissions))
651651
session.execute(stmt)
652-
session.commit()
653-
654652
return None
655653

656654
@as_result(NotFoundException)
@@ -685,7 +683,6 @@ def remove_permission(
685683
.values(permissions=list(permissions))
686684
)
687685
session.execute(stmt)
688-
session.commit()
689686
return None
690687

691688
@with_session
@@ -842,7 +839,6 @@ def remove_storage_permission(
842839
.values(storage_permissions=[str(uid) for uid in permissions])
843840
)
844841
session.execute(stmt)
845-
session.commit()
846842
return None
847843

848844
@as_result(StashException)
@@ -857,3 +853,26 @@ def _get_storage_permissions_for_uid(
857853
if result is None:
858854
raise NotFoundException(f"No storage permissions found for uid: {uid}")
859855
return {UID(uid) for uid in result.storage_permissions}
856+
857+
@with_session
858+
@as_result(StashException)
859+
def upsert(
860+
self,
861+
credentials: SyftVerifyKey,
862+
obj: StashT,
863+
session: Session = None,
864+
) -> StashT:
865+
"""Insert or update an object in the stash if it already exists.
866+
Atomic operation when using the same session for both operations.
867+
"""
868+
869+
try:
870+
return self.set(
871+
credentials=credentials,
872+
obj=obj,
873+
session=session,
874+
).unwrap()
875+
except UniqueConstraintException:
876+
return self.update(
877+
credentials=credentials, obj=obj, session=session
878+
).unwrap()

packages/syft/tests/syft/stores/base_stash_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,30 @@ def test_basestash_update(
190190
assert retrieved == updated_obj
191191

192192

193+
def test_basestash_upsert(
194+
root_verify_key, base_stash: MockStash, mock_object: MockObject, faker: Faker
195+
) -> None:
196+
base_stash.set(root_verify_key, mock_object).unwrap()
197+
198+
updated_obj = mock_object.copy()
199+
updated_obj.name = faker.name()
200+
201+
retrieved = base_stash.upsert(root_verify_key, updated_obj).unwrap()
202+
assert retrieved == updated_obj
203+
204+
updated_obj.id = UID()
205+
206+
with pytest.raises(StashException):
207+
# fails because the name should be unique
208+
base_stash.upsert(root_verify_key, updated_obj).unwrap()
209+
210+
updated_obj.name = faker.name()
211+
212+
retrieved = base_stash.upsert(root_verify_key, updated_obj).unwrap()
213+
assert retrieved == updated_obj
214+
assert len(base_stash.get_all(root_verify_key).unwrap()) == 2
215+
216+
193217
def test_basestash_cannot_update_non_existent(
194218
root_verify_key, base_stash: MockStash, mock_object: MockObject, faker: Faker
195219
) -> None:

0 commit comments

Comments
 (0)