diff --git a/internal/impl/aws/input_kinesis.go b/internal/impl/aws/input_kinesis.go index aaa21998de..60f203e831 100644 --- a/internal/impl/aws/input_kinesis.go +++ b/internal/impl/aws/input_kinesis.go @@ -37,8 +37,24 @@ const ( kiFieldRebalancePeriod = "rebalance_period" kiFieldStartFromOldest = "start_from_oldest" kiFieldBatching = "batching" + kiFieldEnhancedFanOut = "enhanced_fan_out" + + // Enhanced Fan Out Fields + kiEFOFieldEnabled = "enabled" + kiEFOFieldConsumerName = "consumer_name" + kiEFOFieldConsumerARN = "consumer_arn" + kiEFOFieldRecordBufferCap = "record_buffer_cap" + kiEFOFieldMaxPendingRecordsGlobal = "max_pending_records" ) +type kiEFOConfig struct { + Enabled bool + ConsumerName string + ConsumerARN string + RecordBufferCap int + MaxPendingRecordsGlobal int +} + type kiConfig struct { Streams []string DynamoDB kiddbConfig @@ -47,6 +63,7 @@ type kiConfig struct { LeasePeriod string RebalancePeriod string StartFromOldest bool + EnhancedFanOut *kiEFOConfig } func kinesisInputConfigFromParsed(pConf *service.ParsedConfig) (conf kiConfig, err error) { @@ -73,6 +90,34 @@ func kinesisInputConfigFromParsed(pConf *service.ParsedConfig) (conf kiConfig, e if conf.StartFromOldest, err = pConf.FieldBool(kiFieldStartFromOldest); err != nil { return } + if pConf.Contains(kiFieldEnhancedFanOut) { + efoConf := &kiEFOConfig{} + efoNs := pConf.Namespace(kiFieldEnhancedFanOut) + if efoConf.Enabled, err = efoNs.FieldBool(kiEFOFieldEnabled); err != nil { + return + } + if efoConf.ConsumerName, err = efoNs.FieldString(kiEFOFieldConsumerName); err != nil { + return + } + if efoConf.ConsumerARN, err = efoNs.FieldString(kiEFOFieldConsumerARN); err != nil { + return + } + if efoConf.RecordBufferCap, err = efoNs.FieldInt(kiEFOFieldRecordBufferCap); err != nil { + return + } + if efoConf.RecordBufferCap < 0 { + err = errors.New("enhanced_fan_out.record_buffer_cap must be at least 0") + return + } + if efoConf.MaxPendingRecordsGlobal, err = efoNs.FieldInt(kiEFOFieldMaxPendingRecordsGlobal); err != nil { + return + } + if efoConf.MaxPendingRecordsGlobal < 1 { + err = errors.New("enhanced_fan_out.max_pending_records must be at least 1") + return + } + conf.EnhancedFanOut = efoConf + } return } @@ -141,9 +186,36 @@ Use the `+"`batching`"+` fields to configure an optional [batching policy](/docs service.NewBoolField(kiFieldStartFromOldest). Description("Whether to consume from the oldest message when a sequence does not yet exist for the stream."). Default(true), + service.NewObjectField(kiFieldEnhancedFanOut, + service.NewBoolField(kiEFOFieldEnabled). + Description("Enable Enhanced Fan Out mode for push-based streaming with dedicated throughput."). + Default(false), + service.NewStringField(kiEFOFieldConsumerName). + Description("Consumer name for EFO registration. Auto-generated if empty: bento-clientID."). + Default(""), + service.NewStringField(kiEFOFieldConsumerARN). + Description("Existing consumer ARN to use. If provided, skips registration."). + Default(""). + Advanced(), + service.NewIntField(kiEFOFieldRecordBufferCap). + Description("Buffer capacity for the internal records channel per shard. Lower values reduce memory usage when processing many shards. Set to 0 for unbuffered channel (minimal memory footprint)."). + Default(0). + Advanced(), + service.NewIntField(kiEFOFieldMaxPendingRecordsGlobal). + Description("Maximum total number of records to buffer across all shards before applying backpressure to Kinesis subscriptions. This provides a global memory bound regardless of shard count. Higher values improve throughput by allowing shards to continue receiving data while processing, but increase memory usage. Total memory usage is approximately max_pending_records × average_record_size."). + Default(50000). + Advanced(), + ). + Description("Enhanced Fan Out configuration for push-based streaming. Provides dedicated 2 MB/sec throughput per consumer per shard and lower latency (~70ms). Note: EFO incurs per shard-hour charges."). + Version("1.16.0"). + Optional(). + Advanced(), ). Fields(config.SessionFields()...). - Field(service.NewBatchPolicyField(kiFieldBatching)) + Field(service.NewBatchPolicyField(kiFieldBatching)). + LintRule(`root = match { +this.` + kiFieldEnhancedFanOut + `.` + kiEFOFieldConsumerName + ` != "" && this.` + kiFieldEnhancedFanOut + `.` + kiEFOFieldConsumerARN + ` != "" => ["cannot specify both ` + kiEFOFieldConsumerName + ` and ` + kiEFOFieldConsumerARN + ` in ` + kiFieldEnhancedFanOut + ` config"] +}`) return spec } @@ -174,6 +246,7 @@ type streamInfo struct { explicitShards []string id string // Either a name or arn, extracted from config and used for balancing shards arn string + efoManager *kinesisEFOManager // Enhanced Fan Out manager (if EFO is enabled) } type kinesisReader struct { @@ -187,8 +260,10 @@ type kinesisReader struct { boffPool sync.Pool - svc *kinesis.Client - checkpointer *awsKinesisCheckpointer + svc *kinesis.Client + checkpointer *awsKinesisCheckpointer + efoEnabled bool + globalPendingPool *globalPendingPool streams []*streamInfo @@ -319,6 +394,14 @@ func newKinesisReaderFromConfig(conf kiConfig, batcher service.BatchPolicy, sess if k.rebalancePeriod, err = time.ParseDuration(k.conf.RebalancePeriod); err != nil { return nil, fmt.Errorf("failed to parse rebalance period string: %v", err) } + + // Check if Enhanced Fan Out is enabled + if k.conf.EnhancedFanOut != nil && k.conf.EnhancedFanOut.Enabled { + k.efoEnabled = true + k.globalPendingPool = newGlobalPendingPool(k.conf.EnhancedFanOut.MaxPendingRecordsGlobal) + k.log.Debugf("Enhanced Fan Out enabled with global pending pool max: %d", k.conf.EnhancedFanOut.MaxPendingRecordsGlobal) + } + return &k, nil } @@ -657,9 +740,9 @@ func (k *kinesisReader) runBalancedShards() { for { for _, info := range k.streams { allShards, err := collectShards(k.ctx, info.arn, k.svc) - var clientClaims map[string][]awsKinesisClientClaim + var checkpointData *awsKinesisCheckpointData if err == nil { - clientClaims, err = k.checkpointer.AllClaims(k.ctx, info.id) + checkpointData, err = k.checkpointer.GetCheckpointsAndClaims(k.ctx, info.id) } if err != nil { if k.ctx.Err() != nil { @@ -669,11 +752,18 @@ func (k *kinesisReader) runBalancedShards() { continue } + clientClaims := checkpointData.ClientClaims + shardsWithCheckpoints := checkpointData.ShardsWithCheckpoints + totalShards := len(allShards) unclaimedShards := make(map[string]string, totalShards) for _, s := range allShards { - if !isShardFinished(s) { - unclaimedShards[*s.ShardId] = "" + // Include shard if: + // 1. It's not finished (still open), OR + // 2. It's finished but has a checkpoint (meaning it hasn't been fully consumed yet) + shardID := *s.ShardId + if !isShardFinished(s) || shardsWithCheckpoints[shardID] { + unclaimedShards[shardID] = "" } } for clientID, claims := range clientClaims { @@ -700,7 +790,12 @@ func (k *kinesisReader) runBalancedShards() { continue } wg.Add(1) - if err = k.runConsumer(&wg, *info, shardID, sequence); err != nil { + if k.efoEnabled { + err = k.runEFOConsumer(&wg, *info, shardID, sequence) + } else { + err = k.runConsumer(&wg, *info, shardID, sequence) + } + if err != nil { k.log.Errorf("Failed to start consumer: %v\n", err) } } @@ -749,7 +844,12 @@ func (k *kinesisReader) runBalancedShards() { info.id, randomShard, clientID, k.clientID, ) wg.Add(1) - if err = k.runConsumer(&wg, *info, randomShard, sequence); err != nil { + if k.efoEnabled { + err = k.runEFOConsumer(&wg, *info, randomShard, sequence) + } else { + err = k.runConsumer(&wg, *info, randomShard, sequence) + } + if err != nil { k.log.Errorf("Failed to start consumer: %v\n", err) } else { // If we successfully stole the shard then that's enough @@ -790,7 +890,11 @@ func (k *kinesisReader) runExplicitShards() { sequence, err := k.checkpointer.Claim(k.ctx, id, shardID, "") if err == nil { wg.Add(1) - err = k.runConsumer(&wg, info, shardID, sequence) + if k.efoEnabled { + err = k.runEFOConsumer(&wg, info, shardID, sequence) + } else { + err = k.runConsumer(&wg, info, shardID, sequence) + } } if err != nil { if k.ctx.Err() != nil { @@ -868,6 +972,26 @@ func (k *kinesisReader) Connect(ctx context.Context) error { return err } + // Initialize Enhanced Fan Out if enabled + if k.efoEnabled { + for _, stream := range k.streams { + // Create EFO manager for this stream + efoMgr, err := newKinesisEFOManager(k.conf.EnhancedFanOut, stream.arn, k.clientID, k.svc, k.log) + if err != nil { + return fmt.Errorf("failed to create EFO manager for stream %s: %w", stream.id, err) + } + + // Register consumer and wait for ACTIVE status + consumerARN, err := efoMgr.ensureConsumerRegistered(ctx) + if err != nil { + return fmt.Errorf("failed to register EFO consumer for stream %s: %w", stream.id, err) + } + + stream.efoManager = efoMgr + k.log.Debugf("Enhanced Fan Out consumer registered for stream %s with ARN: %s", stream.id, consumerARN) + } + } + if len(k.streams[0].explicitShards) > 0 { go k.runExplicitShards() } else { diff --git a/internal/impl/aws/input_kinesis_checkpointer.go b/internal/impl/aws/input_kinesis_checkpointer.go index 5dcf78cf2b..a28f719b6b 100644 --- a/internal/impl/aws/input_kinesis_checkpointer.go +++ b/internal/impl/aws/input_kinesis_checkpointer.go @@ -180,54 +180,91 @@ type awsKinesisClientClaim struct { LeaseTimeout time.Time } -// AllClaims returns a map of client IDs to shards claimed by that client, -// including the lease timeout of the claim. -func (k *awsKinesisCheckpointer) AllClaims(ctx context.Context, streamID string) (map[string][]awsKinesisClientClaim, error) { - clientClaims := make(map[string][]awsKinesisClientClaim) - var scanErr error - - scanRes, err := k.svc.Scan(ctx, &dynamodb.ScanInput{ - TableName: aws.String(k.conf.Table), - FilterExpression: aws.String("StreamID = :stream_id"), +// awsKinesisCheckpointData contains both the set of all shards with checkpoints +// and the map of client claims, retrieved in a single DynamoDB query. +type awsKinesisCheckpointData struct { + // ShardsWithCheckpoints is a set of all shard IDs that have checkpoint records + ShardsWithCheckpoints map[string]bool + // ClientClaims is a map of client IDs to shards claimed by that client + ClientClaims map[string][]awsKinesisClientClaim +} + +// GetCheckpointsAndClaims retrieves all checkpoint data for a stream. +// +// Returns: +// - ShardsWithCheckpoints: set of all shard IDs that have checkpoint records +// - ClientClaims: map of client IDs to their claimed shards (excludes entries without ClientID) +func (k *awsKinesisCheckpointer) GetCheckpointsAndClaims(ctx context.Context, streamID string) (*awsKinesisCheckpointData, error) { + result := &awsKinesisCheckpointData{ + ShardsWithCheckpoints: make(map[string]bool), + ClientClaims: make(map[string][]awsKinesisClientClaim), + } + + input := &dynamodb.QueryInput{ + TableName: aws.String(k.conf.Table), + KeyConditionExpression: aws.String("StreamID = :stream_id"), ExpressionAttributeValues: map[string]types.AttributeValue{ ":stream_id": &types.AttributeValueMemberS{ Value: streamID, }, }, - }) - if err != nil { - return nil, err } - for _, i := range scanRes.Items { - var clientID string - if s, ok := i["ClientID"].(*types.AttributeValueMemberS); ok { - clientID = s.Value - } else { - continue - } + paginator := dynamodb.NewQueryPaginator(k.svc, input) - var claim awsKinesisClientClaim - if s, ok := i["ShardID"].(*types.AttributeValueMemberS); ok { - claim.ShardID = s.Value - } - if claim.ShardID == "" { - return nil, errors.New("failed to extract shard id from claim") + for paginator.HasMorePages() { + page, err := paginator.NextPage(ctx) + if err != nil { + return nil, fmt.Errorf("failed to query checkpoints: %w", err) } - if s, ok := i["LeaseTimeout"].(*types.AttributeValueMemberS); ok { - if claim.LeaseTimeout, scanErr = time.Parse(time.RFC3339Nano, s.Value); scanErr != nil { - return nil, fmt.Errorf("failed to parse claim lease: %w", scanErr) + for _, item := range page.Items { + // Extract ShardID - required for all checkpoint entries + var shardID string + if s, ok := item["ShardID"].(*types.AttributeValueMemberS); ok { + shardID = s.Value + } + if shardID == "" { + // Skip malformed items without a ShardID rather than failing the + // whole query — this is intentionally lenient compared to the + // single-item getCheckpoint path which returns an error, because + // here we are scanning all checkpoints and a single bad row + // should not block progress for the rest. + continue + } + + // Track all shards with checkpoints + result.ShardsWithCheckpoints[shardID] = true + + // Extract client claim if ClientID exists + var clientID string + if s, ok := item["ClientID"].(*types.AttributeValueMemberS); ok { + clientID = s.Value + } + if clientID == "" { + // No client ID means this is an orphaned checkpoint (from final=true) + continue } - } - if claim.LeaseTimeout.IsZero() { - return nil, errors.New("failed to extract lease timeout from claim") - } - clientClaims[clientID] = append(clientClaims[clientID], claim) + // Extract lease timeout for claims + var claim awsKinesisClientClaim + claim.ShardID = shardID + + if s, ok := item["LeaseTimeout"].(*types.AttributeValueMemberS); ok { + var parseErr error + if claim.LeaseTimeout, parseErr = time.Parse(time.RFC3339Nano, s.Value); parseErr != nil { + return nil, fmt.Errorf("failed to parse claim lease for shard %s: %w", shardID, parseErr) + } + } + if claim.LeaseTimeout.IsZero() { + return nil, fmt.Errorf("failed to extract lease timeout from claim for shard %s", shardID) + } + + result.ClientClaims[clientID] = append(result.ClientClaims[clientID], claim) + } } - return clientClaims, scanErr + return result, nil } // Claim attempts to claim a shard for a particular stream ID. If fromClientID diff --git a/internal/impl/aws/input_kinesis_efo.go b/internal/impl/aws/input_kinesis_efo.go new file mode 100644 index 0000000000..a8a232c418 --- /dev/null +++ b/internal/impl/aws/input_kinesis_efo.go @@ -0,0 +1,651 @@ +package aws + +import ( + "context" + "errors" + "fmt" + "sync" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/kinesis" + "github.com/aws/aws-sdk-go-v2/service/kinesis/types" + "github.com/cenkalti/backoff/v4" + + "github.com/warpstreamlabs/bento/public/service" +) + +// errBackpressureTimeout is returned when WaitForSpace times out due to sustained backpressure. +// This is a retryable error that should trigger a backoff before resubscribing. +var errBackpressureTimeout = errors.New("backpressure timeout waiting for space in pending pool") + +// kinesisEFOAPI is the subset of kinesis.Client methods used by kinesisEFOManager. +type kinesisEFOAPI interface { + RegisterStreamConsumer(ctx context.Context, params *kinesis.RegisterStreamConsumerInput, optFns ...func(*kinesis.Options)) (*kinesis.RegisterStreamConsumerOutput, error) + DescribeStreamConsumer(ctx context.Context, params *kinesis.DescribeStreamConsumerInput, optFns ...func(*kinesis.Options)) (*kinesis.DescribeStreamConsumerOutput, error) +} + +// kinesisEFOManager handles Enhanced Fan Out consumer registration and lifecycle +type kinesisEFOManager struct { + streamARN string + consumerName string + consumerARN string + svc kinesisEFOAPI + log *service.Logger + // pollInterval controls how long waitForActiveConsumer waits between status + // checks. Defaults to 2 seconds; overridable in tests for faster iteration. + pollInterval time.Duration +} + +// newKinesisEFOManager creates a new EFO manager +func newKinesisEFOManager(conf *kiEFOConfig, streamARN, clientID string, svc *kinesis.Client, log *service.Logger) (*kinesisEFOManager, error) { + if conf == nil { + return nil, errors.New("enhanced fan out config is nil") + } + + if conf.ConsumerName != "" && conf.ConsumerARN != "" { + return nil, errors.New("cannot specify both consumer_name and consumer_arn") + } + + consumerName := conf.ConsumerName + if consumerName == "" && conf.ConsumerARN == "" { + consumerName = "bento-" + clientID + } + + return &kinesisEFOManager{ + streamARN: streamARN, + consumerName: consumerName, + consumerARN: conf.ConsumerARN, + svc: svc, + log: log, + }, nil +} + +// ensureConsumerRegistered registers the consumer if needed and returns the consumer ARN +func (m *kinesisEFOManager) ensureConsumerRegistered(ctx context.Context) (string, error) { + if m.consumerARN != "" { + m.log.Debugf("Using provided consumer ARN: %s", m.consumerARN) + return m.consumerARN, nil + } + + m.log.Debugf("Registering Enhanced Fan Out consumer: %s for stream: %s", m.consumerName, m.streamARN) + + registerInput := &kinesis.RegisterStreamConsumerInput{ + StreamARN: aws.String(m.streamARN), + ConsumerName: aws.String(m.consumerName), + } + + output, err := m.svc.RegisterStreamConsumer(ctx, registerInput) + if err != nil { + var resourceInUse *types.ResourceInUseException + if errors.As(err, &resourceInUse) { + m.log.Debugf("Consumer %s already exists, describing to get ARN", m.consumerName) + return m.describeAndWaitForActive(ctx) + } + return "", fmt.Errorf("failed to register consumer: %w", err) + } + + if output.Consumer == nil || output.Consumer.ConsumerARN == nil { + return "", errors.New("RegisterStreamConsumer succeeded but returned no consumer ARN") + } + + m.consumerARN = *output.Consumer.ConsumerARN + m.log.Debugf("Registered consumer with ARN: %s, waiting for ACTIVE status", m.consumerARN) + + if err := m.waitForActiveConsumer(ctx); err != nil { + return "", fmt.Errorf("failed waiting for consumer to become active: %w", err) + } + + return m.consumerARN, nil +} + +// describeAndWaitForActive describes an existing consumer and waits for it to be active +func (m *kinesisEFOManager) describeAndWaitForActive(ctx context.Context) (string, error) { + describeInput := &kinesis.DescribeStreamConsumerInput{ + StreamARN: aws.String(m.streamARN), + ConsumerName: aws.String(m.consumerName), + } + + output, err := m.svc.DescribeStreamConsumer(ctx, describeInput) + if err != nil { + return "", fmt.Errorf("failed to describe consumer: %w", err) + } + + if output.ConsumerDescription == nil || output.ConsumerDescription.ConsumerARN == nil { + return "", errors.New("consumer description missing ARN") + } + + m.consumerARN = *output.ConsumerDescription.ConsumerARN + m.log.Debugf("Found existing consumer with ARN: %s", m.consumerARN) + + if output.ConsumerDescription.ConsumerStatus == types.ConsumerStatusActive { + m.log.Debugf("Consumer is already ACTIVE") + return m.consumerARN, nil + } + + if err := m.waitForActiveConsumer(ctx); err != nil { + return "", fmt.Errorf("failed waiting for consumer to become active: %w", err) + } + + return m.consumerARN, nil +} + +// waitForActiveConsumer waits for the consumer to reach ACTIVE status +func (m *kinesisEFOManager) waitForActiveConsumer(ctx context.Context) error { + waiterCtx, cancel := context.WithTimeout(ctx, 2*time.Minute) + defer cancel() + + interval := m.pollInterval + if interval == 0 { + interval = 2 * time.Second + } + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + // Check consumer status immediately before waiting for the next tick + describeInput := &kinesis.DescribeStreamConsumerInput{ + ConsumerARN: aws.String(m.consumerARN), + } + + output, err := m.svc.DescribeStreamConsumer(waiterCtx, describeInput) + if err != nil { + return fmt.Errorf("failed to describe consumer: %w", err) + } + + if output.ConsumerDescription != nil { + status := output.ConsumerDescription.ConsumerStatus + m.log.Debugf("Consumer status: %s", status) + + if status == types.ConsumerStatusActive { + m.log.Debugf("Consumer is now ACTIVE") + return nil + } + + if status == types.ConsumerStatusDeleting { + return errors.New("consumer is being deleted") + } + } + + select { + case <-ctx.Done(): + return fmt.Errorf("context cancelled waiting for consumer to become ACTIVE: %w", ctx.Err()) + case <-waiterCtx.Done(): + return fmt.Errorf("timeout waiting for consumer to become ACTIVE: %w", waiterCtx.Err()) + case <-ticker.C: + } + } +} + +// runEFOConsumer consumes from a shard using Enhanced Fan Out +func (k *kinesisReader) runEFOConsumer(wg *sync.WaitGroup, info streamInfo, shardID, startingSequence string) error { + // Create record batcher (same as polling mode) + var recordBatcher *awsKinesisRecordBatcher + var err error + if recordBatcher, err = k.newAWSKinesisRecordBatcher(info, shardID, startingSequence); err != nil { + wg.Done() + if _, checkErr := k.checkpointer.Checkpoint(context.Background(), info.id, shardID, startingSequence, true); checkErr != nil { + k.log.Errorf("Failed to gracefully yield checkpoint: %v\n", checkErr) + } + return err + } + + // Track consumer state + state := awsKinesisConsumerConsuming + var pendingMsg asyncMessage + + // Buffer for pending records from the subscription + var pending []types.Record + + // Channels for subscription control + subscriptionTrigger := make(chan string, 1) // Trigger for initial subscription or resubscription + subscriptionTrigger <- startingSequence // Start with initial sequence + + // Channels for timed batches and message flush + var nextTimedBatchChan <-chan time.Time + var nextFlushChan chan<- asyncMessage + var nextRecordsChan <-chan []types.Record + commitCtx, commitCtxClose := context.WithTimeout(k.ctx, k.commitPeriod) + + go func() { + defer func() { + commitCtxClose() + recordBatcher.Close(context.Background(), state == awsKinesisConsumerFinished) + + // Release any remaining pending records back to the global pool + if len(pending) > 0 { + k.globalPendingPool.Release(len(pending)) + } + + reason := "" + switch state { + case awsKinesisConsumerFinished: + reason = " because the shard is closed" + if err := k.checkpointer.Delete(k.ctx, info.id, shardID); err != nil { + k.log.Errorf("Failed to remove checkpoint for finished stream '%v' shard '%v': %v", info.id, shardID, err) + } + case awsKinesisConsumerYielding: + reason = " because the shard has been claimed by another client" + if err := k.checkpointer.Yield(k.ctx, info.id, shardID, recordBatcher.GetSequence()); err != nil { + k.log.Errorf("Failed to yield checkpoint for stolen stream '%v' shard '%v': %v", info.id, shardID, err) + } + case awsKinesisConsumerClosing: + reason = " because the pipeline is shutting down" + if _, err := k.checkpointer.Checkpoint(context.Background(), info.id, shardID, recordBatcher.GetSequence(), true); err != nil { + k.log.Errorf("Failed to store final checkpoint for stream '%v' shard '%v': %v", info.id, shardID, err) + } + } + + wg.Done() + k.log.Debugf("Closing stream '%v' shard '%v' as client '%v'%v", info.id, shardID, k.checkpointer.clientID, reason) + }() + + k.log.Debugf("Consuming stream '%v' shard '%v' with Enhanced Fan Out as client '%v'", info.id, shardID, k.checkpointer.clientID) + + // Start subscription in a separate goroutine + bufferCap := 0 + if k.conf.EnhancedFanOut != nil { + bufferCap = k.conf.EnhancedFanOut.RecordBufferCap + } + recordsChan := make(chan []types.Record, bufferCap) + // errorsChan is used for logging/monitoring only - subscription goroutine + // handles its own retries. Non-blocking sends handle overflow gracefully. + errorsChan := make(chan error, 1) + resubscribeChan := make(chan string, 1) + shardFinishedChan := make(chan struct{}, 1) + + // drainRecordsChan drains any remaining records from recordsChan after + // the subscription goroutine has stopped, releasing their pool capacity. + // This prevents leaking pool capacity when the consumer exits with buffered records. + drainRecordsChan := func() { + for { + select { + case records := <-recordsChan: + k.globalPendingPool.Release(len(records)) + default: + return + } + } + } + + var subscriptionWg sync.WaitGroup + subscriptionWg.Go(func() { + // Subscription goroutine manages its own backoff for retries + subBoff := backoff.NewExponentialBackOff() + subBoff.InitialInterval = 300 * time.Millisecond + subBoff.MaxInterval = 5 * time.Second + subBoff.MaxElapsedTime = 0 // Never stop retrying + + for sequence := range subscriptionTrigger { + currentSeq := sequence + + // Inner retry loop - keeps trying until success or context cancellation + for { + select { + case <-k.ctx.Done(): + return + default: + } + + continuationSeq, shardFinished, err := k.efoSubscribeAndStream(k.ctx, info, shardID, currentSeq, recordsChan) + + if err != nil { + // Check for non-retryable errors - these should stop the subscription + var resourceNotFound *types.ResourceNotFoundException + var invalidArg *types.InvalidArgumentException + if errors.As(err, &resourceNotFound) || errors.As(err, &invalidArg) { + // Send to errorsChan for main loop to handle shutdown. + // Use blocking send (with context) to ensure fatal errors are not dropped. + k.log.Errorf("Non-retryable EFO error for shard %v: %v", shardID, err) + select { + case <-k.ctx.Done(): + case errorsChan <- err: + } + return // Stop retrying + } + + // Log retryable error (non-blocking) + select { + case errorsChan <- err: + default: + // Channel full, just log locally + if errors.Is(err, errBackpressureTimeout) { + k.log.Debugf("EFO backpressure timeout for shard %v, will retry", shardID) + } else { + k.log.Warnf("EFO subscription error for shard %v, will retry: %v", shardID, err) + } + } + + // Update sequence for retry if we got a continuation + if continuationSeq != "" { + currentSeq = continuationSeq + } + + // Backoff before retry + backoffDuration := subBoff.NextBackOff() + select { + case <-k.ctx.Done(): + return + case <-time.After(backoffDuration): + } + continue // Retry the subscription + } + + // Success - reset backoff + subBoff.Reset() + + if shardFinished { + // Shard is closed, signal to main loop + select { + case shardFinishedChan <- struct{}{}: + default: + } + return + } + + // Subscription completed normally, update sequence and notify main loop + if continuationSeq != "" { + currentSeq = continuationSeq + } + select { + case <-k.ctx.Done(): + return + case resubscribeChan <- currentSeq: + } + break // Exit retry loop, wait for next trigger from main loop + } + } + }) + + // Main consumer loop (similar to polling consumer) + for { + if pendingMsg.msg == nil { + // If our consumer is finished and we've run out of pending + // records then we're done. + if len(pending) == 0 && state == awsKinesisConsumerFinished { + if pendingMsg, _ = recordBatcher.FlushMessage(k.ctx); pendingMsg.msg == nil { + close(subscriptionTrigger) + subscriptionWg.Wait() + drainRecordsChan() + return + } + } else if recordBatcher.HasPendingMessage() { + var err error + if pendingMsg, err = recordBatcher.FlushMessage(commitCtx); err != nil { + k.log.Errorf("Failed to dispatch message due to checkpoint error: %v\n", err) + } + } else if len(pending) > 0 { + var i int + var r types.Record + for i, r = range pending { + if recordBatcher.AddRecord(r) { + var err error + if pendingMsg, err = recordBatcher.FlushMessage(commitCtx); err != nil { + k.log.Errorf("Failed to dispatch message due to checkpoint error: %v\n", err) + } + break + } + } + // Release processed records back to the global pool + processedCount := i + 1 + k.globalPendingPool.Release(processedCount) + pending = pending[processedCount:] + } + } + + // Decide whether to flush + if pendingMsg.msg != nil { + nextFlushChan = k.msgChan + } else { + nextFlushChan = nil + } + + // Always listen for records - backpressure is applied in efoSubscribeAndStream + // via globalPendingPool.Acquire() before sending to recordsChan + nextRecordsChan = recordsChan + + if nextTimedBatchChan == nil { + if tNext, exists := recordBatcher.UntilNext(); exists { + nextTimedBatchChan = time.After(tNext) + } + } + + select { + case <-commitCtx.Done(): + if k.ctx.Err() != nil { + state = awsKinesisConsumerClosing + close(subscriptionTrigger) + subscriptionWg.Wait() + drainRecordsChan() + return + } + + commitCtxClose() + commitCtx, commitCtxClose = context.WithTimeout(k.ctx, k.commitPeriod) + + if state == awsKinesisConsumerConsuming { + stillOwned, err := k.checkpointer.Checkpoint(k.ctx, info.id, shardID, recordBatcher.GetSequence(), false) + if err != nil { + k.log.Errorf("Failed to store checkpoint for Kinesis stream '%v' shard '%v': %v", info.id, shardID, err) + } else if !stillOwned { + state = awsKinesisConsumerYielding + close(subscriptionTrigger) + subscriptionWg.Wait() + drainRecordsChan() + return + } + } + + case <-nextTimedBatchChan: + nextTimedBatchChan = nil + + case nextFlushChan <- pendingMsg: + pendingMsg = asyncMessage{} + + case records := <-nextRecordsChan: + // Received records from subscription + // Space was already acquired in efoSubscribeAndStream before sending + pending = append(pending, records...) + + case err := <-errorsChan: + // Subscription error received - log it. + // The subscription goroutine handles its own retry logic with backoff, + // so we don't need to trigger resubscription from here. + var resourceNotFound *types.ResourceNotFoundException + var invalidArg *types.InvalidArgumentException + + if errors.As(err, &resourceNotFound) || errors.As(err, &invalidArg) { + // Non-retryable errors are still fatal + k.log.Errorf("Non-retryable EFO error for shard %v: %v", shardID, err) + state = awsKinesisConsumerClosing + close(subscriptionTrigger) + subscriptionWg.Wait() + drainRecordsChan() + return + } + + // Log retryable errors (subscription goroutine handles retry) + if errors.Is(err, errBackpressureTimeout) { + k.log.Debugf("EFO backpressure timeout for shard %v, subscription goroutine will retry", shardID) + } else { + k.log.Warnf("EFO subscription error for shard %v, subscription goroutine will retry: %v", shardID, err) + } + + case sequence := <-resubscribeChan: + // Subscription completed successfully, resubscribe immediately to maintain continuous data flow + select { + case subscriptionTrigger <- sequence: + case <-k.ctx.Done(): + } + + case <-shardFinishedChan: + // Shard is closed, mark as finished so we can drain pending records + state = awsKinesisConsumerFinished + + case <-k.ctx.Done(): + state = awsKinesisConsumerClosing + close(subscriptionTrigger) + subscriptionWg.Wait() + drainRecordsChan() + return + } + } + }() + + return nil +} + +// efoSubscribeAndStream subscribes to a shard and streams records to a channel +// Returns: continuationSequence, shardFinished, error +func (k *kinesisReader) efoSubscribeAndStream(ctx context.Context, info streamInfo, shardID, startingSequence string, recordsChan chan<- []types.Record) (string, bool, error) { + if info.efoManager == nil || info.efoManager.consumerARN == "" { + return "", false, errors.New("EFO manager or consumer ARN not initialized") + } + + // Build starting position + var startingPosition *types.StartingPosition + if startingSequence == "" { + // No sequence yet, use TRIM_HORIZON or LATEST based on config + if k.conf.StartFromOldest { + startingPosition = &types.StartingPosition{ + Type: types.ShardIteratorTypeTrimHorizon, + } + } else { + startingPosition = &types.StartingPosition{ + Type: types.ShardIteratorTypeLatest, + } + } + } else { + // Continue from last sequence + startingPosition = &types.StartingPosition{ + Type: types.ShardIteratorTypeAfterSequenceNumber, + SequenceNumber: aws.String(startingSequence), + } + } + + k.log.Debugf("Subscribing to shard %v with sequence %v", shardID, startingSequence) + + input := &kinesis.SubscribeToShardInput{ + ConsumerARN: aws.String(info.efoManager.consumerARN), + ShardId: aws.String(shardID), + StartingPosition: startingPosition, + } + + output, err := k.svc.SubscribeToShard(ctx, input) + if err != nil { + return "", false, fmt.Errorf("failed to subscribe to shard: %w", err) + } + + // Process the event stream + eventStream := output.GetStream() + defer eventStream.Close() + + continuationSeq := "" + lastReceivedSeq := "" + shardFinished := false + eventsChan := eventStream.Events() + + // Timeout for waiting on backpressure - if we wait too long, close the subscription + // cleanly and resubscribe rather than letting AWS forcibly terminate the connection. + // 30 seconds is well under the 5-minute EFO subscription timeout. + const backpressureTimeout = 30 * time.Second + + for { + // Wait for space in the global pool before fetching the next event + // This applies backpressure to Kinesis before data enters memory + switch k.globalPendingPool.WaitForSpace(ctx, backpressureTimeout) { + case WaitForSpaceCancelled: + // Context cancelled + if continuationSeq == "" { + continuationSeq = lastReceivedSeq + } + return continuationSeq, false, ctx.Err() + case WaitForSpaceTimeout: + // Backpressure timeout - close subscription cleanly and return error to trigger backoff + // This prevents AWS from forcibly terminating the connection after extended inactivity + // and ensures we don't immediately resubscribe while backpressure persists + k.log.Debugf("Backpressure timeout for shard %v, closing subscription to resubscribe with backoff", shardID) + if continuationSeq == "" { + continuationSeq = lastReceivedSeq + } + return continuationSeq, false, errBackpressureTimeout + case WaitForSpaceOK: + // Space available, continue + } + + // Now fetch the next event + event, ok := <-eventsChan + if !ok { + break + } + + switch e := event.(type) { + case *types.SubscribeToShardEventStreamMemberSubscribeToShardEvent: + // Got records event + shardEvent := e.Value + + // Send records to channel and track last received sequence + if len(shardEvent.Records) > 0 { + // Acquire the actual space for this batch + if !k.globalPendingPool.Acquire(ctx, len(shardEvent.Records)) { + // Context cancelled, return with current sequence + if continuationSeq == "" { + continuationSeq = lastReceivedSeq + } + return continuationSeq, false, ctx.Err() + } + + // Track the last record's sequence number for fallback + lastRecord := shardEvent.Records[len(shardEvent.Records)-1] + if lastRecord.SequenceNumber != nil { + lastReceivedSeq = *lastRecord.SequenceNumber + } + select { + case recordsChan <- shardEvent.Records: + case <-ctx.Done(): + // Release the acquired space since we couldn't send + k.globalPendingPool.Release(len(shardEvent.Records)) + // Use lastReceivedSeq as fallback if continuationSeq not set + if continuationSeq == "" { + continuationSeq = lastReceivedSeq + } + return continuationSeq, false, ctx.Err() + } + } + + // Update continuation sequence for next subscription + if shardEvent.ContinuationSequenceNumber != nil { + continuationSeq = *shardEvent.ContinuationSequenceNumber + } + + // Check if shard is closed (has child shards) + if len(shardEvent.ChildShards) > 0 { + k.log.Debugf("Shard %v is closed, child shards: %v", shardID, len(shardEvent.ChildShards)) + shardFinished = true + } + + if shardEvent.MillisBehindLatest != nil { + k.log.Debugf("Shard %v is %d milliseconds behind latest", shardID, *shardEvent.MillisBehindLatest) + } + + default: + k.log.Warnf("Unknown event type received: %T", event) + } + } + + // Use lastReceivedSeq as fallback if continuationSeq not set + if continuationSeq == "" { + continuationSeq = lastReceivedSeq + } + + // Check for stream errors + if err := eventStream.Err(); err != nil { + // Check if it's just end of stream + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return continuationSeq, shardFinished, nil + } + return continuationSeq, shardFinished, fmt.Errorf("error receiving event: %w", err) + } + + return continuationSeq, shardFinished, nil +} diff --git a/internal/impl/aws/input_kinesis_efo_test.go b/internal/impl/aws/input_kinesis_efo_test.go new file mode 100644 index 0000000000..6d5ba72d65 --- /dev/null +++ b/internal/impl/aws/input_kinesis_efo_test.go @@ -0,0 +1,210 @@ +package aws + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/kinesis" + "github.com/aws/aws-sdk-go-v2/service/kinesis/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/warpstreamlabs/bento/public/service" +) + +// mockEFOClient implements kinesisEFOAPI for testing. +type mockEFOClient struct { + describeResponses []mockDescribeResponse + describeCallCount int + // repeatLast causes the last response to be repeated indefinitely once all + // responses are exhausted. Useful for testing context cancellation. + repeatLast bool +} + +type mockDescribeResponse struct { + status types.ConsumerStatus + err error +} + +func (m *mockEFOClient) DescribeStreamConsumer(_ context.Context, _ *kinesis.DescribeStreamConsumerInput, _ ...func(*kinesis.Options)) (*kinesis.DescribeStreamConsumerOutput, error) { + idx := m.describeCallCount + m.describeCallCount++ + if idx >= len(m.describeResponses) { + if m.repeatLast && len(m.describeResponses) > 0 { + // Repeat the last response indefinitely (useful for context-cancel tests) + idx = len(m.describeResponses) - 1 + } else { + // Default: return ACTIVE if we run out of responses + return &kinesis.DescribeStreamConsumerOutput{ + ConsumerDescription: &types.ConsumerDescription{ + ConsumerStatus: types.ConsumerStatusActive, + }, + }, nil + } + } + resp := m.describeResponses[idx] + if resp.err != nil { + return nil, resp.err + } + return &kinesis.DescribeStreamConsumerOutput{ + ConsumerDescription: &types.ConsumerDescription{ + ConsumerStatus: resp.status, + }, + }, nil +} + +func (m *mockEFOClient) RegisterStreamConsumer(_ context.Context, _ *kinesis.RegisterStreamConsumerInput, _ ...func(*kinesis.Options)) (*kinesis.RegisterStreamConsumerOutput, error) { + return &kinesis.RegisterStreamConsumerOutput{ + Consumer: &types.Consumer{ + ConsumerARN: aws.String("arn:aws:kinesis:us-east-1:123:stream/test/consumer/bento-test"), + }, + }, nil +} + +func newTestEFOManager(client kinesisEFOAPI) *kinesisEFOManager { + return &kinesisEFOManager{ + streamARN: "arn:aws:kinesis:us-east-1:123:stream/test", + consumerARN: "arn:aws:kinesis:us-east-1:123:stream/test/consumer/bento-test", + svc: client, + log: service.MockResources().Logger(), + pollInterval: 5 * time.Millisecond, // fast polling for unit tests + } +} + +// TestWaitForActiveConsumer_AlreadyActive checks that if the consumer is already +// ACTIVE on the first describe call, the function returns immediately without +// waiting for a tick. +func TestWaitForActiveConsumer_AlreadyActive(t *testing.T) { + client := &mockEFOClient{ + describeResponses: []mockDescribeResponse{ + {status: types.ConsumerStatusActive}, + }, + } + mgr := newTestEFOManager(client) + + start := time.Now() + err := mgr.waitForActiveConsumer(context.Background()) + elapsed := time.Since(start) + + require.NoError(t, err) + assert.Equal(t, 1, client.describeCallCount, "should describe exactly once") + // Should return well under the poll interval (5ms in tests, 2s in production) + assert.Less(t, elapsed, 50*time.Millisecond, "should return immediately, not wait for tick") +} + +// TestWaitForActiveConsumer_TransitionsToActive checks that the function polls +// until the consumer becomes ACTIVE after a few CREATING responses. +func TestWaitForActiveConsumer_TransitionsToActive(t *testing.T) { + client := &mockEFOClient{ + describeResponses: []mockDescribeResponse{ + {status: types.ConsumerStatusCreating}, + {status: types.ConsumerStatusCreating}, + {status: types.ConsumerStatusActive}, + }, + } + mgr := newTestEFOManager(client) + + err := mgr.waitForActiveConsumer(context.Background()) + + require.NoError(t, err) + assert.Equal(t, 3, client.describeCallCount, "should describe until ACTIVE") +} + +// TestWaitForActiveConsumer_DeletingState checks that a DELETING status +// returns an error immediately. +func TestWaitForActiveConsumer_DeletingState(t *testing.T) { + client := &mockEFOClient{ + describeResponses: []mockDescribeResponse{ + {status: types.ConsumerStatusDeleting}, + }, + } + mgr := newTestEFOManager(client) + + err := mgr.waitForActiveConsumer(context.Background()) + + require.Error(t, err) + assert.Contains(t, err.Error(), "consumer is being deleted") +} + +// TestWaitForActiveConsumer_ParentContextCancelled checks that cancelling the +// parent context returns a "context cancelled" error (not a "timeout" error). +func TestWaitForActiveConsumer_ParentContextCancelled(t *testing.T) { + // Stay in CREATING indefinitely so the function has to wait + client := &mockEFOClient{ + describeResponses: []mockDescribeResponse{ + {status: types.ConsumerStatusCreating}, + }, + repeatLast: true, // keep returning CREATING so we never accidentally succeed + } + mgr := newTestEFOManager(client) + + ctx, cancel := context.WithCancel(context.Background()) + // Cancel the context after a short delay so the loop reaches the select + go func() { + time.Sleep(50 * time.Millisecond) + cancel() + }() + + err := mgr.waitForActiveConsumer(ctx) + + require.Error(t, err) + assert.Contains(t, err.Error(), "context cancelled", "should report cancellation, not timeout") + assert.False(t, errors.Is(err, context.DeadlineExceeded), "should not be a deadline error") + assert.True(t, errors.Is(err, context.Canceled), "should wrap context.Canceled") +} + +// TestWaitForActiveConsumer_InternalTimeoutDistinctFromCancellation checks that +// when the internal 2-minute waiterCtx expires it produces a "timeout" error, +// distinct from a parent-context cancellation. We use a very short timeout so +// the test doesn't actually wait 2 minutes. +func TestWaitForActiveConsumer_InternalTimeout(t *testing.T) { + // Override the ticker to a short interval and ensure status never becomes ACTIVE + client := &mockEFOClient{ + // Many CREATING responses so the loop keeps going + describeResponses: func() []mockDescribeResponse { + resps := make([]mockDescribeResponse, 100) + for i := range resps { + resps[i] = mockDescribeResponse{status: types.ConsumerStatusCreating} + } + return resps + }(), + } + mgr := newTestEFOManager(client) + + // Use a context that times out very quickly to simulate the internal waiterCtx + // expiring before the parent ctx is cancelled. + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + err := mgr.waitForActiveConsumer(ctx) + + require.Error(t, err) + // The error could be either "context cancelled" (from ctx.Done) or "timeout" + // (from waiterCtx.Done) depending on which fires first, but either way it + // must NOT succeed and must contain a meaningful message. + assert.True(t, + errors.Is(err, context.DeadlineExceeded) || errors.Is(err, context.Canceled), + "error should wrap a context error, got: %v", err, + ) +} + +// TestWaitForActiveConsumer_DescribeError checks that an error from +// DescribeStreamConsumer is propagated immediately. +func TestWaitForActiveConsumer_DescribeError(t *testing.T) { + describeErr := errors.New("network error") + client := &mockEFOClient{ + describeResponses: []mockDescribeResponse{ + {err: describeErr}, + }, + } + mgr := newTestEFOManager(client) + + err := mgr.waitForActiveConsumer(context.Background()) + + require.Error(t, err) + assert.ErrorContains(t, err, "failed to describe consumer") + assert.ErrorIs(t, err, describeErr) +} diff --git a/internal/impl/aws/input_kinesis_pending_pool.go b/internal/impl/aws/input_kinesis_pending_pool.go new file mode 100644 index 0000000000..4689cadc4a --- /dev/null +++ b/internal/impl/aws/input_kinesis_pending_pool.go @@ -0,0 +1,160 @@ +package aws + +import ( + "context" + "sync" + "time" +) + +// globalPendingPool limits the total number of pending records across all shards. +// Each shard must acquire space from this pool before accepting records from Kinesis, +// ensuring bounded memory usage regardless of shard count. +type globalPendingPool struct { + mu sync.Mutex + cond *sync.Cond + current int + max int +} + +// newGlobalPendingPool creates a new pool with the specified maximum capacity. +func newGlobalPendingPool(maximum int) *globalPendingPool { + p := &globalPendingPool{ + max: maximum, + } + p.cond = sync.NewCond(&p.mu) + return p +} + +// Acquire acquires space for count records, blocking if necessary until space is available. +// Returns false immediately if count > max (impossible to satisfy) or if ctx is cancelled. +func (p *globalPendingPool) Acquire(ctx context.Context, count int) bool { + p.mu.Lock() + defer p.mu.Unlock() + + // If the requested count exceeds the pool's maximum capacity, it can never be satisfied. + // Return false immediately to avoid blocking indefinitely. + if count > p.max { + return false + } + + // Start a goroutine to handle context cancellation by broadcasting to wake up waiters + done := make(chan struct{}) + defer close(done) + go func() { + select { + case <-ctx.Done(): + p.cond.Broadcast() + case <-done: + } + }() + + for p.current+count > p.max { + // Check if context is cancelled before waiting + if ctx.Err() != nil { + return false + } + p.cond.Wait() + } + + // Final check after waking up + if ctx.Err() != nil { + return false + } + + p.current += count + return true +} + +// WaitForSpaceResult indicates the outcome of WaitForSpace. +type WaitForSpaceResult int + +const ( + // WaitForSpaceOK indicates space is available. + WaitForSpaceOK WaitForSpaceResult = iota + // WaitForSpaceCancelled indicates the context was cancelled. + WaitForSpaceCancelled + // WaitForSpaceTimeout indicates the timeout was reached while waiting. + WaitForSpaceTimeout +) + +// WaitForSpace blocks until there is any space available in the pool. +// This is used to apply backpressure before fetching new data from Kinesis. +// Returns WaitForSpaceOK if space is available, WaitForSpaceCancelled if +// context is cancelled, or WaitForSpaceTimeout if the timeout is reached. +// A timeout of 0 means no timeout (wait indefinitely). +func (p *globalPendingPool) WaitForSpace(ctx context.Context, timeout time.Duration) WaitForSpaceResult { + p.mu.Lock() + defer p.mu.Unlock() + + // Set up deadline if timeout is specified + var deadline time.Time + var timer *time.Timer + if timeout > 0 { + deadline = time.Now().Add(timeout) + timer = time.NewTimer(timeout) + defer timer.Stop() + } + + // Start a goroutine to handle context cancellation and timeout by broadcasting + done := make(chan struct{}) + defer close(done) + go func() { + if timer != nil { + select { + case <-ctx.Done(): + p.cond.Broadcast() + case <-timer.C: + p.cond.Broadcast() + case <-done: + } + } else { + select { + case <-ctx.Done(): + p.cond.Broadcast() + case <-done: + } + } + }() + + for p.current >= p.max { + // Check if context is cancelled before waiting + if ctx.Err() != nil { + return WaitForSpaceCancelled + } + + // Check if we've exceeded the timeout + if timeout > 0 && time.Now().After(deadline) { + return WaitForSpaceTimeout + } + + p.cond.Wait() + } + + // Final checks after waking up + if ctx.Err() != nil { + return WaitForSpaceCancelled + } + if timeout > 0 && time.Now().After(deadline) { + return WaitForSpaceTimeout + } + + return WaitForSpaceOK +} + +// Release returns count records worth of space back to the pool. +func (p *globalPendingPool) Release(count int) { + p.mu.Lock() + p.current -= count + if p.current < 0 { + p.current = 0 + } + p.cond.Broadcast() + p.mu.Unlock() +} + +// Current returns the current number of records in the pool (for monitoring/debugging). +func (p *globalPendingPool) Current() int { + p.mu.Lock() + defer p.mu.Unlock() + return p.current +} diff --git a/internal/impl/aws/input_kinesis_pending_pool_test.go b/internal/impl/aws/input_kinesis_pending_pool_test.go new file mode 100644 index 0000000000..4ef84beacf --- /dev/null +++ b/internal/impl/aws/input_kinesis_pending_pool_test.go @@ -0,0 +1,246 @@ +package aws + +import ( + "context" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGlobalPendingPool_AcquireRelease(t *testing.T) { + pool := newGlobalPendingPool(100) + + // Acquire some space + assert.True(t, pool.Acquire(context.Background(), 50)) + assert.Equal(t, 50, pool.Current()) + + // Acquire more space + assert.True(t, pool.Acquire(context.Background(), 30)) + assert.Equal(t, 80, pool.Current()) + + // Release some space + pool.Release(20) + assert.Equal(t, 60, pool.Current()) + + // Release all + pool.Release(60) + assert.Equal(t, 0, pool.Current()) +} + +func TestGlobalPendingPool_AcquireExceedsMax(t *testing.T) { + pool := newGlobalPendingPool(100) + + // Trying to acquire more than max should fail immediately + assert.False(t, pool.Acquire(context.Background(), 150)) + assert.Equal(t, 0, pool.Current()) +} + +func TestGlobalPendingPool_AcquireBlocksUntilSpace(t *testing.T) { + pool := newGlobalPendingPool(100) + + // Fill the pool + assert.True(t, pool.Acquire(context.Background(), 100)) + + // Start a goroutine that will try to acquire more + acquired := make(chan bool, 1) + go func() { + acquired <- pool.Acquire(context.Background(), 50) + }() + + // Give the goroutine time to start waiting + time.Sleep(50 * time.Millisecond) + + // Release some space + pool.Release(50) + + // The acquire should succeed now + select { + case result := <-acquired: + assert.True(t, result) + case <-time.After(time.Second): + t.Fatal("Acquire should have succeeded after Release") + } + + assert.Equal(t, 100, pool.Current()) +} + +func TestGlobalPendingPool_AcquireContextCancellation(t *testing.T) { + pool := newGlobalPendingPool(100) + + // Fill the pool + assert.True(t, pool.Acquire(context.Background(), 100)) + + // Create a cancellable context + ctx, cancel := context.WithCancel(context.Background()) + + // Start a goroutine that will try to acquire more + acquired := make(chan bool, 1) + go func() { + acquired <- pool.Acquire(ctx, 50) + }() + + // Give the goroutine time to start waiting + time.Sleep(50 * time.Millisecond) + + // Cancel the context + cancel() + + // The acquire should fail due to cancellation + select { + case result := <-acquired: + assert.False(t, result) + case <-time.After(time.Second): + t.Fatal("Acquire should have returned after context cancellation") + } + + // Pool should still have the original 100 + assert.Equal(t, 100, pool.Current()) +} + +func TestGlobalPendingPool_WaitForSpace(t *testing.T) { + pool := newGlobalPendingPool(100) + + // With space available, should return immediately + result := pool.WaitForSpace(context.Background(), 0) + assert.Equal(t, WaitForSpaceOK, result) + + // Fill the pool + assert.True(t, pool.Acquire(context.Background(), 100)) + + // Start a goroutine that will wait for space + waitResult := make(chan WaitForSpaceResult, 1) + go func() { + waitResult <- pool.WaitForSpace(context.Background(), 0) + }() + + // Give the goroutine time to start waiting + time.Sleep(50 * time.Millisecond) + + // Release some space + pool.Release(10) + + // The wait should succeed now + select { + case result := <-waitResult: + assert.Equal(t, WaitForSpaceOK, result) + case <-time.After(time.Second): + t.Fatal("WaitForSpace should have returned after Release") + } +} + +func TestGlobalPendingPool_WaitForSpaceTimeout(t *testing.T) { + pool := newGlobalPendingPool(100) + + // Fill the pool + assert.True(t, pool.Acquire(context.Background(), 100)) + + // Wait with a short timeout + start := time.Now() + result := pool.WaitForSpace(context.Background(), 100*time.Millisecond) + elapsed := time.Since(start) + + assert.Equal(t, WaitForSpaceTimeout, result) + assert.GreaterOrEqual(t, elapsed, 100*time.Millisecond) + assert.Less(t, elapsed, 200*time.Millisecond) // Should not take too long +} + +func TestGlobalPendingPool_WaitForSpaceContextCancellation(t *testing.T) { + pool := newGlobalPendingPool(100) + + // Fill the pool + assert.True(t, pool.Acquire(context.Background(), 100)) + + // Create a cancellable context + ctx, cancel := context.WithCancel(context.Background()) + + // Start a goroutine that will wait for space + waitResult := make(chan WaitForSpaceResult, 1) + go func() { + waitResult <- pool.WaitForSpace(ctx, 0) + }() + + // Give the goroutine time to start waiting + time.Sleep(50 * time.Millisecond) + + // Cancel the context + cancel() + + // The wait should return cancelled + select { + case result := <-waitResult: + assert.Equal(t, WaitForSpaceCancelled, result) + case <-time.After(time.Second): + t.Fatal("WaitForSpace should have returned after context cancellation") + } +} + +func TestGlobalPendingPool_ReleaseBelowZero(t *testing.T) { + pool := newGlobalPendingPool(100) + + // Release more than current (should clamp to 0) + pool.Release(50) + assert.Equal(t, 0, pool.Current()) + + // Acquire and release more than acquired + assert.True(t, pool.Acquire(context.Background(), 30)) + pool.Release(50) + assert.Equal(t, 0, pool.Current()) +} + +func TestGlobalPendingPool_ConcurrentAccess(t *testing.T) { + pool := newGlobalPendingPool(1000) + + var wg sync.WaitGroup + acquireCount := 100 + acquireSize := 10 + + // Start many goroutines acquiring and releasing + for range acquireCount { + wg.Go(func() { + require.True(t, pool.Acquire(context.Background(), acquireSize)) + time.Sleep(10 * time.Millisecond) + pool.Release(acquireSize) + }) + } + + wg.Wait() + assert.Equal(t, 0, pool.Current()) +} + +func TestGlobalPendingPool_WaitForSpaceVsAcquire(t *testing.T) { + // Test that WaitForSpace returns as soon as current < max, + // but Acquire might still need to wait if current + count > max + pool := newGlobalPendingPool(100) + + // Fill to 90 + assert.True(t, pool.Acquire(context.Background(), 90)) + + // WaitForSpace should return OK (there's space) + result := pool.WaitForSpace(context.Background(), 100*time.Millisecond) + assert.Equal(t, WaitForSpaceOK, result) + + // But Acquire of 20 should block until more space is released + acquired := make(chan bool, 1) + go func() { + acquired <- pool.Acquire(context.Background(), 20) + }() + + // Give the goroutine time to start waiting + time.Sleep(50 * time.Millisecond) + + // Release 10 more to make room + pool.Release(10) + + // Now Acquire should succeed + select { + case result := <-acquired: + assert.True(t, result) + case <-time.After(time.Second): + t.Fatal("Acquire should have succeeded after Release") + } + + assert.Equal(t, 100, pool.Current()) +} diff --git a/internal/impl/aws/input_kinesis_test.go b/internal/impl/aws/input_kinesis_test.go index 4f80a7c9e3..a44c8cc283 100644 --- a/internal/impl/aws/input_kinesis_test.go +++ b/internal/impl/aws/input_kinesis_test.go @@ -3,6 +3,8 @@ package aws import ( "testing" + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/kinesis/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -64,3 +66,59 @@ func TestStreamIDParser(t *testing.T) { }) } } + +func TestIsShardFinished(t *testing.T) { + tests := []struct { + name string + shard types.Shard + expected bool + }{ + { + name: "open shard - no ending sequence", + shard: types.Shard{ + ShardId: aws.String("shardId-000000000001"), + SequenceNumberRange: &types.SequenceNumberRange{ + StartingSequenceNumber: aws.String("49671246667567228643283430150187087032206582658"), + }, + }, + expected: false, + }, + { + name: "closed shard - has ending sequence", + shard: types.Shard{ + ShardId: aws.String("shardId-000000000001"), + SequenceNumberRange: &types.SequenceNumberRange{ + StartingSequenceNumber: aws.String("49671246667567228643283430150187087032206582658"), + EndingSequenceNumber: aws.String("49671246667589458717803282320587893555896035326582658"), + }, + }, + expected: true, + }, + { + name: "closed shard - ending sequence is null string", + shard: types.Shard{ + ShardId: aws.String("shardId-000000000001"), + SequenceNumberRange: &types.SequenceNumberRange{ + StartingSequenceNumber: aws.String("49671246667567228643283430150187087032206582658"), + EndingSequenceNumber: aws.String("null"), + }, + }, + expected: false, + }, + { + name: "shard with no sequence number range", + shard: types.Shard{ + ShardId: aws.String("shardId-000000000001"), + }, + expected: false, + }, + } + + for _, test := range tests { + test := test + t.Run(test.name, func(t *testing.T) { + result := isShardFinished(test.shard) + assert.Equal(t, test.expected, result) + }) + } +} diff --git a/internal/impl/aws/integration_kinesis_test.go b/internal/impl/aws/integration_kinesis_test.go index 89cec253ad..a8d382fabe 100644 --- a/internal/impl/aws/integration_kinesis_test.go +++ b/internal/impl/aws/integration_kinesis_test.go @@ -69,7 +69,13 @@ func createKinesisShards(ctx context.Context, t testing.TB, awsPort, id string, return shards, nil } -func kinesisIntegrationSuite(t *testing.T, lsPort string) { +func kinesisIntegrationSuite(t *testing.T, lsPort string, efoEnabled bool) { + var efoYaml string + if efoEnabled { + efoYaml = `enhanced_fan_out: + enabled: true` + } + template := ` output: aws_kinesis: @@ -95,6 +101,7 @@ input: create: true start_from_oldest: true region: us-east-1 + ` + efoYaml + ` credentials: id: xxxxx secret: xxxxx diff --git a/internal/impl/aws/integration_test.go b/internal/impl/aws/integration_test.go index 3a36ceaa1f..5a6b8e7d03 100644 --- a/internal/impl/aws/integration_test.go +++ b/internal/impl/aws/integration_test.go @@ -15,7 +15,11 @@ func TestIntegration(t *testing.T) { servicePort := GetLocalStack(t, nil) t.Run("kinesis", func(t *testing.T) { - kinesisIntegrationSuite(t, servicePort) + kinesisIntegrationSuite(t, servicePort, false) + }) + + t.Run("kinesis_efo", func(t *testing.T) { + kinesisIntegrationSuite(t, servicePort, true) }) t.Run("s3", func(t *testing.T) { diff --git a/website/docs/components/inputs/aws_kinesis.md b/website/docs/components/inputs/aws_kinesis.md index 7ba76182d3..5a6f64d57e 100644 --- a/website/docs/components/inputs/aws_kinesis.md +++ b/website/docs/components/inputs/aws_kinesis.md @@ -69,6 +69,12 @@ input: rebalance_period: 30s lease_period: 30s start_from_oldest: true + enhanced_fan_out: + enabled: false + consumer_name: "" + consumer_arn: "" + record_buffer_cap: 0 + max_pending_records: 50000 region: "" endpoint: "" credentials: @@ -222,6 +228,53 @@ Whether to consume from the oldest message when a sequence does not yet exist fo Type: `bool` Default: `true` +### `enhanced_fan_out` + +Enhanced Fan Out configuration for push-based streaming. Provides dedicated 2 MB/sec throughput per consumer per shard and lower latency (~70ms). Note: EFO incurs per shard-hour charges. + + +Type: `object` + +### `enhanced_fan_out.enabled` + +Enable Enhanced Fan Out mode for push-based streaming with dedicated throughput. + + +Type: `bool` +Default: `false` + +### `enhanced_fan_out.consumer_name` + +Consumer name for EFO registration. Auto-generated if empty: bento-clientID. + + +Type: `string` +Default: `""` + +### `enhanced_fan_out.consumer_arn` + +Existing consumer ARN to use. If provided, skips registration. + + +Type: `string` +Default: `""` + +### `enhanced_fan_out.record_buffer_cap` + +Buffer capacity for the internal records channel per shard. Lower values reduce memory usage when processing many shards. Set to 0 for unbuffered channel (minimal memory footprint). + + +Type: `int` +Default: `0` + +### `enhanced_fan_out.max_pending_records` + +Maximum total number of records to buffer across all shards before applying backpressure to Kinesis subscriptions. This provides a global memory bound regardless of shard count. Higher values improve throughput by allowing shards to continue receiving data while processing, but increase memory usage. Total memory usage is approximately max_pending_records × average_record_size. + + +Type: `int` +Default: `50000` + ### `region` The AWS region to target.