diff --git a/CHANGELOG.md b/CHANGELOG.md index 6b76a3d50cb0d..eee5f91b89415 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -47,6 +47,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Changed - Indexed IP field supports `terms_query` with more than 1025 IP masks [#16391](https://github.com/opensearch-project/OpenSearch/pull/16391) - Make entries for dependencies from server/build.gradle to gradle version catalog ([#16707](https://github.com/opensearch-project/OpenSearch/pull/16707)) +- Add listenable TransportRequestHandler in TransportNodesAction ([#15166](https://github.com/opensearch-project/OpenSearch/pull/15166)) ### Deprecated - Performing update operation with default pipeline or final pipeline is deprecated ([#16712](https://github.com/opensearch-project/OpenSearch/pull/16712)) diff --git a/server/src/main/java/org/opensearch/action/support/nodes/TransportNodesAction.java b/server/src/main/java/org/opensearch/action/support/nodes/TransportNodesAction.java index dccd5059dd52d..5e339c388439c 100644 --- a/server/src/main/java/org/opensearch/action/support/nodes/TransportNodesAction.java +++ b/server/src/main/java/org/opensearch/action/support/nodes/TransportNodesAction.java @@ -116,9 +116,50 @@ protected TransportNodesAction( transportService.registerRequestHandler(transportNodeAction, nodeExecutor, nodeRequest, new NodeTransportHandler()); } + /** + * @param actionName action name + * @param threadPool thread-pool + * @param clusterService cluster service + * @param transportService transport service + * @param actionFilters action filters + * @param request node request writer + * @param nodeRequest node request reader + * @param nodeExecutor executor to execute node action on + * @param finalExecutor executor to execute final collection of all responses on + * @param listenableHandler true if the handler should be a listenable handler + * @param nodeResponseClass class of the node responses + */ + protected TransportNodesAction( + String actionName, + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + Writeable.Reader request, + Writeable.Reader nodeRequest, + String nodeExecutor, + String finalExecutor, + boolean listenableHandler, + Class nodeResponseClass + ) { + super(actionName, transportService, actionFilters, request); + this.threadPool = threadPool; + this.clusterService = Objects.requireNonNull(clusterService); + this.transportService = Objects.requireNonNull(transportService); + this.nodeResponseClass = Objects.requireNonNull(nodeResponseClass); + + this.transportNodeAction = actionName + "[n]"; + this.finalExecutor = finalExecutor; + if (listenableHandler) { + transportService.registerRequestHandler(transportNodeAction, nodeExecutor, nodeRequest, new ListenableNodeTransportHandler()); + } else { + transportService.registerRequestHandler(transportNodeAction, nodeExecutor, nodeRequest, new NodeTransportHandler()); + } + } + /** * Same as {@link #TransportNodesAction(String, ThreadPool, ClusterService, TransportService, ActionFilters, Writeable.Reader, - * Writeable.Reader, String, String, Class)} but executes final response collection on the transport thread except for when the final + * Writeable.Reader, String, String, boolean, Class)} but executes final response collection on the transport thread except for when the final * node response is received from the local node, in which case {@code nodeExecutor} is used. * This constructor should only be used for actions for which the creation of the final response is fast enough to be safely executed * on a transport thread. @@ -144,6 +185,7 @@ protected TransportNodesAction( nodeRequest, nodeExecutor, ThreadPool.Names.SAME, + false, nodeResponseClass ); } @@ -196,6 +238,8 @@ protected NodesResponse newResponse(NodesRequest request, AtomicReferenceArray actionListener) {} + protected NodeResponse nodeOperation(NodeRequest request, Task task) { return nodeOperation(request); } @@ -335,4 +379,14 @@ public void messageReceived(NodeRequest request, TransportChannel channel, Task } } + class ListenableNodeTransportHandler implements TransportRequestHandler { + + @Override + public void messageReceived(NodeRequest request, TransportChannel channel, Task task) { + ActionListener listener = ActionListener.wrap(channel::sendResponse, e -> { + TransportChannel.sendErrorResponse(channel, actionName, request, e); + }); + nodeOperation(request, listener); + } + } } diff --git a/server/src/test/java/org/opensearch/action/support/nodes/TransportNodesActionTests.java b/server/src/test/java/org/opensearch/action/support/nodes/TransportNodesActionTests.java index a338e68276bbc..c755ff2307f13 100644 --- a/server/src/test/java/org/opensearch/action/support/nodes/TransportNodesActionTests.java +++ b/server/src/test/java/org/opensearch/action/support/nodes/TransportNodesActionTests.java @@ -43,16 +43,19 @@ import org.opensearch.cluster.node.DiscoveryNodeRole; import org.opensearch.cluster.node.DiscoveryNodes; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.common.io.stream.Writeable; import org.opensearch.indices.IndicesService; import org.opensearch.node.NodeService; +import org.opensearch.tasks.Task; import org.opensearch.telemetry.tracing.noop.NoopTracer; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.transport.CapturingTransport; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportChannel; import org.opensearch.transport.TransportRequest; import org.opensearch.transport.TransportService; import org.junit.After; @@ -74,9 +77,12 @@ import java.util.function.Supplier; import java.util.stream.Collectors; +import org.mockito.ArgumentCaptor; + import static org.opensearch.test.ClusterServiceUtils.createClusterService; import static org.opensearch.test.ClusterServiceUtils.setState; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; public class TransportNodesActionTests extends OpenSearchTestCase { @@ -198,6 +204,28 @@ public void testTransportNodesActionWithDiscoveryNodesReset() { capturedTransportNodeRequestList.forEach(capturedRequest -> assertNull(capturedRequest.testNodesRequest.concreteNodes())); } + public void testCreateTransportNodesActionWithListenableHandler() { + TransportNodesAction action = getListenableHandlerTestTransportNodesAction(); + assertTrue( + transport.getRequestHandlers() + .getHandler(action.actionName + "[n]") + .getHandler() instanceof TransportNodesAction.ListenableNodeTransportHandler + ); + } + + public void testMessageReceivedInListenableNodeTransportHandler() throws Exception { + TransportNodesAction action = getListenableHandlerTestTransportNodesAction(); + TransportChannel transportChannel = mock(TransportChannel.class); + transport.getRequestHandlers() + .getHandler(action.actionName + "[n]") + .getHandler() + .messageReceived(new TestNodeRequest(), transportChannel, mock(Task.class)); + ArgumentCaptor argCaptor = ArgumentCaptor.forClass(TestNodeResponse.class); + verify(transportChannel).sendResponse(argCaptor.capture()); + TestNodeResponse response = argCaptor.getValue(); + assertNotNull(response); + } + private List mockList(Supplier supplier, int size) { List failures = new ArrayList<>(size); for (int i = 0; i < size; ++i) { @@ -290,6 +318,19 @@ public TestTransportNodesAction getTestTransportNodesAction() { ); } + public TestTransportNodesAction getListenableHandlerTestTransportNodesAction() { + return new TestTransportNodesAction( + THREAD_POOL, + clusterService, + transportService, + new ActionFilters(Collections.emptySet()), + TestNodesRequest::new, + TestNodeRequest::new, + ThreadPool.Names.SAME, + true + ); + } + public DataNodesOnlyTransportNodesAction getDataNodesOnlyTransportNodesAction(TransportService transportService) { return new DataNodesOnlyTransportNodesAction( THREAD_POOL, @@ -335,6 +376,31 @@ private static class TestTransportNodesAction extends TransportNodesAction< ); } + TestTransportNodesAction( + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + Writeable.Reader request, + Writeable.Reader nodeRequest, + String nodeExecutor, + boolean listenableHandler + ) { + super( + "indices:admin/test", + threadPool, + clusterService, + transportService, + actionFilters, + request, + nodeRequest, + nodeExecutor, + nodeExecutor, + listenableHandler, + TestNodeResponse.class + ); + } + @Override protected TestNodesResponse newResponse( TestNodesRequest request, @@ -359,6 +425,11 @@ protected TestNodeResponse nodeOperation(TestNodeRequest request) { return new TestNodeResponse(); } + @Override + protected void nodeOperation(TestNodeRequest request, ActionListener actionListener) { + actionListener.onResponse(new TestNodeResponse()); + } + } private static class DataNodesOnlyTransportNodesAction extends TestTransportNodesAction {