@@ -692,3 +692,128 @@ def test_reorg_preserves_different_networks(self, postgresql_test_config, test_t
692692
693693 finally :
694694 loader .pool .putconn (conn )
695+
696+ def test_microbatch_deduplication (self , postgresql_test_config , test_table_name , cleanup_tables ):
697+ """
698+ Test that multiple RecordBatches within the same microbatch are all loaded,
699+ and deduplication only happens at microbatch boundaries when ranges_complete=True.
700+
701+ This test verifies the fix for the critical bug where we were marking batches
702+ as processed after every RecordBatch instead of waiting for ranges_complete=True.
703+ """
704+ from src .amp .streaming .types import BatchMetadata , BlockRange , ResponseBatch
705+
706+ cleanup_tables .append (test_table_name )
707+
708+ # Enable state management to test deduplication
709+ config_with_state = {
710+ ** postgresql_test_config ,
711+ 'state' : {'enabled' : True , 'storage' : 'memory' , 'store_batch_id' : True },
712+ }
713+ loader = PostgreSQLLoader (config_with_state )
714+
715+ with loader :
716+ # Create table first from the schema
717+ batch1_data = pa .RecordBatch .from_pydict ({'id' : [1 , 2 ], 'value' : [100 , 200 ]})
718+ loader ._create_table_from_schema (batch1_data .schema , test_table_name )
719+
720+ # Simulate a microbatch sent as 3 RecordBatches with the same BlockRange
721+ # This happens when the server sends large microbatches in smaller chunks
722+
723+ # First RecordBatch of the microbatch (ranges_complete=False)
724+ response1 = ResponseBatch .data_batch (
725+ data = batch1_data ,
726+ metadata = BatchMetadata (
727+ ranges = [BlockRange (network = 'ethereum' , start = 100 , end = 110 , hash = '0xabc123' )],
728+ ranges_complete = False , # Not the last batch in this microbatch
729+ ),
730+ )
731+
732+ # Second RecordBatch of the microbatch (ranges_complete=False)
733+ batch2_data = pa .RecordBatch .from_pydict ({'id' : [3 , 4 ], 'value' : [300 , 400 ]})
734+ response2 = ResponseBatch .data_batch (
735+ data = batch2_data ,
736+ metadata = BatchMetadata (
737+ ranges = [BlockRange (network = 'ethereum' , start = 100 , end = 110 , hash = '0xabc123' )], # Same BlockRange!
738+ ranges_complete = False , # Still not the last batch
739+ ),
740+ )
741+
742+ # Third RecordBatch of the microbatch (ranges_complete=True)
743+ batch3_data = pa .RecordBatch .from_pydict ({'id' : [5 , 6 ], 'value' : [500 , 600 ]})
744+ response3 = ResponseBatch .data_batch (
745+ data = batch3_data ,
746+ metadata = BatchMetadata (
747+ ranges = [BlockRange (network = 'ethereum' , start = 100 , end = 110 , hash = '0xabc123' )], # Same BlockRange!
748+ ranges_complete = True , # Last batch in this microbatch - safe to mark as processed
749+ ),
750+ )
751+
752+ # Process the microbatch stream
753+ stream = [response1 , response2 , response3 ]
754+ results = list (
755+ loader .load_stream_continuous (iter (stream ), test_table_name , connection_name = 'test_connection' )
756+ )
757+
758+ # CRITICAL: All 3 RecordBatches should be loaded successfully
759+ # Before the fix, only the first batch would load (the other 2 would be skipped as "duplicates")
760+ assert len (results ) == 3 , 'All RecordBatches within microbatch should be processed'
761+ assert all (r .success for r in results ), 'All batches should succeed'
762+ assert results [0 ].rows_loaded == 2 , 'First batch should load 2 rows'
763+ assert results [1 ].rows_loaded == 2 , 'Second batch should load 2 rows (not skipped!)'
764+ assert results [2 ].rows_loaded == 2 , 'Third batch should load 2 rows (not skipped!)'
765+
766+ # Verify total rows in table (all batches loaded)
767+ conn = loader .pool .getconn ()
768+ try :
769+ with conn .cursor () as cur :
770+ cur .execute (f'SELECT COUNT(*) FROM { test_table_name } ' )
771+ total_count = cur .fetchone ()[0 ]
772+ assert total_count == 6 , 'All 6 rows from 3 RecordBatches should be in the table'
773+
774+ # Verify the actual IDs are present
775+ cur .execute (f'SELECT id FROM { test_table_name } ORDER BY id' )
776+ all_ids = [row [0 ] for row in cur .fetchall ()]
777+ assert all_ids == [1 , 2 , 3 , 4 , 5 , 6 ], 'All rows from all RecordBatches should be present'
778+
779+ finally :
780+ loader .pool .putconn (conn )
781+
782+ # Now test that re-sending the complete microbatch is properly deduplicated
783+ # This time, the first batch has ranges_complete=True (entire microbatch in one RecordBatch)
784+ duplicate_batch = pa .RecordBatch .from_pydict ({'id' : [7 , 8 ], 'value' : [700 , 800 ]})
785+ duplicate_response = ResponseBatch .data_batch (
786+ data = duplicate_batch ,
787+ metadata = BatchMetadata (
788+ ranges = [
789+ BlockRange (network = 'ethereum' , start = 100 , end = 110 , hash = '0xabc123' )
790+ ], # Same range as before!
791+ ranges_complete = True , # Complete microbatch
792+ ),
793+ )
794+
795+ # Process duplicate microbatch
796+ duplicate_results = list (
797+ loader .load_stream_continuous (
798+ iter ([duplicate_response ]), test_table_name , connection_name = 'test_connection'
799+ )
800+ )
801+
802+ # The duplicate microbatch should be skipped (already processed)
803+ assert len (duplicate_results ) == 1
804+ assert duplicate_results [0 ].success is True
805+ assert duplicate_results [0 ].rows_loaded == 0 , 'Duplicate microbatch should be skipped'
806+ assert (
807+ duplicate_results [0 ].metadata .get ('operation' ) == 'skip_duplicate'
808+ ), 'Should be marked as duplicate'
809+
810+ # Verify row count unchanged (duplicate was skipped)
811+ conn = loader .pool .getconn ()
812+ try :
813+ with conn .cursor () as cur :
814+ cur .execute (f'SELECT COUNT(*) FROM { test_table_name } ' )
815+ final_count = cur .fetchone ()[0 ]
816+ assert final_count == 6 , 'Row count should not increase after duplicate microbatch'
817+
818+ finally :
819+ loader .pool .putconn (conn )
0 commit comments