diff --git a/src/amp/loaders/base.py b/src/amp/loaders/base.py index 3097feb..cc8a9a9 100644 --- a/src/amp/loaders/base.py +++ b/src/amp/loaders/base.py @@ -484,6 +484,7 @@ def load_stream_continuous( table_name, connection_name, response.metadata.ranges, + ranges_complete=response.metadata.ranges_complete, ) else: # Non-transactional loading (separate check, load, mark) @@ -494,6 +495,7 @@ def load_stream_continuous( table_name, connection_name, response.metadata.ranges, + ranges_complete=response.metadata.ranges_complete, **filtered_kwargs, ) @@ -611,6 +613,7 @@ def _process_batch_transactional( table_name: str, connection_name: str, ranges: List[BlockRange], + ranges_complete: bool = False, ) -> LoadResult: """ Process a data batch using transactional exactly-once semantics. @@ -622,6 +625,7 @@ def _process_batch_transactional( table_name: Target table name connection_name: Connection identifier ranges: Block ranges for this batch + ranges_complete: True when this RecordBatch completes a microbatch (streaming only) Returns: LoadResult with operation outcome @@ -630,13 +634,17 @@ def _process_batch_transactional( try: # Delegate to loader-specific transactional implementation # Loaders that support transactions implement load_batch_transactional() - rows_loaded_batch = self.load_batch_transactional(batch_data, table_name, connection_name, ranges) + rows_loaded_batch = self.load_batch_transactional( + batch_data, table_name, connection_name, ranges, ranges_complete + ) duration = time.time() - start_time - # Mark batches as processed in state store after successful transaction - if ranges: + # Mark batches as processed ONLY when microbatch is complete + # multiple RecordBatches can share the same microbatch ID + if ranges and ranges_complete: batch_ids = [BatchIdentifier.from_block_range(br) for br in ranges] self.state_store.mark_processed(connection_name, table_name, batch_ids) + self.logger.debug(f'Marked microbatch as processed: {len(batch_ids)} batch IDs') return LoadResult( rows_loaded=rows_loaded_batch, @@ -648,6 +656,7 @@ def _process_batch_transactional( metadata={ 'operation': 'transactional_load' if rows_loaded_batch > 0 else 'skip_duplicate', 'ranges': [r.to_dict() for r in ranges], + 'ranges_complete': ranges_complete, }, ) @@ -670,6 +679,7 @@ def _process_batch_non_transactional( table_name: str, connection_name: str, ranges: Optional[List[BlockRange]], + ranges_complete: bool = False, **kwargs, ) -> Optional[LoadResult]: """ @@ -682,13 +692,17 @@ def _process_batch_non_transactional( table_name: Target table name connection_name: Connection identifier ranges: Block ranges for this batch (if available) + ranges_complete: True when this RecordBatch completes a microbatch (streaming only) **kwargs: Additional options passed to load_batch Returns: LoadResult, or None if batch was skipped as duplicate """ # Check if batch already processed (idempotency / exactly-once) - if ranges and self.state_enabled: + # For streaming: only check when ranges_complete=True (end of microbatch) + # Multiple RecordBatches can share the same microbatch ID, so we must wait + # until the entire microbatch is delivered before checking/marking as processed + if ranges and self.state_enabled and ranges_complete: try: batch_ids = [BatchIdentifier.from_block_range(br) for br in ranges] is_duplicate = self.state_store.is_processed(connection_name, table_name, batch_ids) @@ -696,7 +710,7 @@ def _process_batch_non_transactional( if is_duplicate: # Skip this batch - already processed self.logger.info( - f'Skipping duplicate batch: {len(ranges)} ranges already processed for {table_name}' + f'Skipping duplicate microbatch: {len(ranges)} ranges already processed for {table_name}' ) return LoadResult( rows_loaded=0, @@ -711,14 +725,16 @@ def _process_batch_non_transactional( # BlockRange missing hash - log and continue without idempotency check self.logger.warning(f'Cannot check for duplicates: {e}. Processing batch anyway.') - # Load batch + # Load batch (always load, even if part of larger microbatch) result = self.load_batch(batch_data, table_name, **kwargs) - if result.success and ranges and self.state_enabled: - # Mark batch as processed (for exactly-once semantics) + # Mark batch as processed ONLY when microbatch is complete + # This ensures we don't skip subsequent RecordBatches within the same microbatch + if result.success and ranges and self.state_enabled and ranges_complete: try: batch_ids = [BatchIdentifier.from_block_range(br) for br in ranges] self.state_store.mark_processed(connection_name, table_name, batch_ids) + self.logger.debug(f'Marked microbatch as processed: {len(batch_ids)} batch IDs') except Exception as e: self.logger.error(f'Failed to mark batches as processed: {e}') # Continue anyway - state store provides resume capability diff --git a/src/amp/loaders/implementations/postgresql_loader.py b/src/amp/loaders/implementations/postgresql_loader.py index 6e84703..7bae9f1 100644 --- a/src/amp/loaders/implementations/postgresql_loader.py +++ b/src/amp/loaders/implementations/postgresql_loader.py @@ -119,6 +119,7 @@ def load_batch_transactional( table_name: str, connection_name: str, ranges: List[BlockRange], + ranges_complete: bool = False, ) -> int: """ Load a batch with transactional exactly-once semantics using in-memory state. @@ -135,6 +136,7 @@ def load_batch_transactional( table_name: Target table name connection_name: Connection identifier for tracking ranges: Block ranges covered by this batch + ranges_complete: True when this RecordBatch completes a microbatch (streaming only) Returns: Number of rows loaded (0 if duplicate) @@ -149,24 +151,27 @@ def load_batch_transactional( self.logger.warning(f'Cannot create batch identifiers: {e}. Loading without duplicate check.') batch_ids = [] - # Check if already processed (using in-memory state) - if batch_ids and self.state_store.is_processed(connection_name, table_name, batch_ids): + # Check if already processed ONLY when microbatch is complete + # Multiple RecordBatches can share the same microbatch ID (BlockRange) + if batch_ids and ranges_complete and self.state_store.is_processed(connection_name, table_name, batch_ids): self.logger.info( f'Batch already processed (ranges: {[f"{r.network}:{r.start}-{r.end}" for r in ranges]}), ' f'skipping (state check)' ) return 0 - # Load data + # Load data (always load, even if part of larger microbatch) conn = self.pool.getconn() try: with conn.cursor() as cur: self._copy_arrow_data(cur, batch, table_name) conn.commit() - # Mark as processed after successful load - if batch_ids: + # Mark as processed ONLY when microbatch is complete + # This ensures we don't skip subsequent RecordBatches within the same microbatch + if batch_ids and ranges_complete: self.state_store.mark_processed(connection_name, table_name, batch_ids) + self.logger.debug(f'Marked microbatch as processed: {len(batch_ids)} batch IDs') self.logger.debug( f'Batch load committed: {batch.num_rows} rows, ' diff --git a/tests/integration/test_deltalake_loader.py b/tests/integration/test_deltalake_loader.py index c925e37..c19c9a9 100644 --- a/tests/integration/test_deltalake_loader.py +++ b/tests/integration/test_deltalake_loader.py @@ -586,15 +586,24 @@ def test_handle_reorg_single_network(self, delta_temp_config): # Create response batches with hashes response1 = ResponseBatch.data_batch( data=batch1, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) response2 = ResponseBatch.data_batch( data=batch2, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) response3 = ResponseBatch.data_batch( data=batch3, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=200, end=210, hash='0x123')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=200, end=210, hash='0x123')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) # Load via streaming API @@ -637,19 +646,31 @@ def test_handle_reorg_multi_network(self, delta_temp_config): # Create response batches with network-specific ranges response1 = ResponseBatch.data_batch( data=batch1, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xaaa')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xaaa')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) response2 = ResponseBatch.data_batch( data=batch2, - metadata=BatchMetadata(ranges=[BlockRange(network='polygon', start=100, end=110, hash='0xbbb')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='polygon', start=100, end=110, hash='0xbbb')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) response3 = ResponseBatch.data_batch( data=batch3, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xccc')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xccc')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) response4 = ResponseBatch.data_batch( data=batch4, - metadata=BatchMetadata(ranges=[BlockRange(network='polygon', start=150, end=160, hash='0xddd')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='polygon', start=150, end=160, hash='0xddd')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) # Load via streaming API @@ -689,15 +710,24 @@ def test_handle_reorg_overlapping_ranges(self, delta_temp_config): # Batch 3: 170-190 (after reorg, but should be deleted as 170 >= 150) response1 = ResponseBatch.data_batch( data=batch1, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=90, end=110, hash='0xaaa')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=90, end=110, hash='0xaaa')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) response2 = ResponseBatch.data_batch( data=batch2, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=140, end=160, hash='0xbbb')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=140, end=160, hash='0xbbb')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) response3 = ResponseBatch.data_batch( data=batch3, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=170, end=190, hash='0xccc')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=170, end=190, hash='0xccc')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) # Load via streaming API @@ -733,15 +763,24 @@ def test_handle_reorg_version_history(self, delta_temp_config): response1 = ResponseBatch.data_batch( data=batch1, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=0, end=10, hash='0xaaa')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=0, end=10, hash='0xaaa')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) response2 = ResponseBatch.data_batch( data=batch2, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=50, end=60, hash='0xbbb')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=50, end=60, hash='0xbbb')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) response3 = ResponseBatch.data_batch( data=batch3, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xccc')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xccc')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) # Load via streaming API @@ -792,12 +831,18 @@ def test_streaming_with_reorg(self, delta_temp_config): # Create response batches using factory methods (with hashes for proper state management) response1 = ResponseBatch.data_batch( data=data1, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) response2 = ResponseBatch.data_batch( data=data2, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef456')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef456')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) # Simulate reorg event using factory method diff --git a/tests/integration/test_lmdb_loader.py b/tests/integration/test_lmdb_loader.py index e7bf14b..20e2f67 100644 --- a/tests/integration/test_lmdb_loader.py +++ b/tests/integration/test_lmdb_loader.py @@ -411,15 +411,24 @@ def test_handle_reorg_single_network(self, lmdb_config): # Create response batches with hashes response1 = ResponseBatch.data_batch( data=batch1, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) response2 = ResponseBatch.data_batch( data=batch2, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) response3 = ResponseBatch.data_batch( data=batch3, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=200, end=210, hash='0x123')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=200, end=210, hash='0x123')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) # Load via streaming API @@ -468,19 +477,31 @@ def test_handle_reorg_multi_network(self, lmdb_config): # Create response batches with network-specific ranges response1 = ResponseBatch.data_batch( data=batch1, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xaaa')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xaaa')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) response2 = ResponseBatch.data_batch( data=batch2, - metadata=BatchMetadata(ranges=[BlockRange(network='polygon', start=100, end=110, hash='0xbbb')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='polygon', start=100, end=110, hash='0xbbb')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) response3 = ResponseBatch.data_batch( data=batch3, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xccc')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xccc')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) response4 = ResponseBatch.data_batch( data=batch4, - metadata=BatchMetadata(ranges=[BlockRange(network='polygon', start=150, end=160, hash='0xddd')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='polygon', start=150, end=160, hash='0xddd')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) # Load via streaming API @@ -524,15 +545,24 @@ def test_handle_reorg_overlapping_ranges(self, lmdb_config): # Batch 3: 170-190 (after reorg, but should be deleted as 170 >= 150) response1 = ResponseBatch.data_batch( data=batch1, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=90, end=110, hash='0xaaa')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=90, end=110, hash='0xaaa')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) response2 = ResponseBatch.data_batch( data=batch2, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=140, end=160, hash='0xbbb')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=140, end=160, hash='0xbbb')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) response3 = ResponseBatch.data_batch( data=batch3, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=170, end=190, hash='0xccc')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=170, end=190, hash='0xccc')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) # Load via streaming API @@ -577,12 +607,18 @@ def test_streaming_with_reorg(self, lmdb_config): # Create response batches using factory methods (with hashes for proper state management) response1 = ResponseBatch.data_batch( data=data1, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) response2 = ResponseBatch.data_batch( data=data2, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef456')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=150, end=160, hash='0xdef456')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) # Simulate reorg event using factory method diff --git a/tests/integration/test_postgresql_loader.py b/tests/integration/test_postgresql_loader.py index 8b68186..649868c 100644 --- a/tests/integration/test_postgresql_loader.py +++ b/tests/integration/test_postgresql_loader.py @@ -541,19 +541,31 @@ def test_handle_reorg_deletion(self, postgresql_test_config, test_table_name, cl # Create response batches with hashes response1 = ResponseBatch.data_batch( data=batch1, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=102, hash='0xaaa')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=102, hash='0xaaa')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) response2 = ResponseBatch.data_batch( data=batch2, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=103, end=104, hash='0xbbb')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=103, end=104, hash='0xbbb')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) response3 = ResponseBatch.data_batch( data=batch3, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=105, end=106, hash='0xccc')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=105, end=106, hash='0xccc')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) response4 = ResponseBatch.data_batch( data=batch4, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=107, end=108, hash='0xddd')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=107, end=108, hash='0xddd')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) # Load via streaming API @@ -605,7 +617,10 @@ def test_reorg_with_overlapping_ranges(self, postgresql_test_config, test_table_ response = ResponseBatch.data_batch( data=batch, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=150, end=175, hash='0xaaa')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=150, end=175, hash='0xaaa')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) # Load via streaming API @@ -658,11 +673,17 @@ def test_reorg_preserves_different_networks(self, postgresql_test_config, test_t response_eth = ResponseBatch.data_batch( data=batch_eth, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=100, hash='0xaaa')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=100, hash='0xaaa')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) response_poly = ResponseBatch.data_batch( data=batch_poly, - metadata=BatchMetadata(ranges=[BlockRange(network='polygon', start=100, end=100, hash='0xbbb')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='polygon', start=100, end=100, hash='0xbbb')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) # Load both batches via streaming API @@ -692,3 +713,126 @@ def test_reorg_preserves_different_networks(self, postgresql_test_config, test_t finally: loader.pool.putconn(conn) + + def test_microbatch_deduplication(self, postgresql_test_config, test_table_name, cleanup_tables): + """ + Test that multiple RecordBatches within the same microbatch are all loaded, + and deduplication only happens at microbatch boundaries when ranges_complete=True. + + This test verifies the fix for the critical bug where we were marking batches + as processed after every RecordBatch instead of waiting for ranges_complete=True. + """ + from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch + + cleanup_tables.append(test_table_name) + + # Enable state management to test deduplication + config_with_state = { + **postgresql_test_config, + 'state': {'enabled': True, 'storage': 'memory', 'store_batch_id': True}, + } + loader = PostgreSQLLoader(config_with_state) + + with loader: + # Create table first from the schema + batch1_data = pa.RecordBatch.from_pydict({'id': [1, 2], 'value': [100, 200]}) + loader._create_table_from_schema(batch1_data.schema, test_table_name) + + # Simulate a microbatch sent as 3 RecordBatches with the same BlockRange + # This happens when the server sends large microbatches in smaller chunks + + # First RecordBatch of the microbatch (ranges_complete=False) + response1 = ResponseBatch.data_batch( + data=batch1_data, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], + ranges_complete=False, # Not the last batch in this microbatch + ), + ) + + # Second RecordBatch of the microbatch (ranges_complete=False) + batch2_data = pa.RecordBatch.from_pydict({'id': [3, 4], 'value': [300, 400]}) + response2 = ResponseBatch.data_batch( + data=batch2_data, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], # Same BlockRange! + ranges_complete=False, # Still not the last batch + ), + ) + + # Third RecordBatch of the microbatch (ranges_complete=True) + batch3_data = pa.RecordBatch.from_pydict({'id': [5, 6], 'value': [500, 600]}) + response3 = ResponseBatch.data_batch( + data=batch3_data, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], # Same BlockRange! + ranges_complete=True, # Last batch in this microbatch - safe to mark as processed + ), + ) + + # Process the microbatch stream + stream = [response1, response2, response3] + results = list( + loader.load_stream_continuous(iter(stream), test_table_name, connection_name='test_connection') + ) + + # CRITICAL: All 3 RecordBatches should be loaded successfully + # Before the fix, only the first batch would load (the other 2 would be skipped as "duplicates") + assert len(results) == 3, 'All RecordBatches within microbatch should be processed' + assert all(r.success for r in results), 'All batches should succeed' + assert results[0].rows_loaded == 2, 'First batch should load 2 rows' + assert results[1].rows_loaded == 2, 'Second batch should load 2 rows (not skipped!)' + assert results[2].rows_loaded == 2, 'Third batch should load 2 rows (not skipped!)' + + # Verify total rows in table (all batches loaded) + conn = loader.pool.getconn() + try: + with conn.cursor() as cur: + cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') + total_count = cur.fetchone()[0] + assert total_count == 6, 'All 6 rows from 3 RecordBatches should be in the table' + + # Verify the actual IDs are present + cur.execute(f'SELECT id FROM {test_table_name} ORDER BY id') + all_ids = [row[0] for row in cur.fetchall()] + assert all_ids == [1, 2, 3, 4, 5, 6], 'All rows from all RecordBatches should be present' + + finally: + loader.pool.putconn(conn) + + # Now test that re-sending the complete microbatch is properly deduplicated + # This time, the first batch has ranges_complete=True (entire microbatch in one RecordBatch) + duplicate_batch = pa.RecordBatch.from_pydict({'id': [7, 8], 'value': [700, 800]}) + duplicate_response = ResponseBatch.data_batch( + data=duplicate_batch, + metadata=BatchMetadata( + ranges=[ + BlockRange(network='ethereum', start=100, end=110, hash='0xabc123') + ], # Same range as before! + ranges_complete=True, # Complete microbatch + ), + ) + + # Process duplicate microbatch + duplicate_results = list( + loader.load_stream_continuous( + iter([duplicate_response]), test_table_name, connection_name='test_connection' + ) + ) + + # The duplicate microbatch should be skipped (already processed) + assert len(duplicate_results) == 1 + assert duplicate_results[0].success is True + assert duplicate_results[0].rows_loaded == 0, 'Duplicate microbatch should be skipped' + assert duplicate_results[0].metadata.get('operation') == 'skip_duplicate', 'Should be marked as duplicate' + + # Verify row count unchanged (duplicate was skipped) + conn = loader.pool.getconn() + try: + with conn.cursor() as cur: + cur.execute(f'SELECT COUNT(*) FROM {test_table_name}') + final_count = cur.fetchone()[0] + assert final_count == 6, 'Row count should not increase after duplicate microbatch' + + finally: + loader.pool.putconn(conn) diff --git a/tests/integration/test_redis_loader.py b/tests/integration/test_redis_loader.py index bf7dc9a..ce23da7 100644 --- a/tests/integration/test_redis_loader.py +++ b/tests/integration/test_redis_loader.py @@ -724,15 +724,24 @@ def test_handle_reorg_deletion(self, redis_test_config, cleanup_redis): # Create response batches with hashes response1 = ResponseBatch.data_batch( data=batch1, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=102, hash='0xaaa')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=102, hash='0xaaa')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) response2 = ResponseBatch.data_batch( data=batch2, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=103, end=104, hash='0xbbb')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=103, end=104, hash='0xbbb')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) response3 = ResponseBatch.data_batch( data=batch3, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=105, end=106, hash='0xccc')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=105, end=106, hash='0xccc')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) # Load via streaming API @@ -789,7 +798,10 @@ def test_reorg_with_overlapping_ranges(self, redis_test_config, cleanup_redis): response = ResponseBatch.data_batch( data=batch, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=150, end=175, hash='0xaaa')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=150, end=175, hash='0xaaa')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) # Load via streaming API @@ -856,11 +868,17 @@ def test_reorg_preserves_different_networks(self, redis_test_config, cleanup_red response_eth = ResponseBatch.data_batch( data=batch_eth, - metadata=BatchMetadata(ranges=[BlockRange(network='ethereum', start=100, end=100, hash='0xaaa')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=100, hash='0xaaa')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) response_poly = ResponseBatch.data_batch( data=batch_poly, - metadata=BatchMetadata(ranges=[BlockRange(network='polygon', start=100, end=100, hash='0xbbb')]), + metadata=BatchMetadata( + ranges=[BlockRange(network='polygon', start=100, end=100, hash='0xbbb')], + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), ) # Load both batches via streaming API @@ -921,7 +939,13 @@ def test_streaming_with_string_data_structure(self, redis_test_config, cleanup_r block_ranges = [BlockRange(network='polygon', start=200, end=202, hash='0xabc')] # Load via streaming API - response = ResponseBatch.data_batch(data=batch, metadata=BatchMetadata(ranges=block_ranges)) + response = ResponseBatch.data_batch( + data=batch, + metadata=BatchMetadata( + ranges=block_ranges, + ranges_complete=True, # Mark as complete so it gets tracked in state store + ), + ) results = list(loader.load_stream_continuous(iter([response]), table_name)) assert len(results) == 1 assert results[0].success == True diff --git a/tests/integration/test_snowflake_loader.py b/tests/integration/test_snowflake_loader.py index 78c2c17..50f96c7 100644 --- a/tests/integration/test_snowflake_loader.py +++ b/tests/integration/test_snowflake_loader.py @@ -1105,3 +1105,111 @@ def test_streaming_error_handling(self, snowflake_streaming_config, test_table_n # and ignores columns that don't exist in the table assert result.success is True assert result.rows_loaded == 2 + + def test_microbatch_deduplication(self, snowflake_config, test_table_name, cleanup_tables): + """ + Test that multiple RecordBatches within the same microbatch are all loaded, + and deduplication only happens at microbatch boundaries when ranges_complete=True. + + This test verifies the fix for the critical bug where we were marking batches + as processed after every RecordBatch instead of waiting for ranges_complete=True. + """ + from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch + + cleanup_tables.append(test_table_name) + + # Enable state management to test deduplication + config_with_state = { + **snowflake_config, + 'state': {'enabled': True, 'storage': 'memory', 'store_batch_id': True}, + } + loader = SnowflakeLoader(config_with_state) + + with loader: + # Simulate a microbatch sent as 3 RecordBatches with the same BlockRange + # This happens when the server sends large microbatches in smaller chunks + + # First RecordBatch of the microbatch (ranges_complete=False) + batch1_data = pa.RecordBatch.from_pydict({'id': [1, 2], 'value': [100, 200]}) + response1 = ResponseBatch.data_batch( + data=batch1_data, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], + ranges_complete=False, # Not the last batch in this microbatch + ), + ) + + # Second RecordBatch of the microbatch (ranges_complete=False) + batch2_data = pa.RecordBatch.from_pydict({'id': [3, 4], 'value': [300, 400]}) + response2 = ResponseBatch.data_batch( + data=batch2_data, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], # Same BlockRange! + ranges_complete=False, # Still not the last batch + ), + ) + + # Third RecordBatch of the microbatch (ranges_complete=True) + batch3_data = pa.RecordBatch.from_pydict({'id': [5, 6], 'value': [500, 600]}) + response3 = ResponseBatch.data_batch( + data=batch3_data, + metadata=BatchMetadata( + ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], # Same BlockRange! + ranges_complete=True, # Last batch in this microbatch - safe to mark as processed + ), + ) + + # Process the microbatch stream + stream = [response1, response2, response3] + results = list( + loader.load_stream_continuous(iter(stream), test_table_name, connection_name='test_connection') + ) + + # CRITICAL: All 3 RecordBatches should be loaded successfully + # Before the fix, only the first batch would load (the other 2 would be skipped as "duplicates") + assert len(results) == 3, 'All RecordBatches within microbatch should be processed' + assert all(r.success for r in results), 'All batches should succeed' + assert results[0].rows_loaded == 2, 'First batch should load 2 rows' + assert results[1].rows_loaded == 2, 'Second batch should load 2 rows (not skipped!)' + assert results[2].rows_loaded == 2, 'Third batch should load 2 rows (not skipped!)' + + # Verify total rows in table (all batches loaded) + loader.cursor.execute(f'SELECT COUNT(*) as count FROM {test_table_name}') + total_count = loader.cursor.fetchone()['COUNT'] + assert total_count == 6, 'All 6 rows from 3 RecordBatches should be in the table' + + # Verify the actual IDs are present + loader.cursor.execute(f'SELECT "id" FROM {test_table_name} ORDER BY "id"') + all_ids = [row['id'] for row in loader.cursor.fetchall()] + assert all_ids == [1, 2, 3, 4, 5, 6], 'All rows from all RecordBatches should be present' + + # Now test that re-sending the complete microbatch is properly deduplicated + # This time, the first batch has ranges_complete=True (entire microbatch in one RecordBatch) + duplicate_batch = pa.RecordBatch.from_pydict({'id': [7, 8], 'value': [700, 800]}) + duplicate_response = ResponseBatch.data_batch( + data=duplicate_batch, + metadata=BatchMetadata( + ranges=[ + BlockRange(network='ethereum', start=100, end=110, hash='0xabc123') + ], # Same range as before! + ranges_complete=True, # Complete microbatch + ), + ) + + # Process duplicate microbatch + duplicate_results = list( + loader.load_stream_continuous( + iter([duplicate_response]), test_table_name, connection_name='test_connection' + ) + ) + + # The duplicate microbatch should be skipped (already processed) + assert len(duplicate_results) == 1 + assert duplicate_results[0].success is True + assert duplicate_results[0].rows_loaded == 0, 'Duplicate microbatch should be skipped' + assert duplicate_results[0].metadata.get('operation') == 'skip_duplicate', 'Should be marked as duplicate' + + # Verify row count unchanged (duplicate was skipped) + loader.cursor.execute(f'SELECT COUNT(*) as count FROM {test_table_name}') + final_count = loader.cursor.fetchone()['COUNT'] + assert final_count == 6, 'Row count should not increase after duplicate microbatch' diff --git a/tests/unit/test_streaming_helpers.py b/tests/unit/test_streaming_helpers.py index a7c4c8c..55cf46f 100644 --- a/tests/unit/test_streaming_helpers.py +++ b/tests/unit/test_streaming_helpers.py @@ -171,7 +171,7 @@ def test_successful_transactional_load(self, mock_loader, sample_batch, sample_r # Verify method call (no batch_hash in current implementation) mock_loader.load_batch_transactional.assert_called_once_with( - sample_batch, 'test_table', 'test_conn', sample_ranges + sample_batch, 'test_table', 'test_conn', sample_ranges, False ) def test_transactional_duplicate_detection(self, mock_loader, sample_batch, sample_ranges): @@ -217,7 +217,7 @@ class TestProcessBatchNonTransactional: """Test _process_batch_non_transactional helper method""" def test_successful_non_transactional_load(self, mock_loader, sample_batch, sample_ranges): - """Test successful non-transactional batch load""" + """Test successful non-transactional batch processing""" # Setup - mock state store for new unified system mock_loader.state_store.is_processed = Mock(return_value=False) mock_loader.state_store.mark_processed = Mock() @@ -228,12 +228,13 @@ def test_successful_non_transactional_load(self, mock_loader, sample_batch, samp ) mock_loader.load_batch = Mock(return_value=success_result) - # Execute + # Execute with ranges_complete=True to trigger duplicate check result = mock_loader._process_batch_non_transactional( batch_data=sample_batch, table_name='test_table', connection_name='test_conn', ranges=sample_ranges, + ranges_complete=True, # Must be True for duplicate check and mark_processed batch_hash='hash123', ) @@ -252,12 +253,13 @@ def test_duplicate_detection_returns_skip_result(self, mock_loader, sample_batch mock_loader.state_store.is_processed = Mock(return_value=True) mock_loader.load_batch = Mock() # Should not be called - # Execute + # Execute with ranges_complete=True to trigger duplicate check result = mock_loader._process_batch_non_transactional( batch_data=sample_batch, table_name='test_table', connection_name='test_conn', ranges=sample_ranges, + ranges_complete=True, # Must be True for duplicate check batch_hash='hash123', ) @@ -307,6 +309,7 @@ def test_mark_processed_failure_continues(self, mock_loader, sample_batch, sampl table_name='test_table', connection_name='test_conn', ranges=sample_ranges, + ranges_complete=True, # Must be True for mark_processed to be called batch_hash='hash123', )