Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fix Netty deprecation warnings in transport-reactor-netty4 module ([20429](https://github.com/opensearch-project/OpenSearch/pull/20429))
- Fix stats aggregation returning zero results with `size:0`. ([20427](https://github.com/opensearch-project/OpenSearch/pull/20427))
- Remove child level directory on refresh for CompositeIndexWriter ([#20326](https://github.com/opensearch-project/OpenSearch/pull/20326))
- Fixes and refactoring in stream transport to make it more robust ([#20359](https://github.com/opensearch-project/OpenSearch/pull/20359))

### Dependencies
- Bump `com.google.auth:google-auth-library-oauth2-http` from 1.38.0 to 1.41.0 ([#20183](https://github.com/opensearch-project/OpenSearch/pull/20183))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,8 @@ public static List<Setting<?>> getSettings() {
ARROW_ENABLE_DEBUG_ALLOCATOR,
ARROW_ENABLE_UNSAFE_MEMORY_ACCESS,
ARROW_SSL_ENABLE,
FLIGHT_EVENT_LOOP_THREADS
FLIGHT_EVENT_LOOP_THREADS,
FLIGHT_THREAD_POOL_MIN_SIZE
)
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import org.opensearch.arrow.flight.stats.FlightCallTracker;
import org.opensearch.arrow.flight.stats.FlightStatsCollector;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
Expand Down Expand Up @@ -45,6 +44,7 @@
*/
class FlightClientChannel implements TcpChannel {
private static final Logger logger = LogManager.getLogger(FlightClientChannel.class);
private static final AtomicLong GLOBAL_CHANNEL_COUNTER = new AtomicLong();
private final AtomicLong correlationIdGenerator = new AtomicLong();
private final FlightClient client;
private final DiscoveryNode node;
Expand Down Expand Up @@ -112,6 +112,11 @@ public FlightClientChannel(
this.closeListeners = new CopyOnWriteArrayList<>();
this.stats = new ChannelStats();
this.isClosed = false;
// Initialize with timestamp + global counter to ensure uniqueness with multiple channels
// Upper bits: timestamp, lower 20 bits: channel ID
long channelId = GLOBAL_CHANNEL_COUNTER.incrementAndGet() & 0xFFFFF; // 20 bits for channel ID
long initialValue = (System.currentTimeMillis() << 20) | channelId;
this.correlationIdGenerator.set(initialValue);
if (statsCollector != null) {
statsCollector.incrementClientChannelsActive();
}
Expand Down Expand Up @@ -229,7 +234,8 @@ public void sendMessage(long requestId, BytesReference reference, ActionListener
config
);

processStreamResponse(streamResponse);
// Open stream and prefetch first batch, invoke handler when ready
openStreamAndInvokeHandler(streamResponse);
listener.onResponse(null);
} catch (Exception e) {
if (callTracker != null) {
Expand All @@ -244,39 +250,44 @@ public void sendMessage(BytesReference reference, ActionListener<Void> listener)
throw new IllegalStateException("sendMessage must be accompanied with requestId for FlightClientChannel, use the right variant.");
}

private void processStreamResponse(FlightTransportResponse<?> streamResponse) {
try {
executeWithThreadContext(streamResponse);
} catch (Exception e) {
handleStreamException(streamResponse, e);
}
}

@SuppressWarnings({ "unchecked", "rawtypes" })
private void executeWithThreadContext(FlightTransportResponse<?> streamResponse) {
final ThreadContext threadContext = threadPool.getThreadContext();
final String executor = streamResponse.getHandler().executor();
private void openStreamAndInvokeHandler(FlightTransportResponse<?> streamResponse) {
TransportResponseHandler handler = streamResponse.getHandler();
String executor = handler.executor();

if (ThreadPool.Names.SAME.equals(executor)) {
executeHandler(threadContext, streamResponse);
} else {
threadPool.executor(executor).execute(() -> executeHandler(threadContext, streamResponse));
logger.debug("Stream transport handler using SAME executor, which may cause blocking behavior");
}
}

@SuppressWarnings({ "unchecked", "rawtypes" })
private void executeHandler(ThreadContext threadContext, FlightTransportResponse<?> streamResponse) {
try (ThreadContext.StoredContext ignored = threadContext.stashContext()) {
Header header = streamResponse.getHeader();
if (header == null) {
throw new StreamException(StreamErrorCode.INTERNAL, "Header is null");
var threadContext = threadPool.getThreadContext();
CompletableFuture<Header> future = new CompletableFuture<>();
streamResponse.openAndPrefetchAsync(future);

future.whenComplete((header, error) -> {
if (error != null) {
handleStreamException(streamResponse, error instanceof Exception ? (Exception) error : new Exception(error));
return;
}
TransportResponseHandler handler = streamResponse.getHandler();
threadContext.setHeaders(header.getHeaders());
handler.handleStreamResponse(streamResponse);
} catch (Exception e) {
cleanupStreamResponse(streamResponse);
throw e;
}

Runnable task = () -> {
try (var ignored = threadContext.stashContext()) {
if (header == null) {
handleStreamException(streamResponse, new StreamException(StreamErrorCode.INTERNAL, "Header is null"));
}
threadContext.setHeaders(header.getHeaders());
handler.handleStreamResponse(streamResponse);
} catch (Exception e) {
cleanupStreamResponse(streamResponse);
throw e;
}
};

if (ThreadPool.Names.SAME.equals(executor)) {
task.run();
} else {
threadPool.executor(executor).execute(task);
}
});
}

private void cleanupStreamResponse(StreamTransportResponse<?> streamResponse) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ private void processCompleteTask(BatchTask task) {
}

try {
flightChannel.completeStream();
flightChannel.completeStream(getHeaderBuffer(task.requestId(), task.nodeVersion(), task.features()));
messageListener.onResponseSent(task.requestId(), task.action(), TransportResponse.Empty.INSTANCE);
} catch (Exception e) {
messageListener.onResponseSent(task.requestId(), task.action(), e);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.OpenSearchException;
import org.opensearch.arrow.flight.stats.FlightCallTracker;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesReference;
Expand All @@ -28,9 +29,9 @@
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;

import static org.opensearch.arrow.flight.transport.FlightErrorMapper.mapFromCallStatus;

Expand All @@ -49,10 +50,12 @@ class FlightServerChannel implements TcpChannel {
private final InetSocketAddress remoteAddress;
private final List<ActionListener<Void>> closeListeners = Collections.synchronizedList(new ArrayList<>());
private final ServerHeaderMiddleware middleware;
private volatile Optional<VectorSchemaRoot> root = Optional.empty();
private volatile VectorSchemaRoot root = null;
private final FlightCallTracker callTracker;
private volatile boolean cancelled = false;
private final ExecutorService executor;
private final long correlationId;
private final AtomicInteger batchNumber = new AtomicInteger(0);

public FlightServerChannel(
ServerStreamListener serverStreamListener,
Expand All @@ -61,15 +64,14 @@ public FlightServerChannel(
FlightCallTracker callTracker,
ExecutorService executor
) {
this.correlationId = Long.parseLong(middleware.getCorrelationId());
logger.debug("Creating FlightServerChannel for correlation ID: {}", correlationId);
this.serverStreamListener = serverStreamListener;
this.serverStreamListener.setUseZeroCopy(true);
this.serverStreamListener.setOnCancelHandler(new Runnable() {
@Override
public void run() {
cancelled = true;
callTracker.recordCallEnd(StreamErrorCode.CANCELLED.name());
close();
}
this.serverStreamListener.setOnCancelHandler(() -> {
cancelled = true;
callTracker.recordCallEnd(StreamErrorCode.CANCELLED.name());
close();
});
this.allocator = allocator;
this.middleware = middleware;
Expand All @@ -83,7 +85,7 @@ public BufferAllocator getAllocator() {
return allocator;
}

Optional<VectorSchemaRoot> getRoot() {
VectorSchemaRoot getRoot() {
return root;
}

Expand All @@ -106,35 +108,53 @@ public void sendBatch(ByteBuffer header, VectorStreamOutput output) {
if (!open.get()) {
throw new IllegalStateException("FlightServerChannel already closed.");
}
batchNumber.incrementAndGet();
long batchStartTime = System.nanoTime();
// Only set for the first batch
if (root.isEmpty()) {
if (root == null) {
middleware.setHeader(header);
root = Optional.of(output.getRoot());
serverStreamListener.start(root.get());
root = output.getRoot();
serverStreamListener.start(root);
} else {
root = Optional.of(output.getRoot());
root = output.getRoot();
// placeholder to clear and fill the root with data for the next batch
}

logger.debug("Sending batch #{} for correlation ID: {}", batchNumber, correlationId);
// we do not want to close the root right after putNext() call as we do not know the status of it whether
// its transmitted at transport; we close them all at complete stream. TODO: optimize this behaviour
serverStreamListener.putNext();
long putNextTime = (System.nanoTime() - batchStartTime) / 1_000_000;
if (callTracker != null) {
long rootSize = FlightUtils.calculateVectorSchemaRootSize(root.get());
long rootSize = FlightUtils.calculateVectorSchemaRootSize(root);
callTracker.recordBatchSent(rootSize, System.nanoTime() - batchStartTime);
logger.debug(
"Batch #{} sent for correlation ID: {}, size: {} bytes, putNext: {}ms",
batchNumber,
correlationId,
rootSize,
putNextTime
);
} else {
logger.debug("Batch #{} sent for correlation ID: {}, putNext: {}ms", batchNumber, correlationId, putNextTime);
}
}

/**
* Completes the streaming response and closes all pending roots.
*
*/
public void completeStream() {
public void completeStream(ByteBuffer header) {
try {
if (!open.get()) {
throw new IllegalStateException("FlightServerChannel already closed.");
}
if (root == null) {
// Set header if no batches were sent
middleware.setHeader(header);
logger.debug("Completing empty stream for correlation ID: {}", correlationId);
} else {
logger.debug("Completing stream for correlation ID: {} after {} batches", correlationId, batchNumber);
}
serverStreamListener.completed();
} finally {
callTracker.recordCallEnd(StreamErrorCode.OK.name());
Expand All @@ -160,8 +180,13 @@ public void sendError(ByteBuffer header, Exception error) {
.toRuntimeException();
}
middleware.setHeader(header);
if (error instanceof OpenSearchException) {
logger.debug("Error in Flight stream: {}", error.getMessage());
} else {
logger.error("Unexpected error in Flight stream", error);
}
logger.debug("Sending error for correlation ID: {} after {} batches: {}", correlationId, batchNumber, error.getMessage());
serverStreamListener.error(flightExc);
logger.debug(error);
} finally {
StreamErrorCode errorCode = flightExc != null ? mapFromCallStatus(flightExc) : StreamErrorCode.UNKNOWN;
callTracker.recordCallEnd(errorCode.name());
Expand Down Expand Up @@ -210,7 +235,9 @@ public void close() {
return;
}
open.set(false);
root.ifPresent(VectorSchemaRoot::close);
if (root != null) {
root.close();
}
notifyCloseListeners();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -364,7 +364,8 @@ public List<Setting<?>> getSettings() {
ServerComponents.SETTING_FLIGHT_PORTS,
ServerComponents.SETTING_FLIGHT_HOST,
ServerComponents.SETTING_FLIGHT_BIND_HOST,
ServerComponents.SETTING_FLIGHT_PUBLISH_HOST
ServerComponents.SETTING_FLIGHT_PUBLISH_HOST,
ServerComponents.SETTING_FLIGHT_PUBLISH_PORT
)
) {
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,15 @@
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;

import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.MultiThreadIoEventLoopGroup;
import io.netty.channel.nio.NioIoHandler;

import static org.opensearch.arrow.flight.bootstrap.ServerComponents.SETTING_FLIGHT_BIND_HOST;
import static org.opensearch.arrow.flight.bootstrap.ServerComponents.SETTING_FLIGHT_PORTS;
Expand Down Expand Up @@ -132,7 +134,7 @@ public FlightTransport(
this.sslContextProvider = sslContextProvider;
this.statsCollector = statsCollector;
this.bossEventLoopGroup = createEventLoopGroup("os-grpc-boss-ELG", 1);
this.workerEventLoopGroup = createEventLoopGroup("os-grpc-worker-ELG", Runtime.getRuntime().availableProcessors() * 2);
this.workerEventLoopGroup = createEventLoopGroup("os-grpc-worker-ELG", Runtime.getRuntime().availableProcessors());
this.serverExecutor = threadPool.executor(ServerConfig.GRPC_EXECUTOR_THREAD_POOL_NAME);
this.clientExecutor = threadPool.executor(ServerConfig.FLIGHT_CLIENT_THREAD_POOL_NAME);
this.threadPool = threadPool;
Expand Down Expand Up @@ -409,7 +411,9 @@ protected InboundHandler createInboundHandler(
}

private EventLoopGroup createEventLoopGroup(String name, int threads) {
return new NioEventLoopGroup(threads);
AtomicInteger threadCounter = new AtomicInteger(0);
ThreadFactory threadFactory = r -> new Thread(r, name + "-" + threadCounter.incrementAndGet());
return new MultiThreadIoEventLoopGroup(threads, threadFactory, NioIoHandler.newFactory());
}

private void gracefullyShutdownELG(EventLoopGroup group, String name) {
Expand Down
Loading
Loading