@@ -75,6 +75,8 @@ def with_session(func: Callable[P, T]) -> Callable[P, T]: # type: ignore
7575 """
7676 Decorator to inject a session into the function kwargs if it is not provided.
7777
78+ Make sure to pass session as a keyword argument to the function.
79+
7880 TODO: This decorator is a temporary fix, we want to move to a DI approach instead:
7981 move db connection and session to context, and pass context to all stash methods.
8082 """
@@ -87,8 +89,9 @@ def with_session(func: Callable[P, T]) -> Callable[P, T]: # type: ignore
8789 def wrapper (self : "ObjectStash[StashT]" , * args : Any , ** kwargs : Any ) -> Any :
8890 if inject_session and kwargs .get ("session" ) is None :
8991 with self .sessionmaker () as session :
90- kwargs ["session" ] = session
91- return func (self , * args , ** kwargs )
92+ with session .begin ():
93+ kwargs ["session" ] = session
94+ return func (self , * args , ** kwargs )
9295 return func (self , * args , ** kwargs )
9396
9497 return wrapper # type: ignore
@@ -369,11 +372,13 @@ def set(
369372 uid = obj .id
370373
371374 # check if the object already exists
372- if self .exists (credentials , uid ) or not self .is_unique (obj ):
375+ if self .exists (credentials , uid , session = session ) or not self .is_unique (
376+ obj , session = session
377+ ):
373378 if ignore_duplicates :
374379 return obj
375380 unique_fields_str = ", " .join (self .unique_fields )
376- raise StashException (
381+ raise UniqueConstraintException (
377382 public_message = f"Duplication Key Error for { obj } .\n "
378383 f"The fields that should be unique are { unique_fields_str } ."
379384 )
@@ -399,7 +404,6 @@ def set(
399404 raise StashException (
400405 f"Error serializing object: { e } . Some fields are invalid."
401406 )
402-
403407 # create the object with the permissions
404408 stmt = self .table .insert ().values (
405409 id = uid ,
@@ -408,7 +412,6 @@ def set(
408412 storage_permissions = storage_permissions ,
409413 )
410414 session .execute (stmt )
411- session .commit ()
412415 return self .get_by_uid (credentials , uid , session = session ).unwrap ()
413416
414417 @as_result (ValidationError , AttributeError )
@@ -462,7 +465,7 @@ def update(
462465 ).unwrap ()
463466
464467 # TODO has_permission is not used
465- if not self .is_unique (obj ):
468+ if not self .is_unique (obj , session = session ):
466469 raise UniqueConstraintException (
467470 f"Some fields are not unique for { type (obj ).__name__ } and unique fields { self .unique_fields } "
468471 )
@@ -483,14 +486,12 @@ def update(
483486 f"Error serializing object: { e } . Some fields are invalid."
484487 )
485488 stmt = stmt .values (fields = fields )
486-
487489 result = session .execute (stmt )
488- session .commit ()
489490 if result .rowcount == 0 :
490491 raise NotFoundException (
491492 f"{ self .object_type .__name__ } : { obj .id } not found or no permission to update."
492493 )
493- return self .get_by_uid (credentials , obj .id ).unwrap ()
494+ return self .get_by_uid (credentials , obj .id , session = session ).unwrap ()
494495
495496 @as_result (StashException , NotFoundException )
496497 @with_session
@@ -510,7 +511,6 @@ def delete_by_uid(
510511 session = session ,
511512 )
512513 result = session .execute (stmt )
513- session .commit ()
514514 if result .rowcount == 0 :
515515 raise NotFoundException (
516516 f"{ self .object_type .__name__ } : { uid } not found or no permission to delete."
@@ -649,8 +649,6 @@ def add_permission(
649649 stmt = self .table .update ().where (self .table .c .id == permission .uid )
650650 stmt = stmt .values (permissions = list (existing_permissions ))
651651 session .execute (stmt )
652- session .commit ()
653-
654652 return None
655653
656654 @as_result (NotFoundException )
@@ -685,7 +683,6 @@ def remove_permission(
685683 .values (permissions = list (permissions ))
686684 )
687685 session .execute (stmt )
688- session .commit ()
689686 return None
690687
691688 @with_session
@@ -842,7 +839,6 @@ def remove_storage_permission(
842839 .values (storage_permissions = [str (uid ) for uid in permissions ])
843840 )
844841 session .execute (stmt )
845- session .commit ()
846842 return None
847843
848844 @as_result (StashException )
@@ -857,3 +853,26 @@ def _get_storage_permissions_for_uid(
857853 if result is None :
858854 raise NotFoundException (f"No storage permissions found for uid: { uid } " )
859855 return {UID (uid ) for uid in result .storage_permissions }
856+
857+ @with_session
858+ @as_result (StashException )
859+ def upsert (
860+ self ,
861+ credentials : SyftVerifyKey ,
862+ obj : StashT ,
863+ session : Session = None ,
864+ ) -> StashT :
865+ """Insert or update an object in the stash if it already exists.
866+ Atomic operation when using the same session for both operations.
867+ """
868+
869+ try :
870+ return self .set (
871+ credentials = credentials ,
872+ obj = obj ,
873+ session = session ,
874+ ).unwrap ()
875+ except UniqueConstraintException :
876+ return self .update (
877+ credentials = credentials , obj = obj , session = session
878+ ).unwrap ()
0 commit comments