Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 70 additions & 65 deletions internal/xds/clients/xdsclient/ads_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"context"
"fmt"
"sync"
"sync/atomic"
"time"

"google.golang.org/grpc/grpclog"
Expand Down Expand Up @@ -103,11 +102,11 @@ type adsStreamImpl struct {
requestCh *buffer.Unbounded // Subscriptions and unsubscriptions are pushed here.
runnerDoneCh chan struct{} // Notify completion of runner goroutine.
cancel context.CancelFunc // To cancel the context passed to the runner goroutine.
fc *adsFlowControl // Flow control for ADS stream.

// Guards access to the below fields (and to the contents of the map).
mu sync.Mutex
resourceTypeState map[ResourceType]*resourceTypeState // Map of resource types to their state.
fc *adsFlowControl // Flow control for ADS stream.
firstRequest bool // False after the first request is sent out.
}

Expand Down Expand Up @@ -135,6 +134,7 @@ func newADSStreamImpl(opts adsStreamOpts) *adsStreamImpl {
streamCh: make(chan clients.Stream, 1),
requestCh: buffer.NewUnbounded(),
runnerDoneCh: make(chan struct{}),
fc: newADSFlowControl(),
resourceTypeState: make(map[ResourceType]*resourceTypeState),
}

Expand All @@ -150,6 +150,7 @@ func newADSStreamImpl(opts adsStreamOpts) *adsStreamImpl {
// Stop blocks until the stream is closed and all spawned goroutines exit.
func (s *adsStreamImpl) Stop() {
s.cancel()
s.fc.stop()
s.requestCh.Close()
<-s.runnerDoneCh
s.logger.Infof("Shutdown ADS stream")
Expand Down Expand Up @@ -240,9 +241,6 @@ func (s *adsStreamImpl) runner(ctx context.Context) {
}

s.mu.Lock()
// Flow control is a property of the underlying streaming RPC call and
// needs to be initialized everytime a new one is created.
s.fc = newADSFlowControl(s.logger)
s.firstRequest = true
s.mu.Unlock()

Expand All @@ -256,7 +254,7 @@ func (s *adsStreamImpl) runner(ctx context.Context) {

// Backoff state is reset upon successful receipt of at least one
// message from the server.
if s.recv(ctx, stream) {
if s.recv(stream) {
return backoff.ErrResetBackoff
}
return nil
Expand Down Expand Up @@ -318,11 +316,13 @@ func (s *adsStreamImpl) sendNew(stream clients.Stream, typ ResourceType) error {
// This allows us to batch writes for requests which are generated as part
// of local processing of a received response.
state := s.resourceTypeState[typ]
if s.fc.pending.Load() {
bufferRequest := func() {
select {
case state.bufferedRequests <- struct{}{}:
Copy link
Contributor

@arjan-bal arjan-bal Sep 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated to the changes in this PR, in all the reads/writes of bufferedRequests, there's a default case. If I understand correctly, this means that none of the references are blocking, they're just checking/changing if the channel is empty (true/false). Does it make sense to replace the channel with an atomic boolean? This can be done in a separate PR to keep this PR focused on the bug fix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually that is true. We can replace it with an atomic boolean and it will definitely simplify the code more. Thanks. I will do it in a follow-up.

default:
}
}
if s.fc.runIfPending(func() { bufferRequest() }) {
return nil
}

Expand Down Expand Up @@ -477,18 +477,13 @@ func (s *adsStreamImpl) sendMessageLocked(stream clients.Stream, names []string,
//
// It returns a boolean indicating whether at least one message was received
// from the server.
func (s *adsStreamImpl) recv(ctx context.Context, stream clients.Stream) bool {
func (s *adsStreamImpl) recv(stream clients.Stream) bool {
msgReceived := false
for {
// Wait for ADS stream level flow control to be available, and send out
// a request if anything was buffered while we were waiting for local
// processing of the previous response to complete.
if !s.fc.wait(ctx) {
if s.logger.V(2) {
s.logger.Infof("ADS stream context canceled")
}
return msgReceived
}
s.fc.wait()
s.sendBuffered(stream)

resources, url, version, nonce, err := s.recvMessage(stream)
Expand All @@ -508,8 +503,8 @@ func (s *adsStreamImpl) recv(ctx context.Context, stream clients.Stream) bool {
}
var resourceNames []string
var nackErr error
s.fc.setPending()
resourceNames, nackErr = s.eventHandler.onResponse(resp, s.fc.onDone)
s.fc.setPending(true)
resourceNames, nackErr = s.eventHandler.onResponse(resp, sync.OnceFunc(func() { s.fc.setPending(false) }))
if xdsresource.ErrType(nackErr) == xdsresource.ErrorTypeResourceTypeUnsupported {
// A general guiding principle is that if the server sends
// something the client didn't actually subscribe to, then the
Expand Down Expand Up @@ -707,69 +702,79 @@ func resourceNames(m map[string]*xdsresource.ResourceWatchState) []string {
return ret
}

// adsFlowControl implements ADS stream level flow control that enables the
// transport to block the reading of the next message off of the stream until
// the previous update is consumed by all watchers.
// adsFlowControl implements ADS stream level flow control that enables the ADS
// stream to block the reading of the next message until the previous update is
// consumed by all watchers.
//
// The lifetime of the flow control is tied to the lifetime of the stream.
// The lifetime of the flow control is tied to the lifetime of the stream. When
// the stream is closed, it is the responsibility of the caller to set the
// pending state to false. This ensures that any goroutine blocked on the flow
// control's wait method is unblocked.
type adsFlowControl struct {
logger *igrpclog.PrefixLogger

// Whether the most recent update is pending consumption by all watchers.
pending atomic.Bool
// Channel used to notify when all the watchers have consumed the most
// recent update. Wait() blocks on reading a value from this channel.
readyCh chan struct{}
mu sync.Mutex
cond *sync.Cond // signals when the most recent update has been consumed
pending bool // indicates if the most recent update is pending consumption
stopped bool // indicates if the ADS stream has been stopped
}

// newADSFlowControl returns a new adsFlowControl.
func newADSFlowControl(logger *igrpclog.PrefixLogger) *adsFlowControl {
return &adsFlowControl{
logger: logger,
readyCh: make(chan struct{}, 1),
}
func newADSFlowControl() *adsFlowControl {
fc := &adsFlowControl{}
fc.cond = sync.NewCond(&fc.mu)
return fc
}

// setPending changes the internal state to indicate that there is an update
// pending consumption by all watchers.
func (fc *adsFlowControl) setPending() {
fc.pending.Store(true)
// stop marks the flow control as stopped and signals the condition variable to
// unblock any goroutine waiting on it.
func (fc *adsFlowControl) stop() {
fc.mu.Lock()
defer fc.mu.Unlock()

fc.stopped = true
fc.cond.Broadcast()
}

// wait blocks until all the watchers have consumed the most recent update and
// returns true. If the context expires before that, it returns false.
func (fc *adsFlowControl) wait(ctx context.Context) bool {
// If there is no pending update, there is no need to block.
if !fc.pending.Load() {
// If all watchers finished processing the most recent update before the
// `recv` goroutine made the next call to `Wait()`, there would be an
// entry in the readyCh channel that needs to be drained to ensure that
// the next call to `Wait()` doesn't unblock before it actually should.
select {
case <-fc.readyCh:
default:
}
return true
// setPending changes the internal state to indicate whether there is an update
// pending consumption by all watchers. If there is no longer a pending update,
// the condition variable is signaled to allow the recv method to proceed.
func (fc *adsFlowControl) setPending(pending bool) {
fc.mu.Lock()
defer fc.mu.Unlock()

if fc.stopped {
return
}

select {
case <-ctx.Done():
fc.pending = pending
if !pending {
fc.cond.Broadcast()
}
}

func (fc *adsFlowControl) runIfPending(f func()) bool {
fc.mu.Lock()
defer fc.mu.Unlock()

if fc.stopped {
return false
case <-fc.readyCh:
return true
}

// If there's a pending update, run the function while still holding the
// lock. This ensures that the pending state does not change between the
// check and the function call.
pending := fc.pending
if fc.pending {
f()
}
return pending
}

// onDone indicates that all watchers have consumed the most recent update.
func (fc *adsFlowControl) onDone() {
select {
// Writes to the readyCh channel should not block ideally. The default
// branch here is to appease the paranoid mind.
case fc.readyCh <- struct{}{}:
default:
if fc.logger.V(2) {
fc.logger.Infof("ADS stream flow control readyCh is full")
}
// wait blocks until all the watchers have consumed the most recent update.
func (fc *adsFlowControl) wait() {
fc.mu.Lock()
defer fc.mu.Unlock()

for fc.pending && !fc.stopped {
fc.cond.Wait()
}
fc.pending.Store(false)
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package xdsclient_test

import (
"context"
"errors"
"fmt"
"slices"
"sort"
Expand Down Expand Up @@ -125,7 +124,6 @@ func (t *transport) NewStream(ctx context.Context, method string) (clients.Strea
stream := &stream{
stream: s,
recvCh: make(chan struct{}, 1),
doneCh: make(chan struct{}),
}
t.adsStreamCh <- stream

Expand All @@ -138,9 +136,7 @@ func (t *transport) Close() {

type stream struct {
stream grpc.ClientStream

recvCh chan struct{}
doneCh <-chan struct{}
}

func (s *stream) Send(msg []byte) error {
Expand All @@ -150,8 +146,8 @@ func (s *stream) Send(msg []byte) error {
func (s *stream) Recv() ([]byte, error) {
select {
case s.recvCh <- struct{}{}:
case <-s.doneCh:
return nil, errors.New("Recv() called after the test has finished")
case <-s.stream.Context().Done():
// Unblock the recv() once the stream is done.
}

var typedRes []byte
Expand Down
2 changes: 1 addition & 1 deletion internal/xds/clients/xdsclient/xdsclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ func (c *XDSClient) releaseChannel(serverConfig *ServerConfig, state *channelSta
c.channelsMu.Lock()

if c.logger.V(2) {
c.logger.Infof("Received request to release a reference to an xdsChannel for server config %q", serverConfig)
c.logger.Infof("Received request to release a reference to an xdsChannel for server config %+v", serverConfig)
}
deInitLocked(state)

Expand Down
4 changes: 2 additions & 2 deletions internal/xds/xdsclient/tests/dump_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func makeGenericXdsConfig(typeURL, name, version string, status v3adminpb.Client
}

func checkResourceDump(ctx context.Context, want *v3statuspb.ClientStatusResponse, pool *xdsclient.Pool) error {
var cmpOpts = cmp.Options{
cmpOpts := cmp.Options{
protocmp.Transform(),
protocmp.IgnoreFields((*v3statuspb.ClientConfig_GenericXdsConfig)(nil), "last_updated"),
protocmp.IgnoreFields((*v3adminpb.UpdateFailureState)(nil), "last_update_attempt", "details"),
Expand All @@ -89,7 +89,7 @@ func checkResourceDump(ctx context.Context, want *v3statuspb.ClientStatusRespons
if diff == "" {
return nil
}
lastErr = fmt.Errorf("received unexpected resource dump, diff (-got, +want):\n%s, got: %s\n want:%s", diff, pretty.ToJSON(got), pretty.ToJSON(want))
lastErr = fmt.Errorf("received unexpected resource dump, diff (-want, +got):\n%s, got: %s\n want:%s", diff, pretty.ToJSON(got), pretty.ToJSON(want))
}
return fmt.Errorf("timeout when waiting for resource dump to reach expected state: %v", lastErr)
}
Expand Down