Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

package org.elasticsearch.plugins;

import org.apache.lucene.util.Constants;
import org.elasticsearch.common.Strings;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.test.compiler.InMemoryJavaCompiler;
import org.elasticsearch.test.jar.JarUtils;

import java.io.IOException;
import java.lang.module.Configuration;
import java.lang.module.ModuleDescriptor;
import java.lang.module.ModuleFinder;
import java.net.MalformedURLException;
import java.net.URL;
Expand All @@ -40,6 +40,8 @@
@ESTestCase.WithoutSecurityManager
public class UberModuleClassLoaderTests extends ESTestCase {

private static Set<URLClassLoader> loaders = new HashSet<>();

/**
* Test the loadClass method, which is the real entrypoint for users of the classloader
*/
Expand Down Expand Up @@ -466,51 +468,116 @@ public static String demo() {
JarUtils.createJarWithEntries(jar, jarEntries);
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/91609")
public void testServiceLoadingWithOptionalDependencies() throws Exception {
assumeFalse("Tests frequently fail on Windows", Constants.WINDOWS);
try (UberModuleClassLoader loader = getServiceTestLoader(true)) {

// check module descriptor
ModuleDescriptor synthetic = loader.getLayer().findModule("synthetic").orElseThrow().getDescriptor();

assertThat(
synthetic.uses(),
equalTo(
Set.of("p.required.LetterService", "p.optional.AnimalService", "q.jar.one.NumberService", "q.jar.two.FooBarService")
)
);
// the descriptor model uses a list ordering that we don't guarantee, so we convert the provider list to maps and sets
Map<String, Set<String>> serviceProviders = synthetic.provides()
.stream()
.collect(Collectors.toMap(ModuleDescriptor.Provides::service, provides -> new HashSet<>(provides.providers())));
assertThat(
serviceProviders,
equalTo(
Map.of(
"p.required.LetterService",
Set.of("q.jar.one.JarOneProvider", "q.jar.two.JarTwoProvider"),
// optional dependencies found and added
"p.optional.AnimalService",
Set.of("q.jar.one.JarOneOptionalProvider", "q.jar.two.JarTwoOptionalProvider"),
"q.jar.one.NumberService",
Set.of("q.jar.one.JarOneProvider", "q.jar.two.JarTwoProvider"),
"q.jar.two.FooBarService",
Set.of("q.jar.two.JarTwoProvider")
)
)
);

// Now let's make sure the module system lets us load available services
Class<?> serviceCallerClass = loader.loadClass("q.caller.ServiceCaller");
Object instance = serviceCallerClass.getConstructor().newInstance();

var requiredParent = serviceCallerClass.getMethod("callServiceFromRequiredParent");
assertThat(requiredParent.invoke(instance), equalTo("AB"));
var optionalParent = serviceCallerClass.getMethod("callServiceFromOptionalParent");
assertThat(optionalParent.invoke(instance), equalTo("catdog"));
assertThat(optionalParent.invoke(instance), equalTo("catdog")); // our service provider worked
var modular = serviceCallerClass.getMethod("callServiceFromModularJar");
assertThat(modular.invoke(instance), equalTo("12"));
var nonModular = serviceCallerClass.getMethod("callServiceFromNonModularJar");
assertThat(nonModular.invoke(instance), equalTo("foo"));
} finally {
for (URLClassLoader loader : loaders) {
loader.close();
}
loaders = new HashSet<>();
}
}

@AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/91609")
public void testServiceLoadingWithoutOptionalDependencies() throws Exception {
assumeFalse("Tests frequently fail on Windows", Constants.WINDOWS);
try (UberModuleClassLoader loader = getServiceTestLoader(false)) {

// check module descriptor
ModuleDescriptor synthetic = loader.getLayer().findModule("synthetic").orElseThrow().getDescriptor();
assertThat(synthetic.uses(), equalTo(Set.of("p.required.LetterService", "q.jar.one.NumberService", "q.jar.two.FooBarService")));
// the descriptor model uses a list ordering that we don't guarantee, so we convert the provider list to maps and sets
Map<String, Set<String>> serviceProviders = synthetic.provides()
.stream()
.collect(Collectors.toMap(ModuleDescriptor.Provides::service, provides -> new HashSet<>(provides.providers())));
assertThat(
serviceProviders,
equalTo(
Map.of(
"p.required.LetterService",
Set.of("q.jar.one.JarOneProvider", "q.jar.two.JarTwoProvider"),
"q.jar.one.NumberService",
Set.of("q.jar.one.JarOneProvider", "q.jar.two.JarTwoProvider"),
"q.jar.two.FooBarService",
Set.of("q.jar.two.JarTwoProvider")
)
)
);

// Now let's make sure the module system lets us load available services
Class<?> serviceCallerClass = loader.loadClass("q.caller.ServiceCaller");
Object instance = serviceCallerClass.getConstructor().newInstance();

var requiredParent = serviceCallerClass.getMethod("callServiceFromRequiredParent");
assertThat(requiredParent.invoke(instance), equalTo("AB"));
var optionalParent = serviceCallerClass.getMethod("callServiceFromOptionalParent");
// service not found at runtime, so we don't try to load the provider
assertThat(optionalParent.invoke(instance), equalTo("Optional AnimalService dependency not present at runtime."));
var modular = serviceCallerClass.getMethod("callServiceFromModularJar");
assertThat(modular.invoke(instance), equalTo("12"));
var nonModular = serviceCallerClass.getMethod("callServiceFromNonModularJar");
assertThat(nonModular.invoke(instance), equalTo("foo"));
} finally {
for (URLClassLoader loader : loaders) {
loader.close();
}
loaders = new HashSet<>();
}
}

/**
* We need to create a test scenario that covers four service loading situations:
/*
* A class in our ubermodule may use SPI to load a service. Our test scenario needs to work out the following four
* conditions:
*
* 1. Service defined in package exported in parent layer.
* 2. Service defined in a compile-time dependency, optionally present at runtime.
* 3. Service defined in modular jar in uberjar
* 4. Service defined in non-modular jar in uberjar
*
* In all these cases, our ubermodule should declare that it uses each service *available at runtime*, and that
* it provides these services with the correct providers.
*
* We create a jar for each scenario, plus "service caller" jar with a demo class, then
* create an UberModuleClassLoader for the relevant jars.
*/
Expand All @@ -525,11 +592,18 @@ private static UberModuleClassLoader getServiceTestLoader(boolean includeOptiona
.configuration()
.resolve(parentModuleFinder, ModuleFinder.of(), moduleNames);

ModuleLayer parentLayer = ModuleLayer.defineModulesWithOneLoader(
parentLayerConfiguration,
List.of(ModuleLayer.boot()),
UberModuleClassLoaderTests.class.getClassLoader()
).layer();
URLClassLoader parentLoader = new URLClassLoader(new URL[] { pathToUrlUnchecked(parentJar) });
loaders.add(parentLoader);
URLClassLoader optionalLoader = new URLClassLoader(new URL[] { pathToUrlUnchecked(optionalJar) }, parentLoader);
loaders.add(optionalLoader);
ModuleLayer parentLayer = ModuleLayer.defineModules(parentLayerConfiguration, List.of(ModuleLayer.boot()), (String moduleName) -> {
if (moduleName.equals("p.required")) {
return parentLoader;
} else if (includeOptionalDeps && moduleName.equals("p.optional")) {
return optionalLoader;
}
return null;
}).layer();

// jars for the ubermodule
Path modularJar = createModularizedJarForBundle(libDir);
Expand All @@ -538,7 +612,7 @@ private static UberModuleClassLoader getServiceTestLoader(boolean includeOptiona

Set<Path> jarPaths = new HashSet<>(Set.of(modularJar, nonModularJar, serviceCallerJar));
return UberModuleClassLoader.getInstance(
parentLayer.findLoader("p.required"),
parentLayer.findLoader(includeOptionalDeps ? "p.optional" : "p.required"),
parentLayer,
"synthetic",
jarPaths.stream().map(UberModuleClassLoaderTests::pathToUrlUnchecked).collect(Collectors.toSet()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ private static class FileManagerWrapper extends ForwardingJavaFileManager<JavaFi
this.files = List.of(file);
}

public List<InMemoryJavaFileObject> getFiles() {
return this.files;
}

@Override
public JavaFileObject getJavaFileForOutput(Location location, String className, Kind kind, FileObject sibling) throws IOException {
return files.stream()
Expand All @@ -129,11 +133,15 @@ static Supplier<IOException> newIOException(String className, List<InMemoryJavaF
*/
public static Map<String, byte[]> compile(Map<String, CharSequence> sources, String... options) {
var files = sources.entrySet().stream().map(e -> new InMemoryJavaFileObject(e.getKey(), e.getValue())).toList();
CompilationTask task = getCompilationTask(files, options);

boolean result = PrivilegedOperations.compilationTaskCall(task);
if (result == false) {
throw new RuntimeException("Could not compile " + sources.entrySet().stream().toList());
try (FileManagerWrapper wrapper = new FileManagerWrapper(files)) {
CompilationTask task = getCompilationTask(wrapper, options);

boolean result = PrivilegedOperations.compilationTaskCall(task);
if (result == false) {
throw new RuntimeException("Could not compile " + sources.entrySet().stream().toList());
}
} catch (IOException e) {
throw new RuntimeException("Could not close file manager for " + sources.entrySet().stream().toList());
}

return files.stream().collect(Collectors.toMap(InMemoryJavaFileObject::getClassName, InMemoryJavaFileObject::getByteCode));
Expand All @@ -150,25 +158,24 @@ public static Map<String, byte[]> compile(Map<String, CharSequence> sources, Str
*/
public static byte[] compile(String className, CharSequence sourceCode, String... options) {
InMemoryJavaFileObject file = new InMemoryJavaFileObject(className, sourceCode);
CompilationTask task = getCompilationTask(file, options);

boolean result = PrivilegedOperations.compilationTaskCall(task);
if (result == false) {
throw new RuntimeException("Could not compile " + className + " with source code " + sourceCode);
try (FileManagerWrapper wrapper = new FileManagerWrapper(file)) {
CompilationTask task = getCompilationTask(wrapper, options);

boolean result = PrivilegedOperations.compilationTaskCall(task);
if (result == false) {
throw new RuntimeException("Could not compile " + className + " with source code " + sourceCode);
}
} catch (IOException e) {
throw new RuntimeException("Could not close file handler for class " + className + " with source code " + sourceCode);
}

return file.getByteCode();
}

private static JavaCompiler getCompiler() {
return ToolProvider.getSystemJavaCompiler();
}

private static CompilationTask getCompilationTask(List<InMemoryJavaFileObject> files, String... options) {
return getCompiler().getTask(null, new FileManagerWrapper(files), null, List.of(options), null, files);
}

private static CompilationTask getCompilationTask(InMemoryJavaFileObject file, String... options) {
return getCompiler().getTask(null, new FileManagerWrapper(file), null, List.of(options), null, List.of(file));
private static CompilationTask getCompilationTask(FileManagerWrapper wrapper, String... options) {
return getCompiler().getTask(null, wrapper, null, List.of(options), null, wrapper.getFiles());
}
}