Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- Fixed javadoc warning for build failure([#4581](https://github.com/opensearch-project/OpenSearch/pull/4581))
- Added transport actions support for extensions ([#4598](https://github.com/opensearch-project/OpenSearch/pull/4598/))
- Pass REST params and content to extensions ([#4633](https://github.com/opensearch-project/OpenSearch/pull/4633))
- Return consumed params and content from extensions ([#4705](https://github.com/opensearch-project/OpenSearch/pull/4705))

## [2.x]

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.extensions.rest;

import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestStatus;

import java.util.List;

/**
* A subclass of {@link BytesRestResponse} which also tracks consumed parameters and content.
*
* @opensearch.api
*/
public class ExtensionRestResponse extends BytesRestResponse {

private final List<String> consumedParams;
private final boolean contentConsumed;

/**
* Creates a new response based on {@link XContentBuilder}.
*
* @param request the REST request being responded to.
* @param status The REST status.
* @param builder The builder for the response.
*/
public ExtensionRestResponse(ExtensionRestRequest request, RestStatus status, XContentBuilder builder) {
super(status, builder);
this.consumedParams = request.consumedParams();
this.contentConsumed = request.isContentConsumed();
}

/**
* Creates a new plain text response.
*
* @param request the REST request being responded to.
* @param status The REST status.
* @param content A plain text response string.
*/
public ExtensionRestResponse(ExtensionRestRequest request, RestStatus status, String content) {
super(status, content);
this.consumedParams = request.consumedParams();
this.contentConsumed = request.isContentConsumed();
}

/**
* Creates a new plain text response.
*
* @param request the REST request being responded to.
* @param status The REST status.
* @param contentType The content type of the response string.
* @param content A response string.
*/
public ExtensionRestResponse(ExtensionRestRequest request, RestStatus status, String contentType, String content) {
super(status, contentType, content);
this.consumedParams = request.consumedParams();
this.contentConsumed = request.isContentConsumed();
}

/**
* Creates a binary response.
*
* @param request the REST request being responded to.
* @param status The REST status.
* @param contentType The content type of the response bytes.
* @param content Response bytes.
*/
public ExtensionRestResponse(ExtensionRestRequest request, RestStatus status, String contentType, byte[] content) {
super(status, contentType, content);
this.consumedParams = request.consumedParams();
this.contentConsumed = request.isContentConsumed();
}

/**
* Creates a binary response.
*
* @param request the REST request being responded to.
* @param status The REST status.
* @param contentType The content type of the response bytes.
* @param content Response bytes.
*/
public ExtensionRestResponse(ExtensionRestRequest request, RestStatus status, String contentType, BytesReference content) {
super(status, contentType, content);
this.consumedParams = request.consumedParams();
this.contentConsumed = request.isContentConsumed();
}

/**
* Gets the list of consumed parameters. These are needed to consume the parameters of the original request.
*
* @return the list of consumed params.
*/
public List<String> getConsumedParams() {
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trying to learn, how do regular HTTP servers respond?
Is returning back consumed params a standard?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how it is outside of OpenSearch. All I know is that OpenSearch spits out an error if you include parameters which aren't included.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which is pretty interesting, Is there a code pointer in OpenSearch. I would love to see why we put it in place.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it's part of the RestRequest class that we have to use to forward the relevant bits to the extension (and track the consuming).

Here's where it throws an exception on unconsumed params:

// validate unconsumed params, but we must exclude params used to format the response
// use a sorted set so the unconsumed parameters appear in a reliable sorted order
final SortedSet<String> unconsumedParams = request.unconsumedParams()
.stream()
.filter(p -> !responseParams().contains(p))
.collect(Collectors.toCollection(TreeSet::new));
// validate the non-response params
if (!unconsumedParams.isEmpty()) {
final Set<String> candidateParams = new HashSet<>();
candidateParams.addAll(request.consumedParams());
candidateParams.addAll(responseParams());
throw new IllegalArgumentException(unrecognized(request, unconsumedParams, candidateParams, "parameter"));
}

Copy link
Copy Markdown
Member Author

@dbwiddis dbwiddis Oct 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's the PR from 6 years ago with more info:
9a83ded

Looks like the intent is to provide a "you typed blah, did you mean bleh?" functionality

return consumedParams;
}

/**
* Reports whether content was consumed.
*
* @return true if the content was consumed, false otherwise.
*/
public boolean isContentConsumed() {
return contentConsumed;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,11 @@

import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.rest.BytesRestResponse;
import org.opensearch.rest.RestResponse;
import org.opensearch.rest.RestStatus;
import org.opensearch.transport.TransportResponse;

import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.List;
import java.util.Map;

Expand All @@ -32,16 +29,8 @@ public class RestExecuteOnExtensionResponse extends TransportResponse {
private String contentType;
private byte[] content;
private Map<String, List<String>> headers;

/**
* Instantiate this object with a status and response string.
*
* @param status The REST status.
* @param responseString The response content as a String.
*/
public RestExecuteOnExtensionResponse(RestStatus status, String responseString) {
this(status, BytesRestResponse.TEXT_CONTENT_TYPE, responseString.getBytes(StandardCharsets.UTF_8), Collections.emptyMap());
}
private List<String> consumedParams;
private boolean contentConsumed;

/**
* Instantiate this object with the components of a {@link RestResponse}.
Expand All @@ -50,33 +39,49 @@ public RestExecuteOnExtensionResponse(RestStatus status, String responseString)
* @param contentType The type of the content.
* @param content The content.
* @param headers The headers.
* @param consumedParams The consumed params.
* @param contentConsumed Whether content was consumed.
*/
public RestExecuteOnExtensionResponse(RestStatus status, String contentType, byte[] content, Map<String, List<String>> headers) {
public RestExecuteOnExtensionResponse(
RestStatus status,
String contentType,
byte[] content,
Map<String, List<String>> headers,
List<String> consumedParams,
boolean contentConsumed
) {
super();
setStatus(status);
setContentType(contentType);
setContent(content);
setHeaders(headers);
setConsumedParams(consumedParams);
setContentConsumed(contentConsumed);
}

/**
* Instantiate this object from a Transport Stream
* Instantiate this object from a Transport Stream.
*
* @param in The stream input.
* @throws IOException on transport failure.
*/
public RestExecuteOnExtensionResponse(StreamInput in) throws IOException {
setStatus(RestStatus.readFrom(in));
setStatus(in.readEnum(RestStatus.class));
setContentType(in.readString());
setContent(in.readByteArray());
setHeaders(in.readMapOfLists(StreamInput::readString, StreamInput::readString));
setConsumedParams(in.readStringList());
setContentConsumed(in.readBoolean());
}

@Override
public void writeTo(StreamOutput out) throws IOException {
RestStatus.writeTo(out, status);
out.writeEnum(status);
out.writeString(contentType);
out.writeByteArray(content);
out.writeMapOfLists(headers, StreamOutput::writeString, StreamOutput::writeString);
out.writeStringCollection(consumedParams);
out.writeBoolean(contentConsumed);
}

public RestStatus getStatus() {
Expand Down Expand Up @@ -110,4 +115,20 @@ public Map<String, List<String>> getHeaders() {
public void setHeaders(Map<String, List<String>> headers) {
this.headers = Map.copyOf(headers);
}

public List<String> getConsumedParams() {
return consumedParams;
}

public void setConsumedParams(List<String> consumedParams) {
this.consumedParams = consumedParams;
}

public boolean isContentConsumed() {
return contentConsumed;
}

public void setContentConsumed(boolean contentConsumed) {
this.contentConsumed = contentConsumed;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,9 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;

import static java.util.Collections.emptyList;
import static java.util.Collections.emptyMap;
import static java.util.Collections.unmodifiableList;

Expand All @@ -49,8 +47,6 @@ public class RestSendToExtensionAction extends BaseRestHandler {

private static final String SEND_TO_EXTENSION_ACTION = "send_to_extension_action";
private static final Logger logger = LogManager.getLogger(RestSendToExtensionAction.class);
private static final String CONSUMED_PARAMS_KEY = "extension.consumed.parameters";
private static final String CONSUMED_CONTENT_KEY = "extension.consumed.content";
// To replace with user identity see https://github.com/opensearch-project/OpenSearch/pull/4247
private static final Principal DEFAULT_PRINCIPAL = new Principal() {
@Override
Expand Down Expand Up @@ -124,7 +120,9 @@ public RestChannelConsumer prepareRequest(final RestRequest request, final NodeC
RestStatus.INTERNAL_SERVER_ERROR,
BytesRestResponse.TEXT_CONTENT_TYPE,
message.getBytes(StandardCharsets.UTF_8),
emptyMap()
emptyMap(),
emptyList(),
false
);
final CountDownLatch inProgressLatch = new CountDownLatch(1);
final TransportResponseHandler<RestExecuteOnExtensionResponse> restExecuteOnExtensionResponseHandler = new TransportResponseHandler<
Expand All @@ -141,25 +139,12 @@ public void handleResponse(RestExecuteOnExtensionResponse response) {
restExecuteOnExtensionResponse.setStatus(response.getStatus());
restExecuteOnExtensionResponse.setContentType(response.getContentType());
restExecuteOnExtensionResponse.setContent(response.getContent());
// Extract the consumed parameters and content from the header
Map<String, List<String>> headers = response.getHeaders();
List<String> consumedParams = headers.get(CONSUMED_PARAMS_KEY);
if (consumedParams != null) {
// consume each param
consumedParams.stream().forEach(p -> request.param(p));
}
List<String> consumedContent = headers.get(CONSUMED_CONTENT_KEY);
if (consumedContent != null) {
// conditionally consume content
if (consumedParams.stream().filter(c -> Boolean.parseBoolean(c)).count() > 0) {
request.content();
}
restExecuteOnExtensionResponse.setHeaders(response.getHeaders());
// Consume parameters and content
response.getConsumedParams().stream().forEach(p -> request.param(p));
if (response.isContentConsumed()) {
request.content();
}
Map<String, List<String>> headersWithoutConsumedParams = headers.entrySet()
.stream()
.filter(e -> !e.getKey().equals(CONSUMED_PARAMS_KEY))
.collect(Collectors.toMap(e -> e.getKey(), e -> e.getValue()));
restExecuteOnExtensionResponse.setHeaders(headersWithoutConsumedParams);
inProgressLatch.countDown();
}

Expand Down Expand Up @@ -204,11 +189,11 @@ public String executor() {
restExecuteOnExtensionResponse.getContentType(),
restExecuteOnExtensionResponse.getContent()
);
for (Entry<String, List<String>> headerEntry : restExecuteOnExtensionResponse.getHeaders().entrySet()) {
for (String value : headerEntry.getValue()) {
restResponse.addHeader(headerEntry.getKey(), value);
}
}
// No constructor that includes headers so we roll our own
restExecuteOnExtensionResponse.getHeaders()
.entrySet()
.stream()
.forEach(e -> { e.getValue().stream().forEach(v -> restResponse.addHeader(e.getKey(), v)); });

return channel -> channel.sendResponse(restResponse);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,30 +188,51 @@ public void testRestExecuteOnExtensionResponse() throws Exception {
String expectedResponse = "Test response";
byte[] expectedResponseBytes = expectedResponse.getBytes(StandardCharsets.UTF_8);

RestExecuteOnExtensionResponse response = new RestExecuteOnExtensionResponse(expectedStatus, expectedResponse);
RestExecuteOnExtensionResponse response = new RestExecuteOnExtensionResponse(
expectedStatus,
expectedContentType,
expectedResponseBytes,
Collections.emptyMap(),
Collections.emptyList(),
false
);

assertEquals(expectedStatus, response.getStatus());
assertEquals(expectedContentType, response.getContentType());
assertArrayEquals(expectedResponseBytes, response.getContent());
assertEquals(0, response.getHeaders().size());
assertEquals(0, response.getConsumedParams().size());
assertFalse(response.isContentConsumed());

String headerKey = "foo";
List<String> headerValueList = List.of("bar", "baz");
Map<String, List<String>> expectedHeaders = Map.of(headerKey, headerValueList);
List<String> expectedConsumedParams = List.of("foo", "bar");

response = new RestExecuteOnExtensionResponse(expectedStatus, expectedContentType, expectedResponseBytes, expectedHeaders);
response = new RestExecuteOnExtensionResponse(
expectedStatus,
expectedContentType,
expectedResponseBytes,
expectedHeaders,
expectedConsumedParams,
true
);

assertEquals(expectedStatus, response.getStatus());
assertEquals(expectedContentType, response.getContentType());
assertArrayEquals(expectedResponseBytes, response.getContent());

assertEquals(1, expectedHeaders.keySet().size());
assertTrue(expectedHeaders.containsKey(headerKey));
assertEquals(1, response.getHeaders().keySet().size());
assertTrue(response.getHeaders().containsKey(headerKey));

List<String> fooList = expectedHeaders.get(headerKey);
List<String> fooList = response.getHeaders().get(headerKey);
assertEquals(2, fooList.size());
assertTrue(fooList.containsAll(headerValueList));

assertEquals(2, response.getConsumedParams().size());
assertTrue(response.getConsumedParams().containsAll(expectedConsumedParams));
assertTrue(response.isContentConsumed());

try (BytesStreamOutput out = new BytesStreamOutput()) {
response.writeTo(out);
out.flush();
Expand All @@ -222,12 +243,16 @@ public void testRestExecuteOnExtensionResponse() throws Exception {
assertEquals(expectedContentType, response.getContentType());
assertArrayEquals(expectedResponseBytes, response.getContent());

assertEquals(1, expectedHeaders.keySet().size());
assertTrue(expectedHeaders.containsKey(headerKey));
assertEquals(1, response.getHeaders().keySet().size());
assertTrue(response.getHeaders().containsKey(headerKey));

fooList = expectedHeaders.get(headerKey);
fooList = response.getHeaders().get(headerKey);
assertEquals(2, fooList.size());
assertTrue(fooList.containsAll(headerValueList));

assertEquals(2, response.getConsumedParams().size());
assertTrue(response.getConsumedParams().containsAll(expectedConsumedParams));
assertTrue(response.isContentConsumed());
}
}
}
Expand Down
Loading