diff --git a/syft/workers/base.py b/syft/workers/base.py index b2205b5a34a..53edc211cdd 100644 --- a/syft/workers/base.py +++ b/syft/workers/base.py @@ -84,6 +84,9 @@ class BaseWorker(AbstractWorker, ObjectStorage): primarily a development/testing feature. auto_add: Determines whether to automatically add this worker to the list of known workers. + message_pending_time (optional): A number of seconds to delay the messages to be sent. + The argument may be a floating point number for subsecond + precision. """ def __init__( @@ -95,6 +98,7 @@ def __init__( log_msgs: bool = False, verbose: bool = False, auto_add: bool = True, + message_pending_time: Union[int, float] = 0, ): """Initializes a BaseWorker.""" super().__init__() @@ -105,6 +109,7 @@ def __init__( self.log_msgs = log_msgs self.verbose = verbose self.auto_add = auto_add + self._message_pending_time = message_pending_time self.msg_history = list() # For performance, we cache all possible message types @@ -975,6 +980,29 @@ def _get_msg(self, index): return sy.serde.deserialize(self.msg_history[index], worker=self) + @property + def message_pending_time(self): + """ + Returns: + The pending time in seconds for messaging between virtual workers. + """ + return self._message_pending_time + + @message_pending_time.setter + def message_pending_time(self, seconds: Union[int, float]) -> None: + """Sets the pending time to send messaging between workers. + + Args: + seconds: A number of seconds to delay the messages to be sent. + The argument may be a floating point number for subsecond + precision. + + """ + if self.verbose: + print(f"Set message pending time to {seconds} seconds.") + + self._message_pending_time = seconds + @staticmethod def create_message_execute_command( command_name: str, command_owner=None, return_ids=None, *args, **kwargs diff --git a/syft/workers/virtual.py b/syft/workers/virtual.py index 68b3e303f26..153883beb8e 100644 --- a/syft/workers/virtual.py +++ b/syft/workers/virtual.py @@ -1,9 +1,16 @@ +from time import sleep + from syft.workers.base import BaseWorker from syft.federated.federated_client import FederatedClient class VirtualWorker(BaseWorker, FederatedClient): def _send_msg(self, message: bin, location: BaseWorker) -> bin: + if self.message_pending_time > 0: + if self.verbose: + print(f"pending time of {self.message_pending_time} seconds to send message...") + sleep(self.message_pending_time) + return location._recv_msg(message) def _recv_msg(self, message: bin) -> bin: diff --git a/test/workers/test_virtual.py b/test/workers/test_virtual.py index 6155e13404a..c844717b0a6 100644 --- a/test/workers/test_virtual.py +++ b/test/workers/test_virtual.py @@ -25,6 +25,9 @@ def test_send_msg(): # get pointer to local worker me = sy.torch.hook.local_worker + # pending time to simulate lantency (optional) + me.message_pending_time = 0.1 + # create a new worker (to send the object to) worker_id = sy.ID_PROVIDER.pop() bob = VirtualWorker(sy.torch.hook, id=f"bob{worker_id}") @@ -34,10 +37,14 @@ def test_send_msg(): obj_id = obj.id # Send data to bob + start_time = time() me.send_msg(ObjectMessage(obj), bob) + elapsed_time = time() - start_time # ensure that object is now on bob's machine assert obj_id in bob._objects + # ensure that object was sent 0.1 secs later + assert abs(elapsed_time - me.message_pending_time) < 0.1 def test_send_msg_using_tensor_api():