diff --git a/CHANGELOG.md b/CHANGELOG.md index 8b936ffbf6486..4b084534ec374 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -173,6 +173,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - Fixed javadoc warning for build failure([#4581](https://github.com/opensearch-project/OpenSearch/pull/4581)) - Added transport actions support for extensions ([#4598](https://github.com/opensearch-project/OpenSearch/pull/4598/)) - Pass REST params and content to extensions ([#4633](https://github.com/opensearch-project/OpenSearch/pull/4633)) + - Return consumed params and content from extensions ([#4705](https://github.com/opensearch-project/OpenSearch/pull/4705)) ## [2.x] diff --git a/server/src/main/java/org/opensearch/extensions/rest/ExtensionRestResponse.java b/server/src/main/java/org/opensearch/extensions/rest/ExtensionRestResponse.java new file mode 100644 index 0000000000000..0eb59823bee93 --- /dev/null +++ b/server/src/main/java/org/opensearch/extensions/rest/ExtensionRestResponse.java @@ -0,0 +1,113 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.rest; + +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.rest.BytesRestResponse; +import org.opensearch.rest.RestStatus; + +import java.util.List; + +/** + * A subclass of {@link BytesRestResponse} which also tracks consumed parameters and content. + * + * @opensearch.api + */ +public class ExtensionRestResponse extends BytesRestResponse { + + private final List consumedParams; + private final boolean contentConsumed; + + /** + * Creates a new response based on {@link XContentBuilder}. + * + * @param request the REST request being responded to. + * @param status The REST status. + * @param builder The builder for the response. + */ + public ExtensionRestResponse(ExtensionRestRequest request, RestStatus status, XContentBuilder builder) { + super(status, builder); + this.consumedParams = request.consumedParams(); + this.contentConsumed = request.isContentConsumed(); + } + + /** + * Creates a new plain text response. + * + * @param request the REST request being responded to. + * @param status The REST status. + * @param content A plain text response string. + */ + public ExtensionRestResponse(ExtensionRestRequest request, RestStatus status, String content) { + super(status, content); + this.consumedParams = request.consumedParams(); + this.contentConsumed = request.isContentConsumed(); + } + + /** + * Creates a new plain text response. + * + * @param request the REST request being responded to. + * @param status The REST status. + * @param contentType The content type of the response string. + * @param content A response string. + */ + public ExtensionRestResponse(ExtensionRestRequest request, RestStatus status, String contentType, String content) { + super(status, contentType, content); + this.consumedParams = request.consumedParams(); + this.contentConsumed = request.isContentConsumed(); + } + + /** + * Creates a binary response. + * + * @param request the REST request being responded to. + * @param status The REST status. + * @param contentType The content type of the response bytes. + * @param content Response bytes. + */ + public ExtensionRestResponse(ExtensionRestRequest request, RestStatus status, String contentType, byte[] content) { + super(status, contentType, content); + this.consumedParams = request.consumedParams(); + this.contentConsumed = request.isContentConsumed(); + } + + /** + * Creates a binary response. + * + * @param request the REST request being responded to. + * @param status The REST status. + * @param contentType The content type of the response bytes. + * @param content Response bytes. + */ + public ExtensionRestResponse(ExtensionRestRequest request, RestStatus status, String contentType, BytesReference content) { + super(status, contentType, content); + this.consumedParams = request.consumedParams(); + this.contentConsumed = request.isContentConsumed(); + } + + /** + * Gets the list of consumed parameters. These are needed to consume the parameters of the original request. + * + * @return the list of consumed params. + */ + public List getConsumedParams() { + return consumedParams; + } + + /** + * Reports whether content was consumed. + * + * @return true if the content was consumed, false otherwise. + */ + public boolean isContentConsumed() { + return contentConsumed; + } +} diff --git a/server/src/main/java/org/opensearch/extensions/rest/RestExecuteOnExtensionResponse.java b/server/src/main/java/org/opensearch/extensions/rest/RestExecuteOnExtensionResponse.java index 39661bd78d996..e2625105e705c 100644 --- a/server/src/main/java/org/opensearch/extensions/rest/RestExecuteOnExtensionResponse.java +++ b/server/src/main/java/org/opensearch/extensions/rest/RestExecuteOnExtensionResponse.java @@ -10,14 +10,11 @@ import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestResponse; import org.opensearch.rest.RestStatus; import org.opensearch.transport.TransportResponse; import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.util.Collections; import java.util.List; import java.util.Map; @@ -32,16 +29,8 @@ public class RestExecuteOnExtensionResponse extends TransportResponse { private String contentType; private byte[] content; private Map> headers; - - /** - * Instantiate this object with a status and response string. - * - * @param status The REST status. - * @param responseString The response content as a String. - */ - public RestExecuteOnExtensionResponse(RestStatus status, String responseString) { - this(status, BytesRestResponse.TEXT_CONTENT_TYPE, responseString.getBytes(StandardCharsets.UTF_8), Collections.emptyMap()); - } + private List consumedParams; + private boolean contentConsumed; /** * Instantiate this object with the components of a {@link RestResponse}. @@ -50,33 +39,49 @@ public RestExecuteOnExtensionResponse(RestStatus status, String responseString) * @param contentType The type of the content. * @param content The content. * @param headers The headers. + * @param consumedParams The consumed params. + * @param contentConsumed Whether content was consumed. */ - public RestExecuteOnExtensionResponse(RestStatus status, String contentType, byte[] content, Map> headers) { + public RestExecuteOnExtensionResponse( + RestStatus status, + String contentType, + byte[] content, + Map> headers, + List consumedParams, + boolean contentConsumed + ) { + super(); setStatus(status); setContentType(contentType); setContent(content); setHeaders(headers); + setConsumedParams(consumedParams); + setContentConsumed(contentConsumed); } /** - * Instantiate this object from a Transport Stream + * Instantiate this object from a Transport Stream. * * @param in The stream input. * @throws IOException on transport failure. */ public RestExecuteOnExtensionResponse(StreamInput in) throws IOException { - setStatus(RestStatus.readFrom(in)); + setStatus(in.readEnum(RestStatus.class)); setContentType(in.readString()); setContent(in.readByteArray()); setHeaders(in.readMapOfLists(StreamInput::readString, StreamInput::readString)); + setConsumedParams(in.readStringList()); + setContentConsumed(in.readBoolean()); } @Override public void writeTo(StreamOutput out) throws IOException { - RestStatus.writeTo(out, status); + out.writeEnum(status); out.writeString(contentType); out.writeByteArray(content); out.writeMapOfLists(headers, StreamOutput::writeString, StreamOutput::writeString); + out.writeStringCollection(consumedParams); + out.writeBoolean(contentConsumed); } public RestStatus getStatus() { @@ -110,4 +115,20 @@ public Map> getHeaders() { public void setHeaders(Map> headers) { this.headers = Map.copyOf(headers); } + + public List getConsumedParams() { + return consumedParams; + } + + public void setConsumedParams(List consumedParams) { + this.consumedParams = consumedParams; + } + + public boolean isContentConsumed() { + return contentConsumed; + } + + public void setContentConsumed(boolean contentConsumed) { + this.contentConsumed = contentConsumed; + } } diff --git a/server/src/main/java/org/opensearch/extensions/rest/RestSendToExtensionAction.java b/server/src/main/java/org/opensearch/extensions/rest/RestSendToExtensionAction.java index 8a35638c9d939..45c1a771c9e83 100644 --- a/server/src/main/java/org/opensearch/extensions/rest/RestSendToExtensionAction.java +++ b/server/src/main/java/org/opensearch/extensions/rest/RestSendToExtensionAction.java @@ -34,11 +34,9 @@ import java.util.ArrayList; import java.util.List; import java.util.Map; -import java.util.Map.Entry; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; - +import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; import static java.util.Collections.unmodifiableList; @@ -49,8 +47,6 @@ public class RestSendToExtensionAction extends BaseRestHandler { private static final String SEND_TO_EXTENSION_ACTION = "send_to_extension_action"; private static final Logger logger = LogManager.getLogger(RestSendToExtensionAction.class); - private static final String CONSUMED_PARAMS_KEY = "extension.consumed.parameters"; - private static final String CONSUMED_CONTENT_KEY = "extension.consumed.content"; // To replace with user identity see https://github.com/opensearch-project/OpenSearch/pull/4247 private static final Principal DEFAULT_PRINCIPAL = new Principal() { @Override @@ -124,7 +120,9 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC RestStatus.INTERNAL_SERVER_ERROR, BytesRestResponse.TEXT_CONTENT_TYPE, message.getBytes(StandardCharsets.UTF_8), - emptyMap() + emptyMap(), + emptyList(), + false ); final CountDownLatch inProgressLatch = new CountDownLatch(1); final TransportResponseHandler restExecuteOnExtensionResponseHandler = new TransportResponseHandler< @@ -141,25 +139,12 @@ public void handleResponse(RestExecuteOnExtensionResponse response) { restExecuteOnExtensionResponse.setStatus(response.getStatus()); restExecuteOnExtensionResponse.setContentType(response.getContentType()); restExecuteOnExtensionResponse.setContent(response.getContent()); - // Extract the consumed parameters and content from the header - Map> headers = response.getHeaders(); - List consumedParams = headers.get(CONSUMED_PARAMS_KEY); - if (consumedParams != null) { - // consume each param - consumedParams.stream().forEach(p -> request.param(p)); - } - List consumedContent = headers.get(CONSUMED_CONTENT_KEY); - if (consumedContent != null) { - // conditionally consume content - if (consumedParams.stream().filter(c -> Boolean.parseBoolean(c)).count() > 0) { - request.content(); - } + restExecuteOnExtensionResponse.setHeaders(response.getHeaders()); + // Consume parameters and content + response.getConsumedParams().stream().forEach(p -> request.param(p)); + if (response.isContentConsumed()) { + request.content(); } - Map> headersWithoutConsumedParams = headers.entrySet() - .stream() - .filter(e -> !e.getKey().equals(CONSUMED_PARAMS_KEY)) - .collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue())); - restExecuteOnExtensionResponse.setHeaders(headersWithoutConsumedParams); inProgressLatch.countDown(); } @@ -204,11 +189,11 @@ public String executor() { restExecuteOnExtensionResponse.getContentType(), restExecuteOnExtensionResponse.getContent() ); - for (Entry> headerEntry : restExecuteOnExtensionResponse.getHeaders().entrySet()) { - for (String value : headerEntry.getValue()) { - restResponse.addHeader(headerEntry.getKey(), value); - } - } + // No constructor that includes headers so we roll our own + restExecuteOnExtensionResponse.getHeaders() + .entrySet() + .stream() + .forEach(e -> { e.getValue().stream().forEach(v -> restResponse.addHeader(e.getKey(), v)); }); return channel -> channel.sendResponse(restResponse); } diff --git a/server/src/test/java/org/opensearch/extensions/rest/ExtensionRestRequestTests.java b/server/src/test/java/org/opensearch/extensions/rest/ExtensionRestRequestTests.java index d095783adc228..55d89d08371a8 100644 --- a/server/src/test/java/org/opensearch/extensions/rest/ExtensionRestRequestTests.java +++ b/server/src/test/java/org/opensearch/extensions/rest/ExtensionRestRequestTests.java @@ -188,30 +188,51 @@ public void testRestExecuteOnExtensionResponse() throws Exception { String expectedResponse = "Test response"; byte[] expectedResponseBytes = expectedResponse.getBytes(StandardCharsets.UTF_8); - RestExecuteOnExtensionResponse response = new RestExecuteOnExtensionResponse(expectedStatus, expectedResponse); + RestExecuteOnExtensionResponse response = new RestExecuteOnExtensionResponse( + expectedStatus, + expectedContentType, + expectedResponseBytes, + Collections.emptyMap(), + Collections.emptyList(), + false + ); assertEquals(expectedStatus, response.getStatus()); assertEquals(expectedContentType, response.getContentType()); assertArrayEquals(expectedResponseBytes, response.getContent()); assertEquals(0, response.getHeaders().size()); + assertEquals(0, response.getConsumedParams().size()); + assertFalse(response.isContentConsumed()); String headerKey = "foo"; List headerValueList = List.of("bar", "baz"); Map> expectedHeaders = Map.of(headerKey, headerValueList); + List expectedConsumedParams = List.of("foo", "bar"); - response = new RestExecuteOnExtensionResponse(expectedStatus, expectedContentType, expectedResponseBytes, expectedHeaders); + response = new RestExecuteOnExtensionResponse( + expectedStatus, + expectedContentType, + expectedResponseBytes, + expectedHeaders, + expectedConsumedParams, + true + ); assertEquals(expectedStatus, response.getStatus()); assertEquals(expectedContentType, response.getContentType()); assertArrayEquals(expectedResponseBytes, response.getContent()); - assertEquals(1, expectedHeaders.keySet().size()); - assertTrue(expectedHeaders.containsKey(headerKey)); + assertEquals(1, response.getHeaders().keySet().size()); + assertTrue(response.getHeaders().containsKey(headerKey)); - List fooList = expectedHeaders.get(headerKey); + List fooList = response.getHeaders().get(headerKey); assertEquals(2, fooList.size()); assertTrue(fooList.containsAll(headerValueList)); + assertEquals(2, response.getConsumedParams().size()); + assertTrue(response.getConsumedParams().containsAll(expectedConsumedParams)); + assertTrue(response.isContentConsumed()); + try (BytesStreamOutput out = new BytesStreamOutput()) { response.writeTo(out); out.flush(); @@ -222,12 +243,16 @@ public void testRestExecuteOnExtensionResponse() throws Exception { assertEquals(expectedContentType, response.getContentType()); assertArrayEquals(expectedResponseBytes, response.getContent()); - assertEquals(1, expectedHeaders.keySet().size()); - assertTrue(expectedHeaders.containsKey(headerKey)); + assertEquals(1, response.getHeaders().keySet().size()); + assertTrue(response.getHeaders().containsKey(headerKey)); - fooList = expectedHeaders.get(headerKey); + fooList = response.getHeaders().get(headerKey); assertEquals(2, fooList.size()); assertTrue(fooList.containsAll(headerValueList)); + + assertEquals(2, response.getConsumedParams().size()); + assertTrue(response.getConsumedParams().containsAll(expectedConsumedParams)); + assertTrue(response.isContentConsumed()); } } } diff --git a/server/src/test/java/org/opensearch/extensions/rest/ExtensionRestResponseTests.java b/server/src/test/java/org/opensearch/extensions/rest/ExtensionRestResponseTests.java new file mode 100644 index 0000000000000..82ae61b02cb32 --- /dev/null +++ b/server/src/test/java/org/opensearch/extensions/rest/ExtensionRestResponseTests.java @@ -0,0 +1,132 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.extensions.rest; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Collections; + +import org.opensearch.common.bytes.BytesArray; +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.rest.RestRequest.Method; +import org.opensearch.test.OpenSearchTestCase; + +import static org.opensearch.rest.BytesRestResponse.TEXT_CONTENT_TYPE; +import static org.opensearch.rest.RestStatus.ACCEPTED; +import static org.opensearch.rest.RestStatus.OK; + +public class ExtensionRestResponseTests extends OpenSearchTestCase { + + private static final String OCTET_CONTENT_TYPE = "application/octet-stream"; + private static final String JSON_CONTENT_TYPE = "application/json; charset=UTF-8"; + + private String testText; + private byte[] testBytes; + + @Override + public void setUp() throws Exception { + super.setUp(); + testText = "plain text"; + testBytes = new byte[] { 1, 2 }; + } + + private ExtensionRestRequest generateTestRequest() { + ExtensionRestRequest request = new ExtensionRestRequest( + Method.GET, + "/foo", + Collections.emptyMap(), + null, + new BytesArray("Text Content"), + null + ); + // consume params "foo" and "bar" + request.param("foo"); + request.param("bar"); + // consume content + request.content(); + return request; + } + + public void testConstructorWithBuilder() throws IOException { + XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent()); + builder.startObject(); + builder.field("status", ACCEPTED); + builder.endObject(); + ExtensionRestRequest request = generateTestRequest(); + ExtensionRestResponse response = new ExtensionRestResponse(request, OK, builder); + + assertEquals(OK, response.status()); + assertEquals(JSON_CONTENT_TYPE, response.contentType()); + assertEquals("{\"status\":\"ACCEPTED\"}", response.content().utf8ToString()); + for (String param : response.getConsumedParams()) { + assertTrue(request.consumedParams().contains(param)); + } + assertTrue(request.isContentConsumed()); + } + + public void testConstructorWithPlainText() { + ExtensionRestRequest request = generateTestRequest(); + ExtensionRestResponse response = new ExtensionRestResponse(request, OK, testText); + + assertEquals(OK, response.status()); + assertEquals(TEXT_CONTENT_TYPE, response.contentType()); + assertEquals(testText, response.content().utf8ToString()); + for (String param : response.getConsumedParams()) { + assertTrue(request.consumedParams().contains(param)); + } + assertTrue(request.isContentConsumed()); + } + + public void testConstructorWithText() { + ExtensionRestRequest request = generateTestRequest(); + ExtensionRestResponse response = new ExtensionRestResponse(request, OK, TEXT_CONTENT_TYPE, testText); + + assertEquals(OK, response.status()); + assertEquals(TEXT_CONTENT_TYPE, response.contentType()); + assertEquals(testText, response.content().utf8ToString()); + + for (String param : response.getConsumedParams()) { + assertTrue(request.consumedParams().contains(param)); + } + assertTrue(request.isContentConsumed()); + } + + public void testConstructorWithByteArray() { + ExtensionRestRequest request = generateTestRequest(); + ExtensionRestResponse response = new ExtensionRestResponse(request, OK, OCTET_CONTENT_TYPE, testBytes); + + assertEquals(OK, response.status()); + assertEquals(OCTET_CONTENT_TYPE, response.contentType()); + assertArrayEquals(testBytes, BytesReference.toBytes(response.content())); + for (String param : response.getConsumedParams()) { + assertTrue(request.consumedParams().contains(param)); + } + assertTrue(request.isContentConsumed()); + } + + public void testConstructorWithBytesReference() { + ExtensionRestRequest request = generateTestRequest(); + ExtensionRestResponse response = new ExtensionRestResponse( + request, + OK, + OCTET_CONTENT_TYPE, + BytesReference.fromByteBuffer(ByteBuffer.wrap(testBytes, 0, 2)) + ); + + assertEquals(OK, response.status()); + assertEquals(OCTET_CONTENT_TYPE, response.contentType()); + assertArrayEquals(testBytes, BytesReference.toBytes(response.content())); + for (String param : response.getConsumedParams()) { + assertTrue(request.consumedParams().contains(param)); + } + assertTrue(request.isContentConsumed()); + } +}