@@ -87,25 +87,32 @@ def __init__(
8787 self ._instance_name in hs .config .worker .writers .to_device
8888 )
8989
90- self ._device_inbox_id_gen : AbstractStreamIdGenerator = (
90+ self ._to_device_msg_id_gen : AbstractStreamIdGenerator = (
9191 MultiWriterIdGenerator (
9292 db_conn = db_conn ,
9393 db = database ,
9494 notifier = hs .get_replication_notifier (),
9595 stream_name = "to_device" ,
9696 instance_name = self ._instance_name ,
97- tables = [("device_inbox" , "instance_name" , "stream_id" )],
97+ tables = [
98+ ("device_inbox" , "instance_name" , "stream_id" ),
99+ ("device_federation_outbox" , "instance_name" , "stream_id" ),
100+ ],
98101 sequence_name = "device_inbox_sequence" ,
99102 writers = hs .config .worker .writers .to_device ,
100103 )
101104 )
102105 else :
103106 self ._can_write_to_device = True
104- self ._device_inbox_id_gen = StreamIdGenerator (
105- db_conn , hs .get_replication_notifier (), "device_inbox" , "stream_id"
107+ self ._to_device_msg_id_gen = StreamIdGenerator (
108+ db_conn ,
109+ hs .get_replication_notifier (),
110+ "device_inbox" ,
111+ "stream_id" ,
112+ extra_tables = [("device_federation_outbox" , "stream_id" )],
106113 )
107114
108- max_device_inbox_id = self ._device_inbox_id_gen .get_current_token ()
115+ max_device_inbox_id = self ._to_device_msg_id_gen .get_current_token ()
109116 device_inbox_prefill , min_device_inbox_id = self .db_pool .get_cache_dict (
110117 db_conn ,
111118 "device_inbox" ,
@@ -145,8 +152,8 @@ def process_replication_rows(
145152 ) -> None :
146153 if stream_name == ToDeviceStream .NAME :
147154 # If replication is happening than postgres must be being used.
148- assert isinstance (self ._device_inbox_id_gen , MultiWriterIdGenerator )
149- self ._device_inbox_id_gen .advance (instance_name , token )
155+ assert isinstance (self ._to_device_msg_id_gen , MultiWriterIdGenerator )
156+ self ._to_device_msg_id_gen .advance (instance_name , token )
150157 for row in rows :
151158 if row .entity .startswith ("@" ):
152159 self ._device_inbox_stream_cache .entity_has_changed (
@@ -162,11 +169,11 @@ def process_replication_position(
162169 self , stream_name : str , instance_name : str , token : int
163170 ) -> None :
164171 if stream_name == ToDeviceStream .NAME :
165- self ._device_inbox_id_gen .advance (instance_name , token )
172+ self ._to_device_msg_id_gen .advance (instance_name , token )
166173 super ().process_replication_position (stream_name , instance_name , token )
167174
168175 def get_to_device_stream_token (self ) -> int :
169- return self ._device_inbox_id_gen .get_current_token ()
176+ return self ._to_device_msg_id_gen .get_current_token ()
170177
171178 async def get_messages_for_user_devices (
172179 self ,
@@ -801,7 +808,7 @@ def add_messages_txn(
801808 msg .get (EventContentFields .TO_DEVICE_MSGID ),
802809 )
803810
804- async with self ._device_inbox_id_gen .get_next () as stream_id :
811+ async with self ._to_device_msg_id_gen .get_next () as stream_id :
805812 now_ms = self ._clock .time_msec ()
806813 await self .db_pool .runInteraction (
807814 "add_messages_to_device_inbox" , add_messages_txn , now_ms , stream_id
@@ -813,7 +820,7 @@ def add_messages_txn(
813820 destination , stream_id
814821 )
815822
816- return self ._device_inbox_id_gen .get_current_token ()
823+ return self ._to_device_msg_id_gen .get_current_token ()
817824
818825 async def add_messages_from_remote_to_device_inbox (
819826 self ,
@@ -857,7 +864,7 @@ def add_messages_txn(
857864 txn , stream_id , local_messages_by_user_then_device
858865 )
859866
860- async with self ._device_inbox_id_gen .get_next () as stream_id :
867+ async with self ._to_device_msg_id_gen .get_next () as stream_id :
861868 now_ms = self ._clock .time_msec ()
862869 await self .db_pool .runInteraction (
863870 "add_messages_from_remote_to_device_inbox" ,
0 commit comments