Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- Modified EnvironmentSettingsRequest to pass entire Settings object ([#4731](https://github.com/opensearch-project/OpenSearch/pull/4731))
- Added contentParser method to ExtensionRestRequest ([#4760](https://github.com/opensearch-project/OpenSearch/pull/4760))
- Enforce type safety for RegisterTransportActionsRequest([#4796](https://github.com/opensearch-project/OpenSearch/pull/4796))
- Enforce type safety for NamedWriteableRegistryParseRequest ([#4923](https://github.com/opensearch-project/OpenSearch/pull/4923))

## [2.x]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,15 @@
*/
public class NamedWriteableRegistryParseRequest extends TransportRequest {

private final Class categoryClass;
private final Class<? extends NamedWriteable> categoryClass;
private byte[] context;

/**
* @param categoryClass Class category for this parse request
* @param context StreamInput object to convert into a byte array and transport to the extension
* @throws IllegalArgumentException if context bytes could not be read
*/
public NamedWriteableRegistryParseRequest(Class categoryClass, StreamInput context) {
public NamedWriteableRegistryParseRequest(Class<? extends NamedWriteable> categoryClass, StreamInput context) {
try {
byte[] streamInputBytes = context.readAllBytes();
this.categoryClass = categoryClass;
Expand All @@ -42,10 +42,11 @@ public NamedWriteableRegistryParseRequest(Class categoryClass, StreamInput conte
* @param in StreamInput from which class fields are read from
* @throws IllegalArgumentException if the fully qualified class name is invalid and the class object cannot be generated at runtime
*/
@SuppressWarnings("unchecked")
public NamedWriteableRegistryParseRequest(StreamInput in) throws IOException {
super(in);
try {
this.categoryClass = Class.forName(in.readString());
this.categoryClass = (Class<? extends NamedWriteable>) Class.forName(in.readString());
this.context = in.readByteArray();
} catch (ClassNotFoundException e) {
throw new IllegalArgumentException("Category class definition not found", e);
Expand Down Expand Up @@ -85,7 +86,7 @@ public int hashCode() {
/**
* Returns the class instance of the category class sent over by the SDK
*/
public Class getCategoryClass() {
public Class<? extends NamedWriteable> getCategoryClass() {
return this.categoryClass;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,12 @@
*/
public class NamedWriteableRegistryResponse extends TransportResponse {

private final Map<String, Class> registry;
private final Map<String, Class<? extends NamedWriteable>> registry;

/**
* @param registry Map of writeable names and their associated category class
*/
public NamedWriteableRegistryResponse(Map<String, Class> registry) {
public NamedWriteableRegistryResponse(Map<String, Class<? extends NamedWriteable>> registry) {
this.registry = new HashMap<>(registry);
}

Expand All @@ -38,12 +38,13 @@ public NamedWriteableRegistryResponse(Map<String, Class> registry) {
public NamedWriteableRegistryResponse(StreamInput in) throws IOException {
super(in);
// Stream output for registry map begins with a variable integer that tells us the number of entries being sent across the wire
Map<String, Class> registry = new HashMap<>();
Map<String, Class<? extends NamedWriteable>> registry = new HashMap<>();
int registryEntryCount = in.readVInt();
for (int i = 0; i < registryEntryCount; i++) {
try {
String name = in.readString();
Class categoryClass = Class.forName(in.readString());
@SuppressWarnings("unchecked")
Class<? extends NamedWriteable> categoryClass = (Class<? extends NamedWriteable>) Class.forName(in.readString());
registry.put(name, categoryClass);
} catch (ClassNotFoundException e) {
throw new IllegalArgumentException("Category class definition not found", e);
Expand All @@ -57,7 +58,7 @@ public NamedWriteableRegistryResponse(StreamInput in) throws IOException {
public void writeTo(StreamOutput out) throws IOException {
// Stream out registry size prior to streaming out registry entries
out.writeVInt(this.registry.size());
for (Map.Entry<String, Class> entry : registry.entrySet()) {
for (Map.Entry<String, Class<? extends NamedWriteable>> entry : registry.entrySet()) {
out.writeString(entry.getKey()); // Unique named writeable name
out.writeString(entry.getValue().getName()); // Fully qualified category class name
}
Expand All @@ -84,7 +85,7 @@ public int hashCode() {
/**
* Returns a map of writeable names and their associated category class
*/
public Map<String, Class> getRegistry() {
public Map<String, Class<? extends NamedWriteable>> getRegistry() {
return Collections.unmodifiableMap(this.registry);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.io.stream.NamedWriteable;
import org.opensearch.extensions.ExtensionsOrchestrator.OpenSearchRequestType;
import org.opensearch.transport.TransportService;

Expand All @@ -29,7 +30,7 @@ public class ExtensionNamedWriteableRegistry {

private static final Logger logger = LogManager.getLogger(ExtensionNamedWriteableRegistry.class);

private Map<DiscoveryNode, Map<Class, Map<String, ExtensionReader>>> extensionNamedWriteableRegistry;
private Map<DiscoveryNode, Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>>> extensionNamedWriteableRegistry;
private List<DiscoveryExtension> extensionsInitializedList;
private TransportService transportService;

Expand All @@ -54,7 +55,8 @@ public void getNamedWriteables() {
// Retrieve named writeable registry entries from each extension
for (DiscoveryNode extensionNode : extensionsInitializedList) {
try {
Map<DiscoveryNode, Map<Class, Map<String, ExtensionReader>>> extensionRegistry = getNamedWriteables(extensionNode);
Map<DiscoveryNode, Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>>> extensionRegistry =
getNamedWriteables(extensionNode);
if (extensionRegistry.isEmpty() == false) {
this.extensionNamedWriteableRegistry.putAll(extensionRegistry);
}
Expand All @@ -74,8 +76,9 @@ public void getNamedWriteables() {
* @throws UnknownHostException if connection to the extension node failed
* @return A map of category classes and their associated names and readers for this discovery node
*/
private Map<DiscoveryNode, Map<Class, Map<String, ExtensionReader>>> getNamedWriteables(DiscoveryNode extensionNode)
throws UnknownHostException {
private Map<DiscoveryNode, Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>>> getNamedWriteables(
DiscoveryNode extensionNode
) throws UnknownHostException {
NamedWriteableRegistryResponseHandler namedWriteableRegistryResponseHandler = new NamedWriteableRegistryResponseHandler(
extensionNode,
transportService,
Expand Down Expand Up @@ -104,7 +107,7 @@ private Map<DiscoveryNode, Map<Class, Map<String, ExtensionReader>>> getNamedWri
* @throws IllegalArgumentException if there is no reader associated with the given category class and name
* @return A map of the discovery node and its associated extension reader
*/
public Map<DiscoveryNode, ExtensionReader> getExtensionReader(Class categoryClass, String name) {
public Map<DiscoveryNode, ExtensionReader> getExtensionReader(Class<? extends NamedWriteable> categoryClass, String name) {

ExtensionReader reader = null;
DiscoveryNode extension = null;
Expand Down Expand Up @@ -133,9 +136,11 @@ public Map<DiscoveryNode, ExtensionReader> getExtensionReader(Class categoryClas
* @param name Unique name identifying the Writeable object
* @return The extension reader
*/
private ExtensionReader getExtensionReader(DiscoveryNode extensionNode, Class categoryClass, String name) {
private ExtensionReader getExtensionReader(DiscoveryNode extensionNode, Class<? extends NamedWriteable> categoryClass, String name) {
ExtensionReader reader = null;
Map<Class, Map<String, ExtensionReader>> categoryMap = this.extensionNamedWriteableRegistry.get(extensionNode);
Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>> categoryMap = this.extensionNamedWriteableRegistry.get(
extensionNode
);
if (categoryMap != null) {
Map<String, ExtensionReader> readerMap = categoryMap.get(categoryClass);
if (readerMap != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.common.io.stream.NamedWriteable;
import org.opensearch.common.io.stream.NamedWriteableRegistryParseRequest;
import org.opensearch.common.io.stream.NamedWriteableRegistryResponse;
import org.opensearch.common.io.stream.StreamInput;
Expand All @@ -34,7 +35,7 @@
public class NamedWriteableRegistryResponseHandler implements TransportResponseHandler<NamedWriteableRegistryResponse> {
private static final Logger logger = LogManager.getLogger(NamedWriteableRegistryResponseHandler.class);

private final Map<DiscoveryNode, Map<Class, Map<String, ExtensionReader>>> extensionRegistry;
private final Map<DiscoveryNode, Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>>> extensionRegistry;
private final DiscoveryNode extensionNode;
private final TransportService transportService;
private final String requestType;
Expand All @@ -56,7 +57,7 @@ public NamedWriteableRegistryResponseHandler(DiscoveryNode extensionNode, Transp
/**
* @return A map of the given DiscoveryNode and its inner named writeable registry map
*/
public Map<DiscoveryNode, Map<Class, Map<String, ExtensionReader>>> getExtensionRegistry() {
public Map<DiscoveryNode, Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>>> getExtensionRegistry() {
return Collections.unmodifiableMap(this.extensionRegistry);
}

Expand All @@ -68,7 +69,8 @@ public Map<DiscoveryNode, Map<Class, Map<String, ExtensionReader>>> getExtension
* @param context StreamInput object to convert into a byte array and transport to the extension
* @throws UnknownHostException if connection to the extension node failed
*/
public void parseNamedWriteable(DiscoveryNode extensionNode, Class categoryClass, StreamInput context) throws UnknownHostException {
public void parseNamedWriteable(DiscoveryNode extensionNode, Class<? extends NamedWriteable> categoryClass, StreamInput context)
throws UnknownHostException {
NamedWriteableRegistryParseResponseHandler namedWriteableRegistryParseResponseHandler =
new NamedWriteableRegistryParseResponseHandler();
try {
Expand Down Expand Up @@ -98,16 +100,16 @@ public void handleResponse(NamedWriteableRegistryResponse response) {
if (response.getRegistry().isEmpty() == false) {

// Extension has sent over entries to register, initialize inner category map
Map<Class, Map<String, ExtensionReader>> categoryMap = new HashMap<>();
Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>> categoryMap = new HashMap<>();

// Reader map associated with this current category
Map<String, ExtensionReader> readers = null;
Class currentCategory = null;
Class<? extends NamedWriteable> currentCategory = null;

for (Map.Entry<String, Class> entry : response.getRegistry().entrySet()) {
for (Map.Entry<String, Class<? extends NamedWriteable>> entry : response.getRegistry().entrySet()) {

String name = entry.getKey();
Class categoryClass = entry.getValue();
Class<? extends NamedWriteable> categoryClass = entry.getValue();
if (currentCategory != categoryClass) {
// After first pass, readers and current category are set
if (currentCategory != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,7 @@ public void testNamedWriteableRegistryResponseHandler() throws Exception {
String requestType = ExtensionsOrchestrator.REQUEST_OPENSEARCH_NAMED_WRITEABLE_REGISTRY;

// Create response to pass to response handler
Map<String, Class> responseRegistry = new HashMap<>();
Map<String, Class<? extends NamedWriteable>> responseRegistry = new HashMap<>();
responseRegistry.put(Example.NAME, Example.class);
NamedWriteableRegistryResponse response = new NamedWriteableRegistryResponse(responseRegistry);

Expand All @@ -761,10 +761,11 @@ public void testNamedWriteableRegistryResponseHandler() throws Exception {
responseHandler.handleResponse(response);

// Ensure that response entries have been processed correctly into their respective maps
Map<DiscoveryNode, Map<Class, Map<String, ExtensionReader>>> extensionsRegistry = responseHandler.getExtensionRegistry();
Map<DiscoveryNode, Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>>> extensionsRegistry = responseHandler
.getExtensionRegistry();
assertEquals(extensionsRegistry.size(), 1);

Map<Class, Map<String, ExtensionReader>> categoryMap = extensionsRegistry.get(extensionNode);
Map<Class<? extends NamedWriteable>, Map<String, ExtensionReader>> categoryMap = extensionsRegistry.get(extensionNode);
assertEquals(categoryMap.size(), 1);

Map<String, ExtensionReader> readerMap = categoryMap.get(Example.class);
Expand Down Expand Up @@ -798,7 +799,7 @@ public void testParseNamedWriteables() throws Exception {
String requestType = ExtensionsOrchestrator.REQUEST_OPENSEARCH_PARSE_NAMED_WRITEABLE;
List<DiscoveryExtension> extensionsList = new ArrayList<>(extensionsOrchestrator.extensionIdMap.values());
DiscoveryNode extensionNode = extensionsList.get(0);
Class categoryClass = Example.class;
Class<? extends NamedWriteable> categoryClass = Example.class;

// convert context into an input stream then stream input for mock
byte[] context = new byte[0];
Expand Down