Skip to content

Commit 1d66af3

Browse files
authored
Fix: register mulitple extensions. (opensearch-project#10256)
* Fix: register mulitple extensions. Signed-off-by: dblock <dblock@amazon.com> * Updated CHANGELOG. Signed-off-by: dblock <dblock@amazon.com> * Added tests. Signed-off-by: dblock <dblock@amazon.com> --------- Signed-off-by: dblock <dblock@amazon.com>
1 parent 9d0db5e commit 1d66af3

8 files changed

Lines changed: 176 additions & 24 deletions

File tree

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
131131
- Fix concurrent search NPE when track_total_hits, terminate_after and size=0 are used ([#10082](https://github.com/opensearch-project/OpenSearch/pull/10082))
132132
- Fix remove ingest processor handing ignore_missing parameter not correctly ([10089](https://github.com/opensearch-project/OpenSearch/pull/10089))
133133
- Fix circular dependency in Settings initialization ([10194](https://github.com/opensearch-project/OpenSearch/pull/10194))
134+
- Fix registration and initialization of multiple extensions ([10256](https://github.com/opensearch-project/OpenSearch/pull/10256))
134135

135136
### Security
136137

server/src/main/java/org/opensearch/extensions/ExtensionsManager.java

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ private void registerRequestHandler(DynamicActionRegistry dynamicActionRegistry)
300300
* Loads a single extension
301301
* @param extension The extension to be loaded
302302
*/
303-
public void loadExtension(Extension extension) throws IOException {
303+
public DiscoveryExtensionNode loadExtension(Extension extension) throws IOException {
304304
validateExtension(extension);
305305
DiscoveryExtensionNode discoveryExtensionNode = new DiscoveryExtensionNode(
306306
extension.getName(),
@@ -314,6 +314,12 @@ public void loadExtension(Extension extension) throws IOException {
314314
extensionIdMap.put(extension.getUniqueId(), discoveryExtensionNode);
315315
extensionSettingsMap.put(extension.getUniqueId(), extension);
316316
logger.info("Loaded extension with uniqueId " + extension.getUniqueId() + ": " + extension);
317+
return discoveryExtensionNode;
318+
}
319+
320+
public void initializeExtension(Extension extension) throws IOException {
321+
DiscoveryExtensionNode node = loadExtension(extension);
322+
initializeExtensionNode(node);
317323
}
318324

319325
private void validateField(String fieldName, String value) throws IOException {
@@ -340,11 +346,11 @@ private void validateExtension(Extension extension) throws IOException {
340346
*/
341347
public void initialize() {
342348
for (DiscoveryExtensionNode extension : extensionIdMap.values()) {
343-
initializeExtension(extension);
349+
initializeExtensionNode(extension);
344350
}
345351
}
346352

347-
private void initializeExtension(DiscoveryExtensionNode extension) {
353+
public void initializeExtensionNode(DiscoveryExtensionNode extensionNode) {
348354

349355
final CompletableFuture<InitializeExtensionResponse> inProgressFuture = new CompletableFuture<>();
350356
final TransportResponseHandler<InitializeExtensionResponse> initializeExtensionResponseHandler = new TransportResponseHandler<
@@ -384,7 +390,8 @@ public String executor() {
384390
transportService.getThreadPool().generic().execute(new AbstractRunnable() {
385391
@Override
386392
public void onFailure(Exception e) {
387-
extensionIdMap.remove(extension.getId());
393+
logger.warn("Error registering extension: " + extensionNode.getId(), e);
394+
extensionIdMap.remove(extensionNode.getId());
388395
if (e.getCause() instanceof ConnectTransportException) {
389396
logger.info("No response from extension to request.", e);
390397
throw (ConnectTransportException) e.getCause();
@@ -399,11 +406,11 @@ public void onFailure(Exception e) {
399406

400407
@Override
401408
protected void doRun() throws Exception {
402-
transportService.connectToExtensionNode(extension);
409+
transportService.connectToExtensionNode(extensionNode);
403410
transportService.sendRequest(
404-
extension,
411+
extensionNode,
405412
REQUEST_EXTENSION_ACTION_NAME,
406-
new InitializeExtensionRequest(transportService.getLocalNode(), extension, issueServiceAccount(extension)),
413+
new InitializeExtensionRequest(transportService.getLocalNode(), extensionNode, issueServiceAccount(extensionNode)),
407414
initializeExtensionResponseHandler
408415
);
409416
}

server/src/main/java/org/opensearch/extensions/rest/RestActionsRequestHandler.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@ public TransportResponse handleRegisterRestActionsRequest(
6262
DynamicActionRegistry dynamicActionRegistry
6363
) throws Exception {
6464
DiscoveryExtensionNode discoveryExtensionNode = extensionIdMap.get(restActionsRequest.getUniqueId());
65+
if (discoveryExtensionNode == null) {
66+
throw new IllegalStateException("Missing extension node for " + restActionsRequest.getUniqueId());
67+
}
6568
RestHandler handler = new RestSendToExtensionAction(
6669
restActionsRequest,
6770
discoveryExtensionNode,

server/src/main/java/org/opensearch/extensions/rest/RestInitializeExtensionAction.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,7 @@ public RestChannelConsumer prepareRequest(RestRequest request, NodeClient client
159159
extAdditionalSettings
160160
);
161161
try {
162-
extensionsManager.loadExtension(extension);
163-
extensionsManager.initialize();
162+
extensionsManager.initializeExtension(extension);
164163
} catch (CompletionException e) {
165164
Throwable cause = e.getCause();
166165
if (cause instanceof TimeoutException) {

server/src/main/java/org/opensearch/extensions/rest/RestSendToExtensionAction.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ public RestSendToExtensionAction(
150150

151151
@Override
152152
public String getName() {
153-
return SEND_TO_EXTENSION_ACTION;
153+
return this.discoveryExtensionNode.getId() + ":" + SEND_TO_EXTENSION_ACTION;
154154
}
155155

156156
@Override

server/src/test/java/org/opensearch/extensions/ExtensionsManagerTests.java

Lines changed: 143 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
import org.opensearch.core.common.transport.TransportAddress;
3737
import org.opensearch.core.indices.breaker.NoneCircuitBreakerService;
3838
import org.opensearch.core.transport.TransportResponse;
39+
import org.opensearch.discovery.InitializeExtensionRequest;
3940
import org.opensearch.env.Environment;
4041
import org.opensearch.env.EnvironmentSettingsResponse;
4142
import org.opensearch.extensions.ExtensionsSettings.Extension;
@@ -77,6 +78,7 @@
7778
import static org.mockito.ArgumentMatchers.any;
7879
import static org.mockito.ArgumentMatchers.anyBoolean;
7980
import static org.mockito.ArgumentMatchers.anyString;
81+
import static org.mockito.Mockito.doNothing;
8082
import static org.mockito.Mockito.mock;
8183
import static org.mockito.Mockito.spy;
8284
import static org.mockito.Mockito.times;
@@ -409,19 +411,94 @@ public void testInitialize() throws Exception {
409411
)
410412
);
411413

412-
// Test needs to be changed to mock the connection between the local node and an extension. Assert statment is commented out for
413-
// now.
414+
// Test needs to be changed to mock the connection between the local node and an extension.
414415
// Link to issue: https://github.com/opensearch-project/OpenSearch/issues/4045
415416
// mockLogAppender.assertAllExpectationsMatched();
416417
}
417418
}
418419

420+
public void testInitializeExtension() throws Exception {
421+
ExtensionsManager extensionsManager = new ExtensionsManager(Set.of(), identityService);
422+
423+
TransportService mockTransportService = spy(
424+
new TransportService(
425+
Settings.EMPTY,
426+
mock(Transport.class),
427+
threadPool,
428+
TransportService.NOOP_TRANSPORT_INTERCEPTOR,
429+
x -> null,
430+
null,
431+
Collections.emptySet(),
432+
NoopTracer.INSTANCE
433+
)
434+
);
435+
436+
doNothing().when(mockTransportService).connectToExtensionNode(any(DiscoveryExtensionNode.class));
437+
438+
doNothing().when(mockTransportService)
439+
.sendRequest(any(DiscoveryExtensionNode.class), anyString(), any(InitializeExtensionRequest.class), any());
440+
441+
extensionsManager.initializeServicesAndRestHandler(
442+
actionModule,
443+
settingsModule,
444+
mockTransportService,
445+
clusterService,
446+
settings,
447+
client,
448+
identityService
449+
);
450+
451+
Extension firstExtension = new Extension(
452+
"firstExtension",
453+
"uniqueid1",
454+
"127.0.0.0",
455+
"9301",
456+
"0.0.7",
457+
"2.0.0",
458+
"2.0.0",
459+
List.of(),
460+
null
461+
);
462+
463+
extensionsManager.initializeExtension(firstExtension);
464+
465+
Extension secondExtension = new Extension(
466+
"secondExtension",
467+
"uniqueid2",
468+
"127.0.0.0",
469+
"9301",
470+
"0.0.7",
471+
"2.0.0",
472+
"2.0.0",
473+
List.of(),
474+
null
475+
);
476+
477+
extensionsManager.initializeExtension(secondExtension);
478+
479+
ThreadPool.terminate(threadPool, 3, TimeUnit.SECONDS);
480+
481+
verify(mockTransportService, times(2)).connectToExtensionNode(any(DiscoveryExtensionNode.class));
482+
483+
verify(mockTransportService, times(2)).sendRequest(
484+
any(DiscoveryExtensionNode.class),
485+
anyString(),
486+
any(InitializeExtensionRequest.class),
487+
any()
488+
);
489+
}
490+
419491
public void testHandleRegisterRestActionsRequest() throws Exception {
420492

421493
ExtensionsManager extensionsManager = new ExtensionsManager(Set.of(), identityService);
422494
initialize(extensionsManager);
423495

424496
String uniqueIdStr = "uniqueid1";
497+
498+
extensionsManager.loadExtension(
499+
new Extension("firstExtension", uniqueIdStr, "127.0.0.0", "9300", "0.0.7", "3.0.0", "3.0.0", List.of(), null)
500+
);
501+
425502
List<String> actionsList = List.of("GET /foo foo", "PUT /bar bar", "POST /baz baz");
426503
List<String> deprecatedActionsList = List.of("GET /deprecated/foo foo_deprecated", "It's deprecated!");
427504
RegisterRestActionsRequest registerActionsRequest = new RegisterRestActionsRequest(uniqueIdStr, actionsList, deprecatedActionsList);
@@ -431,6 +508,58 @@ public void testHandleRegisterRestActionsRequest() throws Exception {
431508
assertTrue(((AcknowledgedResponse) response).getStatus());
432509
}
433510

511+
public void testHandleRegisterRestActionsRequestRequiresDiscoveryNode() throws Exception {
512+
513+
ExtensionsManager extensionsManager = new ExtensionsManager(Set.of(), identityService);
514+
initialize(extensionsManager);
515+
516+
RegisterRestActionsRequest registerActionsRequest = new RegisterRestActionsRequest("uniqueId1", List.of(), List.of());
517+
518+
expectThrows(
519+
IllegalStateException.class,
520+
() -> extensionsManager.getRestActionsRequestHandler()
521+
.handleRegisterRestActionsRequest(registerActionsRequest, actionModule.getDynamicActionRegistry())
522+
);
523+
}
524+
525+
public void testHandleRegisterRestActionsRequestMultiple() throws Exception {
526+
527+
ExtensionsManager extensionsManager = new ExtensionsManager(Set.of(), identityService);
528+
initialize(extensionsManager);
529+
530+
List<String> actionsList = List.of("GET /foo foo", "PUT /bar bar", "POST /baz baz");
531+
List<String> deprecatedActionsList = List.of("GET /deprecated/foo foo_deprecated", "It's deprecated!");
532+
for (int i = 0; i < 2; i++) {
533+
String uniqueIdStr = "uniqueid-%d" + i;
534+
535+
Set<Setting<?>> additionalSettings = extAwarePlugin.getExtensionSettings().stream().collect(Collectors.toSet());
536+
ExtensionScopedSettings extensionScopedSettings = new ExtensionScopedSettings(additionalSettings);
537+
Extension firstExtension = new Extension(
538+
"Extension %s" + i,
539+
uniqueIdStr,
540+
"127.0.0.0",
541+
"9300",
542+
"0.0.7",
543+
"3.0.0",
544+
"3.0.0",
545+
List.of(),
546+
extensionScopedSettings
547+
);
548+
549+
extensionsManager.loadExtension(firstExtension);
550+
551+
RegisterRestActionsRequest registerActionsRequest = new RegisterRestActionsRequest(
552+
uniqueIdStr,
553+
actionsList,
554+
deprecatedActionsList
555+
);
556+
TransportResponse response = extensionsManager.getRestActionsRequestHandler()
557+
.handleRegisterRestActionsRequest(registerActionsRequest, actionModule.getDynamicActionRegistry());
558+
assertEquals(AcknowledgedResponse.class, response.getClass());
559+
assertTrue(((AcknowledgedResponse) response).getStatus());
560+
}
561+
}
562+
434563
public void testHandleRegisterSettingsRequest() throws Exception {
435564
ExtensionsManager extensionsManager = new ExtensionsManager(Set.of(), identityService);
436565
initialize(extensionsManager);
@@ -452,6 +581,9 @@ public void testHandleRegisterRestActionsRequestWithInvalidMethod() throws Excep
452581
initialize(extensionsManager);
453582

454583
String uniqueIdStr = "uniqueid1";
584+
extensionsManager.loadExtension(
585+
new Extension("firstExtension", uniqueIdStr, "127.0.0.0", "9300", "0.0.7", "3.0.0", "3.0.0", List.of(), null)
586+
);
455587
List<String> actionsList = List.of("FOO /foo", "PUT /bar", "POST /baz");
456588
List<String> deprecatedActionsList = List.of("GET /deprecated/foo", "It's deprecated!");
457589
RegisterRestActionsRequest registerActionsRequest = new RegisterRestActionsRequest(uniqueIdStr, actionsList, deprecatedActionsList);
@@ -467,6 +599,9 @@ public void testHandleRegisterRestActionsRequestWithInvalidDeprecatedMethod() th
467599
initialize(extensionsManager);
468600

469601
String uniqueIdStr = "uniqueid1";
602+
extensionsManager.loadExtension(
603+
new Extension("firstExtension", uniqueIdStr, "127.0.0.0", "9300", "0.0.7", "3.0.0", "3.0.0", List.of(), null)
604+
);
470605
List<String> actionsList = List.of("GET /foo", "PUT /bar", "POST /baz");
471606
List<String> deprecatedActionsList = List.of("FOO /deprecated/foo", "It's deprecated!");
472607
RegisterRestActionsRequest registerActionsRequest = new RegisterRestActionsRequest(uniqueIdStr, actionsList, deprecatedActionsList);
@@ -481,6 +616,9 @@ public void testHandleRegisterRestActionsRequestWithInvalidUri() throws Exceptio
481616
ExtensionsManager extensionsManager = new ExtensionsManager(Set.of(), identityService);
482617
initialize(extensionsManager);
483618
String uniqueIdStr = "uniqueid1";
619+
extensionsManager.loadExtension(
620+
new Extension("firstExtension", uniqueIdStr, "127.0.0.0", "9300", "0.0.7", "3.0.0", "3.0.0", List.of(), null)
621+
);
484622
List<String> actionsList = List.of("GET", "PUT /bar", "POST /baz");
485623
List<String> deprecatedActionsList = List.of("GET /deprecated/foo", "It's deprecated!");
486624
RegisterRestActionsRequest registerActionsRequest = new RegisterRestActionsRequest(uniqueIdStr, actionsList, deprecatedActionsList);
@@ -495,6 +633,9 @@ public void testHandleRegisterRestActionsRequestWithInvalidDeprecatedUri() throw
495633
ExtensionsManager extensionsManager = new ExtensionsManager(Set.of(), identityService);
496634
initialize(extensionsManager);
497635
String uniqueIdStr = "uniqueid1";
636+
extensionsManager.loadExtension(
637+
new Extension("firstExtension", uniqueIdStr, "127.0.0.0", "9300", "0.0.7", "3.0.0", "3.0.0", List.of(), null)
638+
);
498639
List<String> actionsList = List.of("GET /foo", "PUT /bar", "POST /baz");
499640
List<String> deprecatedActionsList = List.of("GET", "It's deprecated!");
500641
RegisterRestActionsRequest registerActionsRequest = new RegisterRestActionsRequest(uniqueIdStr, actionsList, deprecatedActionsList);

server/src/test/java/org/opensearch/extensions/rest/RestInitializeExtensionActionTests.java

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,9 @@
1919
import org.opensearch.core.indices.breaker.NoneCircuitBreakerService;
2020
import org.opensearch.core.rest.RestStatus;
2121
import org.opensearch.core.xcontent.MediaTypeRegistry;
22+
import org.opensearch.extensions.DiscoveryExtensionNode;
2223
import org.opensearch.extensions.ExtensionsManager;
23-
import org.opensearch.extensions.ExtensionsSettings;
24+
import org.opensearch.extensions.ExtensionsSettings.Extension;
2425
import org.opensearch.identity.IdentityService;
2526
import org.opensearch.rest.RestRequest;
2627
import org.opensearch.telemetry.tracing.noop.NoopTracer;
@@ -160,8 +161,8 @@ public void testRestInitializeExtensionActionResponseWithAdditionalSettings() th
160161

161162
// optionally, you can stub out some methods:
162163
when(spy.getAdditionalSettings()).thenCallRealMethod();
163-
Mockito.doCallRealMethod().when(spy).loadExtension(any(ExtensionsSettings.Extension.class));
164-
Mockito.doNothing().when(spy).initialize();
164+
Mockito.doCallRealMethod().when(spy).loadExtension(any(Extension.class));
165+
Mockito.doNothing().when(spy).initializeExtensionNode(any(DiscoveryExtensionNode.class));
165166
RestInitializeExtensionAction restInitializeExtensionAction = new RestInitializeExtensionAction(spy);
166167
final String content = "{\"name\":\"ad-extension\",\"uniqueId\":\"ad-extension\",\"hostAddress\":\"127.0.0.1\","
167168
+ "\"port\":\"4532\",\"version\":\"1.0\",\"opensearchVersion\":\""
@@ -177,10 +178,10 @@ public void testRestInitializeExtensionActionResponseWithAdditionalSettings() th
177178
FakeRestChannel channel = new FakeRestChannel(request, false, 0);
178179
restInitializeExtensionAction.handleRequest(request, channel, null);
179180

180-
assertEquals(channel.capturedResponse().status(), RestStatus.ACCEPTED);
181+
assertEquals(RestStatus.ACCEPTED, channel.capturedResponse().status());
181182
assertTrue(channel.capturedResponse().content().utf8ToString().contains("A request to initialize an extension has been sent."));
182183

183-
Optional<ExtensionsSettings.Extension> extension = spy.lookupExtensionSettingsById("ad-extension");
184+
Optional<Extension> extension = spy.lookupExtensionSettingsById("ad-extension");
184185
assertTrue(extension.isPresent());
185186
assertEquals(true, extension.get().getAdditionalSettings().get(boolSetting));
186187
assertEquals("customSetting", extension.get().getAdditionalSettings().get(stringSetting));
@@ -210,8 +211,8 @@ public void testRestInitializeExtensionActionResponseWithAdditionalSettingsUsing
210211

211212
// optionally, you can stub out some methods:
212213
when(spy.getAdditionalSettings()).thenCallRealMethod();
213-
Mockito.doCallRealMethod().when(spy).loadExtension(any(ExtensionsSettings.Extension.class));
214-
Mockito.doNothing().when(spy).initialize();
214+
Mockito.doCallRealMethod().when(spy).loadExtension(any(Extension.class));
215+
Mockito.doNothing().when(spy).initializeExtensionNode(any(DiscoveryExtensionNode.class));
215216
RestInitializeExtensionAction restInitializeExtensionAction = new RestInitializeExtensionAction(spy);
216217
final String content = "{\"name\":\"ad-extension\",\"uniqueId\":\"ad-extension\",\"hostAddress\":\"127.0.0.1\","
217218
+ "\"port\":\"4532\",\"version\":\"1.0\",\"opensearchVersion\":\""
@@ -227,10 +228,10 @@ public void testRestInitializeExtensionActionResponseWithAdditionalSettingsUsing
227228
FakeRestChannel channel = new FakeRestChannel(request, false, 0);
228229
restInitializeExtensionAction.handleRequest(request, channel, null);
229230

230-
assertEquals(channel.capturedResponse().status(), RestStatus.ACCEPTED);
231+
assertEquals(RestStatus.ACCEPTED, channel.capturedResponse().status());
231232
assertTrue(channel.capturedResponse().content().utf8ToString().contains("A request to initialize an extension has been sent."));
232233

233-
Optional<ExtensionsSettings.Extension> extension = spy.lookupExtensionSettingsById("ad-extension");
234+
Optional<Extension> extension = spy.lookupExtensionSettingsById("ad-extension");
234235
assertTrue(extension.isPresent());
235236
assertEquals(false, extension.get().getAdditionalSettings().get(boolSetting));
236237
assertEquals("default", extension.get().getAdditionalSettings().get(stringSetting));

0 commit comments

Comments
 (0)