Skip to content

Commit bbbb9a7

Browse files
committed
base loader: fix micro batch is_processed marking, add tests
1 parent 710c4e3 commit bbbb9a7

File tree

4 files changed

+267
-13
lines changed

4 files changed

+267
-13
lines changed

src/amp/loaders/base.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,7 @@ def load_stream_continuous(
484484
table_name,
485485
connection_name,
486486
response.metadata.ranges,
487+
ranges_complete=response.metadata.ranges_complete,
487488
)
488489
else:
489490
# Non-transactional loading (separate check, load, mark)
@@ -494,6 +495,7 @@ def load_stream_continuous(
494495
table_name,
495496
connection_name,
496497
response.metadata.ranges,
498+
ranges_complete=response.metadata.ranges_complete,
497499
**filtered_kwargs,
498500
)
499501

@@ -611,6 +613,7 @@ def _process_batch_transactional(
611613
table_name: str,
612614
connection_name: str,
613615
ranges: List[BlockRange],
616+
ranges_complete: bool = False,
614617
) -> LoadResult:
615618
"""
616619
Process a data batch using transactional exactly-once semantics.
@@ -622,6 +625,7 @@ def _process_batch_transactional(
622625
table_name: Target table name
623626
connection_name: Connection identifier
624627
ranges: Block ranges for this batch
628+
ranges_complete: True when this RecordBatch completes a microbatch (streaming only)
625629
626630
Returns:
627631
LoadResult with operation outcome
@@ -630,13 +634,15 @@ def _process_batch_transactional(
630634
try:
631635
# Delegate to loader-specific transactional implementation
632636
# Loaders that support transactions implement load_batch_transactional()
633-
rows_loaded_batch = self.load_batch_transactional(batch_data, table_name, connection_name, ranges)
637+
rows_loaded_batch = self.load_batch_transactional(batch_data, table_name, connection_name, ranges, ranges_complete)
634638
duration = time.time() - start_time
635639

636-
# Mark batches as processed in state store after successful transaction
637-
if ranges:
640+
# Mark batches as processed ONLY when microbatch is complete
641+
# multiple RecordBatches can share the same microbatch ID
642+
if ranges and ranges_complete:
638643
batch_ids = [BatchIdentifier.from_block_range(br) for br in ranges]
639644
self.state_store.mark_processed(connection_name, table_name, batch_ids)
645+
self.logger.debug(f'Marked microbatch as processed: {len(batch_ids)} batch IDs')
640646

641647
return LoadResult(
642648
rows_loaded=rows_loaded_batch,
@@ -648,6 +654,7 @@ def _process_batch_transactional(
648654
metadata={
649655
'operation': 'transactional_load' if rows_loaded_batch > 0 else 'skip_duplicate',
650656
'ranges': [r.to_dict() for r in ranges],
657+
'ranges_complete': ranges_complete,
651658
},
652659
)
653660

@@ -670,6 +677,7 @@ def _process_batch_non_transactional(
670677
table_name: str,
671678
connection_name: str,
672679
ranges: Optional[List[BlockRange]],
680+
ranges_complete: bool = False,
673681
**kwargs,
674682
) -> Optional[LoadResult]:
675683
"""
@@ -682,21 +690,25 @@ def _process_batch_non_transactional(
682690
table_name: Target table name
683691
connection_name: Connection identifier
684692
ranges: Block ranges for this batch (if available)
693+
ranges_complete: True when this RecordBatch completes a microbatch (streaming only)
685694
**kwargs: Additional options passed to load_batch
686695
687696
Returns:
688697
LoadResult, or None if batch was skipped as duplicate
689698
"""
690699
# Check if batch already processed (idempotency / exactly-once)
691-
if ranges and self.state_enabled:
700+
# For streaming: only check when ranges_complete=True (end of microbatch)
701+
# Multiple RecordBatches can share the same microbatch ID, so we must wait
702+
# until the entire microbatch is delivered before checking/marking as processed
703+
if ranges and self.state_enabled and ranges_complete:
692704
try:
693705
batch_ids = [BatchIdentifier.from_block_range(br) for br in ranges]
694706
is_duplicate = self.state_store.is_processed(connection_name, table_name, batch_ids)
695707

696708
if is_duplicate:
697709
# Skip this batch - already processed
698710
self.logger.info(
699-
f'Skipping duplicate batch: {len(ranges)} ranges already processed for {table_name}'
711+
f'Skipping duplicate microbatch: {len(ranges)} ranges already processed for {table_name}'
700712
)
701713
return LoadResult(
702714
rows_loaded=0,
@@ -711,14 +723,16 @@ def _process_batch_non_transactional(
711723
# BlockRange missing hash - log and continue without idempotency check
712724
self.logger.warning(f'Cannot check for duplicates: {e}. Processing batch anyway.')
713725

714-
# Load batch
726+
# Load batch (always load, even if part of larger microbatch)
715727
result = self.load_batch(batch_data, table_name, **kwargs)
716728

717-
if result.success and ranges and self.state_enabled:
718-
# Mark batch as processed (for exactly-once semantics)
729+
# Mark batch as processed ONLY when microbatch is complete
730+
# This ensures we don't skip subsequent RecordBatches within the same microbatch
731+
if result.success and ranges and self.state_enabled and ranges_complete:
719732
try:
720733
batch_ids = [BatchIdentifier.from_block_range(br) for br in ranges]
721734
self.state_store.mark_processed(connection_name, table_name, batch_ids)
735+
self.logger.debug(f'Marked microbatch as processed: {len(batch_ids)} batch IDs')
722736
except Exception as e:
723737
self.logger.error(f'Failed to mark batches as processed: {e}')
724738
# Continue anyway - state store provides resume capability

src/amp/loaders/implementations/postgresql_loader.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def load_batch_transactional(
119119
table_name: str,
120120
connection_name: str,
121121
ranges: List[BlockRange],
122+
ranges_complete: bool = False,
122123
) -> int:
123124
"""
124125
Load a batch with transactional exactly-once semantics using in-memory state.
@@ -135,6 +136,7 @@ def load_batch_transactional(
135136
table_name: Target table name
136137
connection_name: Connection identifier for tracking
137138
ranges: Block ranges covered by this batch
139+
ranges_complete: True when this RecordBatch completes a microbatch (streaming only)
138140
139141
Returns:
140142
Number of rows loaded (0 if duplicate)
@@ -149,24 +151,27 @@ def load_batch_transactional(
149151
self.logger.warning(f'Cannot create batch identifiers: {e}. Loading without duplicate check.')
150152
batch_ids = []
151153

152-
# Check if already processed (using in-memory state)
153-
if batch_ids and self.state_store.is_processed(connection_name, table_name, batch_ids):
154+
# Check if already processed ONLY when microbatch is complete
155+
# Multiple RecordBatches can share the same microbatch ID (BlockRange)
156+
if batch_ids and ranges_complete and self.state_store.is_processed(connection_name, table_name, batch_ids):
154157
self.logger.info(
155158
f'Batch already processed (ranges: {[f"{r.network}:{r.start}-{r.end}" for r in ranges]}), '
156159
f'skipping (state check)'
157160
)
158161
return 0
159162

160-
# Load data
163+
# Load data (always load, even if part of larger microbatch)
161164
conn = self.pool.getconn()
162165
try:
163166
with conn.cursor() as cur:
164167
self._copy_arrow_data(cur, batch, table_name)
165168
conn.commit()
166169

167-
# Mark as processed after successful load
168-
if batch_ids:
170+
# Mark as processed ONLY when microbatch is complete
171+
# This ensures we don't skip subsequent RecordBatches within the same microbatch
172+
if batch_ids and ranges_complete:
169173
self.state_store.mark_processed(connection_name, table_name, batch_ids)
174+
self.logger.debug(f'Marked microbatch as processed: {len(batch_ids)} batch IDs')
170175

171176
self.logger.debug(
172177
f'Batch load committed: {batch.num_rows} rows, '

tests/integration/test_postgresql_loader.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,3 +692,128 @@ def test_reorg_preserves_different_networks(self, postgresql_test_config, test_t
692692

693693
finally:
694694
loader.pool.putconn(conn)
695+
696+
def test_microbatch_deduplication(self, postgresql_test_config, test_table_name, cleanup_tables):
697+
"""
698+
Test that multiple RecordBatches within the same microbatch are all loaded,
699+
and deduplication only happens at microbatch boundaries when ranges_complete=True.
700+
701+
This test verifies the fix for the critical bug where we were marking batches
702+
as processed after every RecordBatch instead of waiting for ranges_complete=True.
703+
"""
704+
from src.amp.streaming.types import BatchMetadata, BlockRange, ResponseBatch
705+
706+
cleanup_tables.append(test_table_name)
707+
708+
# Enable state management to test deduplication
709+
config_with_state = {
710+
**postgresql_test_config,
711+
'state': {'enabled': True, 'storage': 'memory', 'store_batch_id': True},
712+
}
713+
loader = PostgreSQLLoader(config_with_state)
714+
715+
with loader:
716+
# Create table first from the schema
717+
batch1_data = pa.RecordBatch.from_pydict({'id': [1, 2], 'value': [100, 200]})
718+
loader._create_table_from_schema(batch1_data.schema, test_table_name)
719+
720+
# Simulate a microbatch sent as 3 RecordBatches with the same BlockRange
721+
# This happens when the server sends large microbatches in smaller chunks
722+
723+
# First RecordBatch of the microbatch (ranges_complete=False)
724+
response1 = ResponseBatch.data_batch(
725+
data=batch1_data,
726+
metadata=BatchMetadata(
727+
ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')],
728+
ranges_complete=False, # Not the last batch in this microbatch
729+
),
730+
)
731+
732+
# Second RecordBatch of the microbatch (ranges_complete=False)
733+
batch2_data = pa.RecordBatch.from_pydict({'id': [3, 4], 'value': [300, 400]})
734+
response2 = ResponseBatch.data_batch(
735+
data=batch2_data,
736+
metadata=BatchMetadata(
737+
ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], # Same BlockRange!
738+
ranges_complete=False, # Still not the last batch
739+
),
740+
)
741+
742+
# Third RecordBatch of the microbatch (ranges_complete=True)
743+
batch3_data = pa.RecordBatch.from_pydict({'id': [5, 6], 'value': [500, 600]})
744+
response3 = ResponseBatch.data_batch(
745+
data=batch3_data,
746+
metadata=BatchMetadata(
747+
ranges=[BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')], # Same BlockRange!
748+
ranges_complete=True, # Last batch in this microbatch - safe to mark as processed
749+
),
750+
)
751+
752+
# Process the microbatch stream
753+
stream = [response1, response2, response3]
754+
results = list(
755+
loader.load_stream_continuous(iter(stream), test_table_name, connection_name='test_connection')
756+
)
757+
758+
# CRITICAL: All 3 RecordBatches should be loaded successfully
759+
# Before the fix, only the first batch would load (the other 2 would be skipped as "duplicates")
760+
assert len(results) == 3, 'All RecordBatches within microbatch should be processed'
761+
assert all(r.success for r in results), 'All batches should succeed'
762+
assert results[0].rows_loaded == 2, 'First batch should load 2 rows'
763+
assert results[1].rows_loaded == 2, 'Second batch should load 2 rows (not skipped!)'
764+
assert results[2].rows_loaded == 2, 'Third batch should load 2 rows (not skipped!)'
765+
766+
# Verify total rows in table (all batches loaded)
767+
conn = loader.pool.getconn()
768+
try:
769+
with conn.cursor() as cur:
770+
cur.execute(f'SELECT COUNT(*) FROM {test_table_name}')
771+
total_count = cur.fetchone()[0]
772+
assert total_count == 6, 'All 6 rows from 3 RecordBatches should be in the table'
773+
774+
# Verify the actual IDs are present
775+
cur.execute(f'SELECT id FROM {test_table_name} ORDER BY id')
776+
all_ids = [row[0] for row in cur.fetchall()]
777+
assert all_ids == [1, 2, 3, 4, 5, 6], 'All rows from all RecordBatches should be present'
778+
779+
finally:
780+
loader.pool.putconn(conn)
781+
782+
# Now test that re-sending the complete microbatch is properly deduplicated
783+
# This time, the first batch has ranges_complete=True (entire microbatch in one RecordBatch)
784+
duplicate_batch = pa.RecordBatch.from_pydict({'id': [7, 8], 'value': [700, 800]})
785+
duplicate_response = ResponseBatch.data_batch(
786+
data=duplicate_batch,
787+
metadata=BatchMetadata(
788+
ranges=[
789+
BlockRange(network='ethereum', start=100, end=110, hash='0xabc123')
790+
], # Same range as before!
791+
ranges_complete=True, # Complete microbatch
792+
),
793+
)
794+
795+
# Process duplicate microbatch
796+
duplicate_results = list(
797+
loader.load_stream_continuous(
798+
iter([duplicate_response]), test_table_name, connection_name='test_connection'
799+
)
800+
)
801+
802+
# The duplicate microbatch should be skipped (already processed)
803+
assert len(duplicate_results) == 1
804+
assert duplicate_results[0].success is True
805+
assert duplicate_results[0].rows_loaded == 0, 'Duplicate microbatch should be skipped'
806+
assert (
807+
duplicate_results[0].metadata.get('operation') == 'skip_duplicate'
808+
), 'Should be marked as duplicate'
809+
810+
# Verify row count unchanged (duplicate was skipped)
811+
conn = loader.pool.getconn()
812+
try:
813+
with conn.cursor() as cur:
814+
cur.execute(f'SELECT COUNT(*) FROM {test_table_name}')
815+
final_count = cur.fetchone()[0]
816+
assert final_count == 6, 'Row count should not increase after duplicate microbatch'
817+
818+
finally:
819+
loader.pool.putconn(conn)

0 commit comments

Comments
 (0)