diff --git a/distribution/src/config/jvm.options b/distribution/src/config/jvm.options index 54222d07634fc..e083f07edabc8 100644 --- a/distribution/src/config/jvm.options +++ b/distribution/src/config/jvm.options @@ -85,3 +85,4 @@ ${error.file} 23:-XX:CompileCommand=dontinline,java/lang/invoke/MethodHandle.asTypeUncached 21-:-javaagent:agent/opensearch-agent.jar +21-:--add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED diff --git a/gradle/run.gradle b/gradle/run.gradle index 34651f1d94964..ac58d74acd6b0 100644 --- a/gradle/run.gradle +++ b/gradle/run.gradle @@ -43,9 +43,17 @@ testClusters { installedPlugins = Eval.me(installedPlugins) for (String p : installedPlugins) { plugin('plugins:'.concat(p)) + if (p.equals("arrow-flight-rpc")) { + // Add system properties for Netty configuration + systemProperty 'io.netty.allocator.numDirectArenas', '1' + systemProperty 'io.netty.noUnsafe', 'false' + systemProperty 'io.netty.tryUnsafe', 'true' + systemProperty 'io.netty.tryReflectionSetAccessible', 'true' + } } } } + } tasks.register("run", RunTask) { diff --git a/libs/arrow-spi/build.gradle b/libs/arrow-spi/build.gradle deleted file mode 100644 index 90a4c162e428b..0000000000000 --- a/libs/arrow-spi/build.gradle +++ /dev/null @@ -1,20 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - * - * Modifications Copyright OpenSearch Contributors. See - * GitHub history for details. - */ - -testingConventions.enabled = false - -dependencies { - api project(':libs:opensearch-core') -} - -tasks.named('forbiddenApisMain').configure { - replaceSignatureFiles 'jdk-signatures' -} diff --git a/plugins/arrow-flight-rpc/build.gradle b/plugins/arrow-flight-rpc/build.gradle index f3a166bc39ae7..1d05464d0ee87 100644 --- a/plugins/arrow-flight-rpc/build.gradle +++ b/plugins/arrow-flight-rpc/build.gradle @@ -12,25 +12,34 @@ apply plugin: 'opensearch.internal-cluster-test' opensearchplugin { - description = 'Arrow flight based Stream implementation' + description = 'Arrow flight based transport and stream implementation. It also provides Arrow vector and memory dependencies as' + + 'an extended-plugin at runtime; consumers should take a compile time dependency and not runtime on this project.\'\n' classname = 'org.opensearch.arrow.flight.bootstrap.FlightStreamPlugin' } dependencies { - implementation project(':libs:opensearch-arrow-spi') - compileOnly 'org.checkerframework:checker-qual:3.44.0' + // all transitive dependencies exported to use arrow-vector and arrow-memory-core + api "org.apache.arrow:arrow-memory-netty:${versions.arrow}" + api "org.apache.arrow:arrow-memory-core:${versions.arrow}" + api "org.apache.arrow:arrow-memory-netty-buffer-patch:${versions.arrow}" + api "io.netty:netty-buffer:${versions.netty}" + api "io.netty:netty-common:${versions.netty}" + api "org.apache.arrow:arrow-vector:${versions.arrow}" + api "org.apache.arrow:arrow-format:${versions.arrow}" - implementation "org.apache.arrow:arrow-vector:${versions.arrow}" - implementation "org.apache.arrow:arrow-format:${versions.arrow}" + compileOnly 'org.checkerframework:checker-qual:3.44.0' + api "com.google.flatbuffers:flatbuffers-java:${versions.flatbuffers}" + api "org.slf4j:slf4j-api:${versions.slf4j}" + api "com.fasterxml.jackson.core:jackson-core:${versions.jackson}" + api "com.fasterxml.jackson.core:jackson-databind:${versions.jackson}" + api "com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}" + api "commons-codec:commons-codec:${versions.commonscodec}" + + // arrow flight dependencies. implementation "org.apache.arrow:flight-core:${versions.arrow}" - implementation "org.apache.arrow:arrow-memory-core:${versions.arrow}" - - runtimeOnly "org.apache.arrow:arrow-memory-netty:${versions.arrow}" - runtimeOnly "org.apache.arrow:arrow-memory-netty-buffer-patch:${versions.arrow}" - - implementation "io.netty:netty-buffer:${versions.netty}" - implementation "io.netty:netty-common:${versions.netty}" + // since netty-common will be added by opensearch-arrow-core at runtime, so declaring them as compileOnly + // compileOnly "io.netty:netty-common:${versions.netty}" implementation "io.netty:netty-codec:${versions.netty}" implementation "io.netty:netty-codec-http:${versions.netty}" implementation "io.netty:netty-codec-http2:${versions.netty}" @@ -41,28 +50,21 @@ dependencies { implementation "io.netty:netty-transport-classes-epoll:${versions.netty}" implementation "io.netty:netty-tcnative-classes:2.0.66.Final" - implementation "org.slf4j:slf4j-api:${versions.slf4j}" - runtimeOnly "com.google.flatbuffers:flatbuffers-java:${versions.flatbuffers}" - runtimeOnly "commons-codec:commons-codec:${versions.commonscodec}" - implementation "io.grpc:grpc-api:${versions.grpc}" runtimeOnly "io.grpc:grpc-core:${versions.grpc}" implementation "io.grpc:grpc-stub:${versions.grpc}" implementation "io.grpc:grpc-netty:${versions.grpc}" + implementation "com.google.errorprone:error_prone_annotations:2.31.0" runtimeOnly group: 'com.google.code.findbugs', name: 'jsr305', version: '3.0.2' - compileOnly 'org.immutables:value:2.10.1' annotationProcessor 'org.immutables:value:2.10.1' runtimeOnly 'io.perfmark:perfmark-api:0.27.0' runtimeOnly 'org.apache.parquet:parquet-arrow:1.13.1' runtimeOnly "io.grpc:grpc-protobuf-lite:${versions.grpc}" runtimeOnly "io.grpc:grpc-protobuf:${versions.grpc}" - implementation "com.fasterxml.jackson.core:jackson-databind:${versions.jackson}" - implementation "com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}" runtimeOnly "com.google.guava:failureaccess:1.0.1" - compileOnly "com.google.errorprone:error_prone_annotations:2.31.0" runtimeOnly('com.google.guava:guava:33.3.1-jre') { attributes { attribute(Attribute.of('org.gradle.jvm.environment', String), 'standard-jvm') @@ -88,6 +90,7 @@ internalClusterTest { systemProperty 'io.netty.noUnsafe', 'false' systemProperty 'io.netty.tryUnsafe', 'true' systemProperty 'io.netty.tryReflectionSetAccessible', 'true' + jvmArgs += ["--add-opens", "java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED"] } spotless { @@ -120,6 +123,19 @@ tasks.named('forbiddenApisMain').configure { tasks.named('thirdPartyAudit').configure { ignoreMissingClasses( + 'org.apache.commons.logging.Log', + 'org.apache.commons.logging.LogFactory', + + 'org.slf4j.impl.StaticLoggerBinder', + 'org.slf4j.impl.StaticMDCBinder', + 'org.slf4j.impl.StaticMarkerBinder', + + // from Log4j (deliberate, Netty will fallback to Log4j 2) + 'org.apache.log4j.Level', + 'org.apache.log4j.Logger', + + 'reactor.blockhound.BlockHound$Builder', + 'reactor.blockhound.integration.BlockHoundIntegration', 'com.google.gson.stream.JsonReader', 'com.google.gson.stream.JsonToken', 'org.apache.parquet.schema.GroupType', @@ -158,18 +174,6 @@ tasks.named('thirdPartyAudit').configure { 'com.aayushatharva.brotli4j.encoder.Encoder$Parameters', // classes are missing - // from io.netty.logging.CommonsLoggerFactory (netty) - 'org.apache.commons.logging.Log', - 'org.apache.commons.logging.LogFactory', - - 'org.slf4j.impl.StaticLoggerBinder', - 'org.slf4j.impl.StaticMDCBinder', - 'org.slf4j.impl.StaticMarkerBinder', - - // from Log4j (deliberate, Netty will fallback to Log4j 2) - 'org.apache.log4j.Level', - 'org.apache.log4j.Logger', - // from io.netty.handler.ssl.util.BouncyCastleSelfSignedCertGenerator (netty) 'org.bouncycastle.cert.X509v3CertificateBuilder', 'org.bouncycastle.cert.jcajce.JcaX509CertificateConverter', @@ -224,9 +228,6 @@ tasks.named('thirdPartyAudit').configure { 'org.conscrypt.Conscrypt', 'org.conscrypt.HandshakeListener', - 'reactor.blockhound.BlockHound$Builder', - 'reactor.blockhound.integration.BlockHoundIntegration', - 'com.google.protobuf.util.Timestamps' ) ignoreViolations( @@ -288,7 +289,7 @@ tasks.named('thirdPartyAudit').configure { 'io.netty.util.internal.shaded.org.jctools.queues.MpscArrayQueueConsumerIndexField', 'io.netty.util.internal.shaded.org.jctools.queues.MpscArrayQueueProducerIndexField', 'io.netty.util.internal.shaded.org.jctools.queues.MpscArrayQueueProducerLimitField', - 'io.netty.util.internal.shaded.org.jctools.queues.unpadded.MpscUnpaddedArrayQueueConsumerIndexField', + 'io.netty.util.internal.shaded.org.jctools.queues.unpadded.MpscUnpaddedArrayQueueConsumerIndexField', 'io.netty.util.internal.shaded.org.jctools.queues.unpadded.MpscUnpaddedArrayQueueProducerIndexField', 'io.netty.util.internal.shaded.org.jctools.queues.unpadded.MpscUnpaddedArrayQueueProducerLimitField', 'io.netty.util.internal.shaded.org.jctools.util.UnsafeAccess', @@ -296,6 +297,5 @@ tasks.named('thirdPartyAudit').configure { 'io.netty.util.internal.shaded.org.jctools.util.UnsafeRefArrayAccess', 'org.apache.arrow.memory.util.MemoryUtil', 'org.apache.arrow.memory.util.MemoryUtil$1' - ) } diff --git a/plugins/arrow-flight-rpc/licenses/error_prone_annotations-2.31.0.jar.sha1 b/plugins/arrow-flight-rpc/licenses/error_prone_annotations-2.31.0.jar.sha1 new file mode 100644 index 0000000000000..4872d644799f5 --- /dev/null +++ b/plugins/arrow-flight-rpc/licenses/error_prone_annotations-2.31.0.jar.sha1 @@ -0,0 +1 @@ +c3ba307b915d6d506e98ffbb49e6d2d12edad65b \ No newline at end of file diff --git a/plugins/arrow-flight-rpc/licenses/error_prone_annotations-LICENSE.txt b/plugins/arrow-flight-rpc/licenses/error_prone_annotations-LICENSE.txt new file mode 100644 index 0000000000000..d645695673349 --- /dev/null +++ b/plugins/arrow-flight-rpc/licenses/error_prone_annotations-LICENSE.txt @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/plugins/arrow-flight-rpc/licenses/error_prone_annotations-NOTICE.txt b/plugins/arrow-flight-rpc/licenses/error_prone_annotations-NOTICE.txt new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/plugins/arrow-flight-rpc/licenses/netty-LICENSE.txt b/plugins/arrow-flight-rpc/licenses/netty-LICENSE.txt index 62589edd12a37..d645695673349 100644 --- a/plugins/arrow-flight-rpc/licenses/netty-LICENSE.txt +++ b/plugins/arrow-flight-rpc/licenses/netty-LICENSE.txt @@ -1,7 +1,7 @@ Apache License Version 2.0, January 2004 - https://www.apache.org/licenses/ + http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION @@ -193,7 +193,7 @@ you may not use this file except in compliance with the License. You may obtain a copy of the License at - https://www.apache.org/licenses/LICENSE-2.0 + http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, diff --git a/plugins/arrow-flight-rpc/licenses/netty-NOTICE.txt b/plugins/arrow-flight-rpc/licenses/netty-NOTICE.txt index 971865b7c1c23..5bbf91a14de23 100644 --- a/plugins/arrow-flight-rpc/licenses/netty-NOTICE.txt +++ b/plugins/arrow-flight-rpc/licenses/netty-NOTICE.txt @@ -4,15 +4,15 @@ Please visit the Netty web site for more information: - * https://netty.io/ + * http://netty.io/ -Copyright 2014 The Netty Project +Copyright 2011 The Netty Project The Netty Project licenses this file to you under the Apache License, version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at: - https://www.apache.org/licenses/LICENSE-2.0 +http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT @@ -42,112 +42,29 @@ Base64 Encoder and Decoder, which can be obtained at: * HOMEPAGE: * http://iharder.sourceforge.net/current/java/base64/ -This product contains a modified portion of 'Webbit', an event based -WebSocket and HTTP server, which can be obtained at: +This product contains a modified version of 'JZlib', a re-implementation of +zlib in pure Java, which can be obtained at: * LICENSE: - * license/LICENSE.webbit.txt (BSD License) - * HOMEPAGE: - * https://github.com/joewalnes/webbit - -This product contains a modified portion of 'SLF4J', a simple logging -facade for Java, which can be obtained at: - - * LICENSE: - * license/LICENSE.slf4j.txt (MIT License) - * HOMEPAGE: - * https://www.slf4j.org/ - -This product contains a modified portion of 'Apache Harmony', an open source -Java SE, which can be obtained at: - - * NOTICE: - * license/NOTICE.harmony.txt - * LICENSE: - * license/LICENSE.harmony.txt (Apache License 2.0) - * HOMEPAGE: - * https://archive.apache.org/dist/harmony/ - -This product contains a modified portion of 'jbzip2', a Java bzip2 compression -and decompression library written by Matthew J. Francis. It can be obtained at: - - * LICENSE: - * license/LICENSE.jbzip2.txt (MIT License) - * HOMEPAGE: - * https://code.google.com/p/jbzip2/ - -This product contains a modified portion of 'libdivsufsort', a C API library to construct -the suffix array and the Burrows-Wheeler transformed string for any input string of -a constant-size alphabet written by Yuta Mori. It can be obtained at: - - * LICENSE: - * license/LICENSE.libdivsufsort.txt (MIT License) - * HOMEPAGE: - * https://github.com/y-256/libdivsufsort - -This product contains a modified portion of Nitsan Wakart's 'JCTools', Java Concurrency Tools for the JVM, - which can be obtained at: - - * LICENSE: - * license/LICENSE.jctools.txt (ASL2 License) - * HOMEPAGE: - * https://github.com/JCTools/JCTools - -This product optionally depends on 'JZlib', a re-implementation of zlib in -pure Java, which can be obtained at: - - * LICENSE: - * license/LICENSE.jzlib.txt (BSD style License) + * license/LICENSE.jzlib.txt (BSD Style License) * HOMEPAGE: * http://www.jcraft.com/jzlib/ -This product optionally depends on 'Compress-LZF', a Java library for encoding and -decoding data in LZF format, written by Tatu Saloranta. It can be obtained at: - - * LICENSE: - * license/LICENSE.compress-lzf.txt (Apache License 2.0) - * HOMEPAGE: - * https://github.com/ning/compress - -This product optionally depends on 'lz4', a LZ4 Java compression -and decompression library written by Adrien Grand. It can be obtained at: - - * LICENSE: - * license/LICENSE.lz4.txt (Apache License 2.0) - * HOMEPAGE: - * https://github.com/jpountz/lz4-java - -This product optionally depends on 'lzma-java', a LZMA Java compression -and decompression library, which can be obtained at: - - * LICENSE: - * license/LICENSE.lzma-java.txt (Apache License 2.0) - * HOMEPAGE: - * https://github.com/jponge/lzma-java - -This product optionally depends on 'zstd-jni', a zstd-jni Java compression -and decompression library, which can be obtained at: +This product contains a modified version of 'Webbit', a Java event based +WebSocket and HTTP server: * LICENSE: - * license/LICENSE.zstd-jni.txt (BSD) - * HOMEPAGE: - * https://github.com/luben/zstd-jni - -This product contains a modified portion of 'jfastlz', a Java port of FastLZ compression -and decompression library written by William Kinney. It can be obtained at: - - * LICENSE: - * license/LICENSE.jfastlz.txt (MIT License) + * license/LICENSE.webbit.txt (BSD License) * HOMEPAGE: - * https://code.google.com/p/jfastlz/ + * https://github.com/joewalnes/webbit -This product contains a modified portion of and optionally depends on 'Protocol Buffers', Google's data +This product optionally depends on 'Protocol Buffers', Google's data interchange format, which can be obtained at: * LICENSE: * license/LICENSE.protobuf.txt (New BSD License) * HOMEPAGE: - * https://github.com/google/protobuf + * http://code.google.com/p/protobuf/ This product optionally depends on 'Bouncy Castle Crypto APIs' to generate a temporary self-signed X.509 certificate when the JVM does not provide the @@ -156,31 +73,15 @@ equivalent functionality. It can be obtained at: * LICENSE: * license/LICENSE.bouncycastle.txt (MIT License) * HOMEPAGE: - * https://www.bouncycastle.org/ - -This product optionally depends on 'Snappy', a compression library produced -by Google Inc, which can be obtained at: - - * LICENSE: - * license/LICENSE.snappy.txt (New BSD License) - * HOMEPAGE: - * https://github.com/google/snappy + * http://www.bouncycastle.org/ -This product optionally depends on 'JBoss Marshalling', an alternative Java -serialization API, which can be obtained at: +This product optionally depends on 'SLF4J', a simple logging facade for Java, +which can be obtained at: * LICENSE: - * license/LICENSE.jboss-marshalling.txt (Apache License 2.0) - * HOMEPAGE: - * https://github.com/jboss-remoting/jboss-marshalling - -This product optionally depends on 'Caliper', Google's micro- -benchmarking framework, which can be obtained at: - - * LICENSE: - * license/LICENSE.caliper.txt (Apache License 2.0) + * license/LICENSE.slf4j.txt (MIT License) * HOMEPAGE: - * https://github.com/google/caliper + * http://www.slf4j.org/ This product optionally depends on 'Apache Commons Logging', a logging framework, which can be obtained at: @@ -188,77 +89,28 @@ framework, which can be obtained at: * LICENSE: * license/LICENSE.commons-logging.txt (Apache License 2.0) * HOMEPAGE: - * https://commons.apache.org/logging/ + * http://commons.apache.org/logging/ -This product optionally depends on 'Apache Log4J', a logging framework, which -can be obtained at: +This product optionally depends on 'Apache Log4J', a logging framework, +which can be obtained at: * LICENSE: * license/LICENSE.log4j.txt (Apache License 2.0) * HOMEPAGE: - * https://logging.apache.org/log4j/ - -This product optionally depends on 'Aalto XML', an ultra-high performance -non-blocking XML processor, which can be obtained at: - - * LICENSE: - * license/LICENSE.aalto-xml.txt (Apache License 2.0) - * HOMEPAGE: - * https://wiki.fasterxml.com/AaltoHome + * http://logging.apache.org/log4j/ -This product contains a modified version of 'HPACK', a Java implementation of -the HTTP/2 HPACK algorithm written by Twitter. It can be obtained at: +This product optionally depends on 'JBoss Logging', a logging framework, +which can be obtained at: * LICENSE: - * license/LICENSE.hpack.txt (Apache License 2.0) - * HOMEPAGE: - * https://github.com/twitter/hpack - -This product contains a modified version of 'HPACK', a Java implementation of -the HTTP/2 HPACK algorithm written by Cory Benfield. It can be obtained at: - - * LICENSE: - * license/LICENSE.hyper-hpack.txt (MIT License) - * HOMEPAGE: - * https://github.com/python-hyper/hpack/ - -This product contains a modified version of 'HPACK', a Java implementation of -the HTTP/2 HPACK algorithm written by Tatsuhiro Tsujikawa. It can be obtained at: - - * LICENSE: - * license/LICENSE.nghttp2-hpack.txt (MIT License) - * HOMEPAGE: - * https://github.com/nghttp2/nghttp2/ - -This product contains a modified portion of 'Apache Commons Lang', a Java library -provides utilities for the java.lang API, which can be obtained at: - - * LICENSE: - * license/LICENSE.commons-lang.txt (Apache License 2.0) - * HOMEPAGE: - * https://commons.apache.org/proper/commons-lang/ - - -This product contains the Maven wrapper scripts from 'Maven Wrapper', that provides an easy way to ensure a user has everything necessary to run the Maven build. - - * LICENSE: - * license/LICENSE.mvn-wrapper.txt (Apache License 2.0) - * HOMEPAGE: - * https://github.com/takari/maven-wrapper - -This product contains the dnsinfo.h header file, that provides a way to retrieve the system DNS configuration on MacOS. -This private header is also used by Apple's open source - mDNSResponder (https://opensource.apple.com/tarballs/mDNSResponder/). - - * LICENSE: - * license/LICENSE.dnsinfo.txt (Apple Public Source License 2.0) + * license/LICENSE.jboss-logging.txt (GNU LGPL 2.1) * HOMEPAGE: - * https://www.opensource.apple.com/source/configd/configd-453.19/dnsinfo/dnsinfo.h + * http://anonsvn.jboss.org/repos/common/common-logging-spi/ -This product optionally depends on 'Brotli4j', Brotli compression and -decompression for Java., which can be obtained at: +This product optionally depends on 'Apache Felix', an open source OSGi +framework implementation, which can be obtained at: * LICENSE: - * license/LICENSE.brotli4j.txt (Apache License 2.0) + * license/LICENSE.felix.txt (Apache License 2.0) * HOMEPAGE: - * https://github.com/hyperxpro/Brotli4j + * http://felix.apache.org/ diff --git a/plugins/arrow-flight-rpc/src/internalClusterTest/java/org/opensearch/arrow/flight/ArrowFlightServerIT.java b/plugins/arrow-flight-rpc/src/internalClusterTest/java/org/opensearch/arrow/flight/ArrowFlightServerIT.java index 54b47329dab7f..46f72bea3e4c7 100644 --- a/plugins/arrow-flight-rpc/src/internalClusterTest/java/org/opensearch/arrow/flight/ArrowFlightServerIT.java +++ b/plugins/arrow-flight-rpc/src/internalClusterTest/java/org/opensearch/arrow/flight/ArrowFlightServerIT.java @@ -10,24 +10,37 @@ import org.apache.arrow.flight.CallOptions; import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VectorSchemaRoot; import org.opensearch.arrow.flight.bootstrap.FlightClientManager; import org.opensearch.arrow.flight.bootstrap.FlightService; import org.opensearch.arrow.flight.bootstrap.FlightStreamPlugin; +import org.opensearch.arrow.spi.StreamManager; +import org.opensearch.arrow.spi.StreamProducer; +import org.opensearch.arrow.spi.StreamReader; +import org.opensearch.arrow.spi.StreamTicket; import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.unit.TimeValue; import org.opensearch.plugins.Plugin; import org.opensearch.test.OpenSearchIntegTestCase; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import static org.opensearch.common.util.FeatureFlags.ARROW_STREAMS; @OpenSearchIntegTestCase.ClusterScope(scope = OpenSearchIntegTestCase.Scope.SUITE, numDataNodes = 5) public class ArrowFlightServerIT extends OpenSearchIntegTestCase { - private FlightClientManager flightClientManager; - @Override protected Collection> nodePlugins() { return Collections.singleton(FlightStreamPlugin.class); @@ -37,18 +50,269 @@ protected Collection> nodePlugins() { public void setUp() throws Exception { super.setUp(); ensureGreen(); - Thread.sleep(1000); - FlightService flightService = internalCluster().getInstance(FlightService.class); - flightClientManager = flightService.getFlightClientManager(); + for (DiscoveryNode node : getClusterState().nodes()) { + FlightService flightService = internalCluster().getInstance(FlightService.class, node.getName()); + FlightClientManager flightClientManager = flightService.getFlightClientManager(); + assertBusy(() -> { + assertTrue( + "Flight client should be created successfully before running tests", + flightClientManager.getFlightClient(node.getId()).isPresent() + ); + }, 3, TimeUnit.SECONDS); + } } @LockFeatureFlag(ARROW_STREAMS) public void testArrowFlightEndpoint() throws Exception { for (DiscoveryNode node : getClusterState().nodes()) { - try (FlightClient flightClient = flightClientManager.getFlightClient(node.getId()).get()) { - assertNotNull(flightClient); - flightClient.handshake(CallOptions.timeout(5000L, TimeUnit.MILLISECONDS)); + FlightService flightService = internalCluster().getInstance(FlightService.class, node.getName()); + FlightClientManager flightClientManager = flightService.getFlightClientManager(); + FlightClient flightClient = flightClientManager.getFlightClient(node.getId()).get(); + assertNotNull(flightClient); + flightClient.handshake(CallOptions.timeout(5000L, TimeUnit.MILLISECONDS)); + flightClient.handshake(CallOptions.timeout(5000L, TimeUnit.MILLISECONDS)); + } + } + + @LockFeatureFlag(ARROW_STREAMS) + public void testFlightStreamReader() throws Exception { + for (DiscoveryNode node : getClusterState().nodes()) { + StreamManager streamManagerRandomNode = getStreamManagerRandomNode(); + StreamTicket ticket = streamManagerRandomNode.registerStream(getStreamProducer(), null); + StreamManager streamManagerCurrentNode = getStreamManager(node.getName()); + // reader should be accessible from any node in the cluster due to the use ProxyStreamProducer + try (StreamReader reader = streamManagerCurrentNode.getStreamReader(ticket)) { + int totalBatches = 0; + assertNotNull(reader.getRoot().getVector("docID")); + while (reader.next()) { + IntVector docIDVector = (IntVector) reader.getRoot().getVector("docID"); + assertEquals(10, docIDVector.getValueCount()); + for (int i = 0; i < 10; i++) { + assertEquals(docIDVector.toString(), i + (totalBatches * 10L), docIDVector.get(i)); + } + totalBatches++; + } + assertEquals(10, totalBatches); + } + } + } + + @LockFeatureFlag(ARROW_STREAMS) + public void testEarlyCancel() throws Exception { + DiscoveryNode previousNode = null; + for (DiscoveryNode node : getClusterState().nodes()) { + if (previousNode == null) { + previousNode = node; + continue; + } + StreamManager streamManagerServer = getStreamManager(node.getName()); + TestStreamProducer streamProducer = getStreamProducer(); + StreamTicket ticket = streamManagerServer.registerStream(streamProducer, null); + StreamManager streamManagerClient = getStreamManager(previousNode.getName()); + + CountDownLatch readerComplete = new CountDownLatch(1); + AtomicReference readerException = new AtomicReference<>(); + AtomicReference> readerRef = new AtomicReference<>(); + + // Start reader thread + Thread readerThread = new Thread(() -> { + try (StreamReader reader = streamManagerClient.getStreamReader(ticket)) { + readerRef.set(reader); + assertNotNull(reader.getRoot()); + IntVector docIDVector = (IntVector) reader.getRoot().getVector("docID"); + assertNotNull(docIDVector); + + // Read first batch + reader.next(); + assertEquals(10, docIDVector.getValueCount()); + for (int i = 0; i < 10; i++) { + assertEquals(docIDVector.toString(), i, docIDVector.get(i)); + } + } catch (Exception e) { + readerException.set(e); + } finally { + readerComplete.countDown(); + } + }, "flight-reader-thread"); + + readerThread.start(); + assertTrue("Reader thread did not complete in time", readerComplete.await(5, TimeUnit.SECONDS)); + + if (readerException.get() != null) { + throw readerException.get(); + } + + StreamReader reader = readerRef.get(); + + try { + reader.next(); + fail("Expected FlightRuntimeException"); + } catch (FlightRuntimeException e) { + assertEquals("CANCELLED", e.status().code().name()); + assertEquals("Stream closed before end", e.getMessage()); + reader.close(); + } + + // Wait for close to complete + // Due to https://github.com/grpc/grpc-java/issues/5882, there is a logic in FlightStream.java + // where it exhausts the stream on the server side before it is actually cancelled. + assertTrue( + "Timeout waiting for stream cancellation on server [" + node.getName() + "]", + streamProducer.waitForClose(2, TimeUnit.SECONDS) + ); + previousNode = node; + } + } + + @LockFeatureFlag(ARROW_STREAMS) + public void testFlightStreamServerError() throws Exception { + DiscoveryNode previousNode = null; + for (DiscoveryNode node : getClusterState().nodes()) { + if (previousNode == null) { + previousNode = node; + continue; + } + StreamManager streamManagerServer = getStreamManager(node.getName()); + TestStreamProducer streamProducer = getStreamProducer(); + streamProducer.setProduceError(true); + StreamTicket ticket = streamManagerServer.registerStream(streamProducer, null); + StreamManager streamManagerClient = getStreamManager(previousNode.getName()); + try (StreamReader reader = streamManagerClient.getStreamReader(ticket)) { + int totalBatches = 0; + assertNotNull(reader.getRoot().getVector("docID")); + try { + while (reader.next()) { + IntVector docIDVector = (IntVector) reader.getRoot().getVector("docID"); + assertEquals(10, docIDVector.getValueCount()); + totalBatches++; + } + fail("Expected FlightRuntimeException"); + } catch (FlightRuntimeException e) { + assertEquals("INTERNAL", e.status().code().name()); + assertEquals("Unexpected server error", e.getMessage()); + } + assertEquals(1, totalBatches); + } + previousNode = node; + } + } + + @LockFeatureFlag(ARROW_STREAMS) + public void testFlightGetInfo() throws Exception { + StreamTicket ticket = null; + for (DiscoveryNode node : getClusterState().nodes()) { + FlightService flightService = internalCluster().getInstance(FlightService.class, node.getName()); + StreamManager streamManager = flightService.getStreamManager(); + if (ticket == null) { + ticket = streamManager.registerStream(getStreamProducer(), null); + } + FlightClientManager flightClientManager = flightService.getFlightClientManager(); + FlightClient flightClient = flightClientManager.getFlightClient(node.getId()).get(); + assertNotNull(flightClient); + FlightDescriptor flightDescriptor = FlightDescriptor.command(ticket.toBytes()); + FlightInfo flightInfo = flightClient.getInfo(flightDescriptor, CallOptions.timeout(5000L, TimeUnit.MILLISECONDS)); + assertNotNull(flightInfo); + assertEquals(100, flightInfo.getRecords()); + } + } + + private StreamManager getStreamManager(String nodeName) { + FlightService flightService = internalCluster().getInstance(FlightService.class, nodeName); + return flightService.getStreamManager(); + } + + private StreamManager getStreamManagerRandomNode() { + FlightService flightService = internalCluster().getInstance(FlightService.class); + return flightService.getStreamManager(); + } + + private TestStreamProducer getStreamProducer() { + return new TestStreamProducer(); + } + + private static class TestStreamProducer implements StreamProducer { + volatile boolean isClosed = false; + private final CountDownLatch closeLatch = new CountDownLatch(1); + TimeValue deadline = TimeValue.timeValueSeconds(5); + private boolean produceError = false; + + public void setProduceError(boolean produceError) { + this.produceError = produceError; + } + + TestStreamProducer() {} + + VectorSchemaRoot root; + + @Override + public VectorSchemaRoot createRoot(BufferAllocator allocator) { + IntVector docIDVector = new IntVector("docID", allocator); + FieldVector[] vectors = new FieldVector[] { docIDVector }; + root = new VectorSchemaRoot(Arrays.asList(vectors)); + return root; + } + + @Override + public BatchedJob createJob(BufferAllocator allocator) { + return new BatchedJob<>() { + @Override + public void run(VectorSchemaRoot root, FlushSignal flushSignal) { + IntVector docIDVector = (IntVector) root.getVector("docID"); + root.setRowCount(10); + for (int i = 0; i < 100; i++) { + docIDVector.setSafe(i % 10, i); + if ((i + 1) % 10 == 0) { + flushSignal.awaitConsumption(TimeValue.timeValueMillis(1000)); + docIDVector.clear(); + root.setRowCount(10); + if (produceError) { + throw new RuntimeException("Server error while producing batch"); + } + } + } + } + + @Override + public void onCancel() { + if (!isClosed && root != null) { + root.close(); + } + isClosed = true; + } + + @Override + public boolean isCancelled() { + return isClosed; + } + }; + } + + @Override + public TimeValue getJobDeadline() { + return deadline; + } + + @Override + public int estimatedRowCount() { + return 100; + } + + @Override + public String getAction() { + return ""; + } + + @Override + public void close() { + if (!isClosed && root != null) { + root.close(); } + closeLatch.countDown(); + isClosed = true; + } + + public boolean waitForClose(long timeout, TimeUnit unit) throws InterruptedException { + return closeLatch.await(timeout, unit); } } } diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/FlightServerInfoAction.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/FlightServerInfoAction.java similarity index 97% rename from plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/FlightServerInfoAction.java rename to plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/FlightServerInfoAction.java index 529bee72c708d..c988090081266 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/FlightServerInfoAction.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/FlightServerInfoAction.java @@ -5,7 +5,7 @@ * this file be licensed under the Apache-2.0 license or a * compatible open source license. */ -package org.opensearch.arrow.flight.api; +package org.opensearch.arrow.flight.api.flightinfo; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestRequest; diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/NodeFlightInfo.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/NodeFlightInfo.java similarity index 98% rename from plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/NodeFlightInfo.java rename to plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/NodeFlightInfo.java index e804b0c518523..23163bfac8c2e 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/NodeFlightInfo.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/NodeFlightInfo.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.arrow.flight.api; +package org.opensearch.arrow.flight.api.flightinfo; import org.opensearch.action.support.nodes.BaseNodeResponse; import org.opensearch.cluster.node.DiscoveryNode; diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/NodesFlightInfoAction.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoAction.java similarity index 93% rename from plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/NodesFlightInfoAction.java rename to plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoAction.java index 3148c58a1509d..3c3a9965459cb 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/NodesFlightInfoAction.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoAction.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.arrow.flight.api; +package org.opensearch.arrow.flight.api.flightinfo; import org.opensearch.action.ActionType; diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/NodesFlightInfoRequest.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoRequest.java similarity index 97% rename from plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/NodesFlightInfoRequest.java rename to plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoRequest.java index 1b707f461819c..43bf38a096b57 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/NodesFlightInfoRequest.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoRequest.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.arrow.flight.api; +package org.opensearch.arrow.flight.api.flightinfo; import org.opensearch.action.support.nodes.BaseNodesRequest; import org.opensearch.core.common.io.stream.StreamInput; diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/NodesFlightInfoResponse.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoResponse.java similarity index 98% rename from plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/NodesFlightInfoResponse.java rename to plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoResponse.java index 721cd631924bd..805aa188ce37a 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/NodesFlightInfoResponse.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoResponse.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.arrow.flight.api; +package org.opensearch.arrow.flight.api.flightinfo; import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.nodes.BaseNodesResponse; diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/TransportNodesFlightInfoAction.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/TransportNodesFlightInfoAction.java similarity index 98% rename from plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/TransportNodesFlightInfoAction.java rename to plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/TransportNodesFlightInfoAction.java index d4722e20d1f84..51f4cc05b8001 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/TransportNodesFlightInfoAction.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/TransportNodesFlightInfoAction.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.arrow.flight.api; +package org.opensearch.arrow.flight.api.flightinfo; import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/package-info.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/package-info.java similarity index 83% rename from plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/package-info.java rename to plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/package-info.java index d89ec87f9a51e..19dde32f32e8f 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/package-info.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/api/flightinfo/package-info.java @@ -9,4 +9,4 @@ /** * Action to retrieve flight info from nodes */ -package org.opensearch.arrow.flight.api; +package org.opensearch.arrow.flight.api.flightinfo; diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightClientManager.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightClientManager.java index a81033f580a03..c81f4d3c270e7 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightClientManager.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightClientManager.java @@ -15,10 +15,10 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.Version; -import org.opensearch.arrow.flight.api.NodeFlightInfo; -import org.opensearch.arrow.flight.api.NodesFlightInfoAction; -import org.opensearch.arrow.flight.api.NodesFlightInfoRequest; -import org.opensearch.arrow.flight.api.NodesFlightInfoResponse; +import org.opensearch.arrow.flight.api.flightinfo.NodeFlightInfo; +import org.opensearch.arrow.flight.api.flightinfo.NodesFlightInfoAction; +import org.opensearch.arrow.flight.api.flightinfo.NodesFlightInfoRequest; +import org.opensearch.arrow.flight.api.flightinfo.NodesFlightInfoResponse; import org.opensearch.arrow.flight.bootstrap.tls.SslContextProvider; import org.opensearch.cluster.ClusterChangedEvent; import org.opensearch.cluster.ClusterStateListener; @@ -59,8 +59,9 @@ public class FlightClientManager implements ClusterStateListener, AutoCloseable static final int LOCATION_TIMEOUT_MS = 1000; private final ExecutorService grpcExecutor; private final ClientConfiguration clientConfig; - private final Map flightClients = new ConcurrentHashMap<>(); + private final Map flightClients = new ConcurrentHashMap<>(); private final Client client; + private volatile boolean closed = false; /** * Creates a new FlightClientManager instance. @@ -99,7 +100,19 @@ public FlightClientManager( * @return An OpenSearchFlightClient instance for the specified node */ public Optional getFlightClient(String nodeId) { - return Optional.ofNullable(flightClients.get(nodeId)); + ClientHolder clientHolder = flightClients.get(nodeId); + return clientHolder == null ? Optional.empty() : Optional.of(clientHolder.flightClient); + } + + /** + * Returns the location of a Flight client for a given node ID. + * + * @param nodeId The ID of the node for which to retrieve the location + * @return The Location of the Flight client for the specified node + */ + public Optional getFlightClientLocation(String nodeId) { + ClientHolder clientHolder = flightClients.get(nodeId); + return clientHolder == null ? Optional.empty() : Optional.of(clientHolder.location); } /** @@ -128,13 +141,15 @@ private void buildClientAndAddToPool(Location location, DiscoveryNode node) { ); return; } - flightClients.computeIfAbsent(node.getId(), key -> buildClient(location)); + if (closed) { + return; + } + flightClients.computeIfAbsent(node.getId(), nodeId -> new ClientHolder(location, buildClient(location))); } private void requestNodeLocation(String nodeId, CompletableFuture future) { NodesFlightInfoRequest request = new NodesFlightInfoRequest(nodeId); try { - client.execute(NodesFlightInfoAction.INSTANCE, request, new ActionListener<>() { @Override public void onResponse(NodesFlightInfoResponse response) { @@ -184,13 +199,21 @@ private DiscoveryNode getNodeFromClusterState(String nodeId) { */ @Override public void close() throws Exception { - for (FlightClient flightClient : flightClients.values()) { - flightClient.close(); + if (closed) { + return; + } + closed = true; + for (ClientHolder clientHolder : flightClients.values()) { + clientHolder.flightClient.close(); } flightClients.clear(); grpcExecutor.shutdown(); - grpcExecutor.awaitTermination(5, TimeUnit.SECONDS); - clientConfig.clusterService.removeListener(this); + if (grpcExecutor.awaitTermination(5, TimeUnit.SECONDS) == false) { + grpcExecutor.shutdownNow(); + } + } + + private record ClientHolder(Location location, FlightClient flightClient) { } /** @@ -229,7 +252,7 @@ private Set getCurrentClusterNodes() { } @VisibleForTesting - Map getFlightClients() { + Map getFlightClients() { return flightClients; } diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightService.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightService.java index 7735fc3df73e0..890163d81145e 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightService.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightService.java @@ -8,7 +8,6 @@ package org.opensearch.arrow.flight.bootstrap; -import org.apache.arrow.flight.NoOpFlightProducer; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.apache.arrow.util.AutoCloseables; @@ -17,6 +16,8 @@ import org.apache.logging.log4j.Logger; import org.opensearch.arrow.flight.bootstrap.tls.DefaultSslContextProvider; import org.opensearch.arrow.flight.bootstrap.tls.SslContextProvider; +import org.opensearch.arrow.flight.impl.BaseFlightProducer; +import org.opensearch.arrow.flight.impl.FlightStreamManager; import org.opensearch.arrow.spi.StreamManager; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.network.NetworkService; @@ -39,7 +40,7 @@ public class FlightService extends NetworkPlugin.AuxTransport { private static final Logger logger = LogManager.getLogger(FlightService.class); private final ServerComponents serverComponents; - private StreamManager streamManager; + private FlightStreamManager streamManager; private Client client; private FlightClientManager clientManager; private SecureTransportSettingsProvider secureTransportSettingsProvider; @@ -58,6 +59,7 @@ public FlightService(Settings settings) { throw new RuntimeException("Failed to initialize Arrow Flight server", e); } this.serverComponents = new ServerComponents(settings); + this.streamManager = new FlightStreamManager(); } void setClusterService(ClusterService clusterService) { @@ -104,7 +106,7 @@ protected void doStart() { client ); initializeStreamManager(clientManager); - serverComponents.setFlightProducer(new NoOpFlightProducer()); + serverComponents.setFlightProducer(new BaseFlightProducer(clientManager, streamManager, allocator)); serverComponents.start(); } catch (Exception e) { @@ -165,6 +167,7 @@ protected void doClose() { } private void initializeStreamManager(FlightClientManager clientManager) { - streamManager = null; + streamManager.setAllocatorSupplier(() -> allocator); + streamManager.setClientManager(clientManager); } } diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightStreamPlugin.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightStreamPlugin.java index bb7edf491cf02..e2e7ef289eaf6 100644 --- a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightStreamPlugin.java +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/bootstrap/FlightStreamPlugin.java @@ -8,9 +8,9 @@ package org.opensearch.arrow.flight.bootstrap; -import org.opensearch.arrow.flight.api.FlightServerInfoAction; -import org.opensearch.arrow.flight.api.NodesFlightInfoAction; -import org.opensearch.arrow.flight.api.TransportNodesFlightInfoAction; +import org.opensearch.arrow.flight.api.flightinfo.FlightServerInfoAction; +import org.opensearch.arrow.flight.api.flightinfo.NodesFlightInfoAction; +import org.opensearch.arrow.flight.api.flightinfo.TransportNodesFlightInfoAction; import org.opensearch.arrow.spi.StreamManager; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.node.DiscoveryNode; @@ -31,6 +31,7 @@ import org.opensearch.env.NodeEnvironment; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.ClusterPlugin; +import org.opensearch.plugins.ExtensiblePlugin; import org.opensearch.plugins.NetworkPlugin; import org.opensearch.plugins.Plugin; import org.opensearch.plugins.SecureTransportSettingsProvider; @@ -52,12 +53,19 @@ import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.function.Supplier; /** * FlightStreamPlugin class extends BaseFlightStreamPlugin and provides implementation for FlightStream plugin. */ -public class FlightStreamPlugin extends Plugin implements StreamManagerPlugin, NetworkPlugin, ActionPlugin, ClusterPlugin { +public class FlightStreamPlugin extends Plugin + implements + StreamManagerPlugin, + NetworkPlugin, + ActionPlugin, + ClusterPlugin, + ExtensiblePlugin { private final FlightService flightService; private final boolean isArrowStreamsEnabled; @@ -221,11 +229,8 @@ public void onNodeStarted(DiscoveryNode localNode) { * Gets the StreamManager instance for managing flight streams. */ @Override - public Supplier getStreamManager() { - if (!isArrowStreamsEnabled) { - return null; - } - return flightService::getStreamManager; + public Optional getStreamManager() { + return isArrowStreamsEnabled ? Optional.ofNullable(flightService.getStreamManager()) : Optional.empty(); } /** diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/BaseFlightProducer.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/BaseFlightProducer.java new file mode 100644 index 0000000000000..08f20d5448511 --- /dev/null +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/BaseFlightProducer.java @@ -0,0 +1,259 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight.impl; + +import org.apache.arrow.flight.BackpressureStrategy; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.NoOpFlightProducer; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.arrow.flight.bootstrap.FlightClientManager; +import org.opensearch.arrow.spi.StreamProducer; +import org.opensearch.arrow.spi.StreamTicket; + +import java.io.IOException; +import java.util.Collections; +import java.util.Optional; + +/** + * BaseFlightProducer extends NoOpFlightProducer to provide stream management functionality + * for Arrow Flight in OpenSearch. This class handles data streaming based on tickets, + * manages backpressure, and coordinates between stream providers and server stream listeners. + * It runs on the gRPC transport thread. + *

+ * Error handling strategy: + * 1. Add all errors to listener. + * 2. All FlightRuntimeException which are not INTERNAL should not be logged. + * 3. All FlightRuntimeException which are INTERNAL should be logged with error or warn (depending on severity). + */ +public class BaseFlightProducer extends NoOpFlightProducer { + private static final Logger logger = LogManager.getLogger(BaseFlightProducer.class); + private final FlightClientManager flightClientManager; + private final FlightStreamManager streamManager; + private final BufferAllocator allocator; + + /** + * Constructs a new BaseFlightProducer. + * + * @param flightClientManager The manager for handling client connections + * @param streamManager The manager for stream operations + * @param allocator The buffer allocator for Arrow memory management + */ + public BaseFlightProducer(FlightClientManager flightClientManager, FlightStreamManager streamManager, BufferAllocator allocator) { + this.flightClientManager = flightClientManager; + this.streamManager = streamManager; + this.allocator = allocator; + } + + /** + * Handles data streaming for a given Arrow Flight Ticket. This method runs on the gRPC transport thread + * and manages the entire streaming process, including backpressure and error handling. + * @param context The call context (unused in this implementation) + * @param ticket The Arrow Flight Ticket containing stream information + * @param listener The server stream listener for data flow + */ + @Override + public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) { + try { + StreamTicket streamTicket = parseTicket(ticket); + FlightStreamManager.StreamProducerHolder producerHolder = acquireStreamProducer(streamTicket, ticket).orElseThrow(() -> { + FlightRuntimeException ex = CallStatus.NOT_FOUND.withDescription("Stream not found").toRuntimeException(); + listener.error(ex); + return ex; + }); + processStreamWithProducer(context, producerHolder, listener); + } catch (FlightRuntimeException ex) { + listener.error(ex); + throw ex; + } catch (Exception ex) { + logger.error("Unexpected error during stream processing", ex); + FlightRuntimeException fre = CallStatus.INTERNAL.withCause(ex).withDescription("Unexpected server error").toRuntimeException(); + listener.error(fre); + throw fre; + } + } + + /** + * Retrieves FlightInfo for a given descriptor, handling both local and remote cases. + * The descriptor's command is expected to contain a serialized StreamTicket. + * + * @param context The call context + * @param descriptor The flight descriptor containing stream information + * @return FlightInfo for the requested stream + * @throws RuntimeException if the requested info cannot be retrieved + */ + @Override + public FlightInfo getFlightInfo(CallContext context, FlightDescriptor descriptor) { + StreamTicket streamTicket = parseDescriptor(descriptor); + return streamTicket.getNodeId().equals(flightClientManager.getLocalNodeId()) + ? getLocalFlightInfo(streamTicket, descriptor) + : getRemoteFlightInfo(streamTicket, descriptor); + } + + private StreamTicket parseTicket(Ticket ticket) { + try { + return streamManager.getStreamTicketFactory().fromBytes(ticket.getBytes()); + } catch (Exception e) { + logger.debug("Failed to parse Arrow Flight Ticket", e); + throw CallStatus.INVALID_ARGUMENT.withCause(e).withDescription("Invalid ticket format: " + e.getMessage()).toRuntimeException(); + } + } + + private StreamTicket parseDescriptor(FlightDescriptor descriptor) { + try { + return streamManager.getStreamTicketFactory().fromBytes(descriptor.getCommand()); + } catch (Exception e) { + logger.debug("Failed to parse flight descriptor command", e); + throw CallStatus.INVALID_ARGUMENT.withCause(e) + .withDescription("Invalid descriptor format: " + e.getMessage()) + .toRuntimeException(); + } + } + + private Optional acquireStreamProducer(StreamTicket streamTicket, Ticket ticket) { + if (streamTicket.getNodeId().equals(flightClientManager.getLocalNodeId())) { + return streamManager.removeStreamProducer(streamTicket); + } + return flightClientManager.getFlightClient(streamTicket.getNodeId()) + .map(client -> createProxyProducer(client, ticket)) + .filter(Optional::isPresent) + .orElse(Optional.empty()); + } + + private Optional createProxyProducer(FlightClient remoteClient, Ticket ticket) { + try (FlightStream flightStream = remoteClient.getStream(ticket)) { + return Optional.ofNullable(flightStream) + .map(fs -> new ProxyStreamProducer(new FlightStreamReader(fs))) + .map(proxy -> FlightStreamManager.StreamProducerHolder.create(proxy, allocator)) + .or(() -> { + logger.warn("Remote client returned null flight stream for ticket"); + return Optional.empty(); + }); + } catch (Exception e) { + logger.warn("Failed to create proxy producer for remote stream", e); + throw CallStatus.INTERNAL.withCause(e).withDescription("Unable to create proxy stream: " + e.getMessage()).toRuntimeException(); + } + } + + private void processStreamWithProducer( + CallContext context, + FlightStreamManager.StreamProducerHolder producerHolder, + ServerStreamListener listener + ) throws IOException { + try (StreamProducer producer = producerHolder.producer()) { + StreamProducer.BatchedJob batchedJob = producer.createJob(allocator); + if (context.isCancelled()) { + handleCancellation(batchedJob, listener); + return; + } + processStream(producerHolder, batchedJob, listener); + } + } + + private void processStream( + FlightStreamManager.StreamProducerHolder producerHolder, + StreamProducer.BatchedJob batchedJob, + ServerStreamListener listener + ) { + BackpressureStrategy backpressureStrategy = new CustomCallbackBackpressureStrategy(null, batchedJob::onCancel); + backpressureStrategy.register(listener); + StreamProducer.FlushSignal flushSignal = createFlushSignal(batchedJob, listener, backpressureStrategy); + + try (VectorSchemaRoot root = producerHolder.getRoot()) { + listener.start(root); + batchedJob.run(root, flushSignal); + listener.completed(); + } + } + + private StreamProducer.FlushSignal createFlushSignal( + StreamProducer.BatchedJob batchedJob, + ServerStreamListener listener, + BackpressureStrategy backpressureStrategy + ) { + return timeout -> { + BackpressureStrategy.WaitResult result = backpressureStrategy.waitForListener(timeout.millis()); + switch (result) { + case READY: + listener.putNext(); + break; + case TIMEOUT: + batchedJob.onCancel(); + throw CallStatus.TIMED_OUT.withDescription("Stream deadline exceeded").toRuntimeException(); + case CANCELLED: + batchedJob.onCancel(); + throw CallStatus.CANCELLED.withDescription("Stream cancelled by client").toRuntimeException(); + default: + batchedJob.onCancel(); + logger.error("Unexpected backpressure result: {}", result); + throw CallStatus.INTERNAL.withDescription("Unexpected backpressure error: " + result).toRuntimeException(); + } + }; + } + + private void handleCancellation(StreamProducer.BatchedJob batchedJob, ServerStreamListener listener) { + try { + batchedJob.onCancel(); + throw CallStatus.CANCELLED.withDescription("Stream cancelled before processing").toRuntimeException(); + } catch (Exception e) { + logger.error("Unexpected error during cancellation", e); + throw CallStatus.INTERNAL.withCause(e).withDescription("Error during cancellation: " + e.getMessage()).toRuntimeException(); + } + } + + private FlightInfo getLocalFlightInfo(StreamTicket streamTicket, FlightDescriptor descriptor) { + FlightStreamManager.StreamProducerHolder producerHolder = streamManager.getStreamProducer(streamTicket).orElseThrow(() -> { + logger.debug("FlightInfo not found for ticket: {}", streamTicket); + return CallStatus.NOT_FOUND.withDescription("FlightInfo not found").toRuntimeException(); + }); + + Location location = flightClientManager.getFlightClientLocation(streamTicket.getNodeId()).orElseThrow(() -> { + logger.warn("Failed to determine location for node: {}", streamTicket.getNodeId()); + return CallStatus.INTERNAL.withDescription("Internal error determining location").toRuntimeException(); + }); + + try { + Ticket ticket = new Ticket(descriptor.getCommand()); + var schema = producerHolder.getRoot().getSchema(); + FlightEndpoint endpoint = new FlightEndpoint(ticket, location); + return FlightInfo.builder(schema, descriptor, Collections.singletonList(endpoint)) + .setRecords(producerHolder.producer().estimatedRowCount()) + .build(); + } catch (Exception e) { + logger.error("Failed to build FlightInfo", e); + throw CallStatus.INTERNAL.withCause(e).withDescription("Error creating FlightInfo: " + e.getMessage()).toRuntimeException(); + } + } + + private FlightInfo getRemoteFlightInfo(StreamTicket streamTicket, FlightDescriptor descriptor) { + FlightClient remoteClient = flightClientManager.getFlightClient(streamTicket.getNodeId()).orElseThrow(() -> { + logger.warn("No remote client available for node: {}", streamTicket.getNodeId()); + return CallStatus.INTERNAL.withDescription("Client doesn't support Stream").toRuntimeException(); + }); + + try { + return remoteClient.getInfo(descriptor); + } catch (Exception e) { + logger.error("Failed to get remote FlightInfo", e); + throw CallStatus.INTERNAL.withCause(e) + .withDescription("Error retrieving remote FlightInfo: " + e.getMessage()) + .toRuntimeException(); + } + } +} diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/CustomCallbackBackpressureStrategy.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/CustomCallbackBackpressureStrategy.java new file mode 100644 index 0000000000000..0c49ddd78ce30 --- /dev/null +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/CustomCallbackBackpressureStrategy.java @@ -0,0 +1,44 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight.impl; + +import org.apache.arrow.flight.BackpressureStrategy; + +/** + * Base class for backpressure strategy. + */ +public class CustomCallbackBackpressureStrategy extends BackpressureStrategy.CallbackBackpressureStrategy { + private final Runnable readyCallback; + private final Runnable cancelCallback; + + /** + * Constructor for BaseBackpressureStrategy. + * + * @param readyCallback Callback to execute when the listener is ready. + * @param cancelCallback Callback to execute when the listener is cancelled. + */ + CustomCallbackBackpressureStrategy(Runnable readyCallback, Runnable cancelCallback) { + this.readyCallback = readyCallback; + this.cancelCallback = cancelCallback; + } + + /** Callback to execute when the listener is ready. */ + protected void readyCallback() { + if (readyCallback != null) { + readyCallback.run(); + } + } + + /** Callback to execute when the listener is cancelled. */ + protected void cancelCallback() { + if (cancelCallback != null) { + cancelCallback.run(); + } + } +} diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/FlightStreamManager.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/FlightStreamManager.java new file mode 100644 index 0000000000000..1130d59227aab --- /dev/null +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/FlightStreamManager.java @@ -0,0 +1,212 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight.impl; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.arrow.flight.bootstrap.FlightClientManager; +import org.opensearch.arrow.spi.StreamManager; +import org.opensearch.arrow.spi.StreamProducer; +import org.opensearch.arrow.spi.StreamReader; +import org.opensearch.arrow.spi.StreamTicket; +import org.opensearch.arrow.spi.StreamTicketFactory; +import org.opensearch.common.cache.Cache; +import org.opensearch.common.cache.CacheBuilder; +import org.opensearch.common.cache.RemovalReason; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.tasks.TaskId; + +import java.io.IOException; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; + +/** + * FlightStreamManager is a concrete implementation of StreamManager that provides + * an abstraction layer for managing Arrow Flight streams in OpenSearch. + * It encapsulates the details of Flight client operations, allowing consumers to + * work with streams without direct exposure to Flight internals. + */ +public class FlightStreamManager implements StreamManager { + private static final Logger logger = LogManager.getLogger(FlightStreamManager.class); + + private FlightStreamTicketFactory ticketFactory; + private FlightClientManager clientManager; + private Supplier allocatorSupplier; + private final Cache streamProducers; + + // Default cache settings (TODO: Make configurable via settings) + private static final TimeValue DEFAULT_CACHE_EXPIRE = TimeValue.timeValueMinutes(10); + private static final int MAX_WEIGHT = 1000; + + /** + * Holds a StreamProducer along with its metadata and resources + */ + record StreamProducerHolder(StreamProducer producer, BufferAllocator allocator, long creationTime, + AtomicReference root) { + public StreamProducerHolder { + Objects.requireNonNull(producer, "StreamProducer cannot be null"); + Objects.requireNonNull(allocator, "BufferAllocator cannot be null"); + } + + static StreamProducerHolder create(StreamProducer producer, BufferAllocator allocator) { + return new StreamProducerHolder(producer, allocator, System.nanoTime(), new AtomicReference<>(null)); + } + + boolean isExpired() { + return System.nanoTime() - creationTime > producer.getJobDeadline().getNanos(); + } + + /** + * Gets the VectorSchemaRoot associated with the StreamProducer. + * If the root is not set, it creates a new one using the provided BufferAllocator. + */ + public VectorSchemaRoot getRoot() { + return root.updateAndGet(current -> current != null ? current : producer.createRoot(allocator)); + } + } + + /** + * Constructs a new FlightStreamManager. + */ + public FlightStreamManager() { + this.streamProducers = CacheBuilder.builder() + .setExpireAfterWrite(DEFAULT_CACHE_EXPIRE) + .setMaximumWeight(MAX_WEIGHT) + .removalListener(n -> { + if (n.getRemovalReason() != RemovalReason.EXPLICIT) { + try (var unused = n.getValue().producer()) {} catch (IOException e) { + logger.error("Error closing stream producer, this may cause memory leaks.", e); + } + } + }) + .build(); + } + + /** + * Sets the allocator supplier for this FlightStreamManager. + * @param allocatorSupplier The supplier for BufferAllocator instances used for memory management. + * This parameter is required to be non-null. + */ + public void setAllocatorSupplier(Supplier allocatorSupplier) { + this.allocatorSupplier = Objects.requireNonNull(allocatorSupplier, "Allocator supplier cannot be null"); + } + + /** + * Sets the FlightClientManager for managing Flight clients. + * + * @param clientManager The FlightClientManager instance (must be non-null). + */ + public void setClientManager(FlightClientManager clientManager) { + this.clientManager = Objects.requireNonNull(clientManager, "FlightClientManager cannot be null"); + this.ticketFactory = new FlightStreamTicketFactory(clientManager::getLocalNodeId); + } + + /** + * Registers a new stream producer with the StreamManager. + * @param provider The StreamProducer instance to register. + * @param parentTaskId The parent task ID associated with the stream. + * @return A StreamTicket representing the registered stream. + */ + @Override + @SuppressWarnings("unchecked") + public StreamTicket registerStream(StreamProducer provider, TaskId parentTaskId) { + StreamTicket ticket = ticketFactory.newTicket(); + try { + streamProducers.computeIfAbsent( + ticket.getTicketId(), + ticketId -> StreamProducerHolder.create( + (StreamProducer) provider, + allocatorSupplier.get() + ) + ); + } catch (ExecutionException e) { + throw new RuntimeException(e); + } + return ticket; + } + + /** + * Retrieves a StreamReader for the given StreamTicket. + * @param ticket The StreamTicket representing the stream to retrieve. + * @return A StreamReader instance for the specified stream. + */ + @Override + @SuppressWarnings("unchecked") + public StreamReader getStreamReader(StreamTicket ticket) { + FlightClient flightClient = clientManager.getFlightClient(ticket.getNodeId()) + .orElseThrow(() -> new RuntimeException("Flight client not found for node [" + ticket.getNodeId() + "].")); + FlightStream stream = flightClient.getStream(new Ticket(ticket.toBytes())); + return (StreamReader) new FlightStreamReader(stream); + } + + /** + * Retrieves the StreamTicketFactory used by this StreamManager. + * @return The StreamTicketFactory instance associated with this StreamManager. + */ + @Override + public StreamTicketFactory getStreamTicketFactory() { + return ticketFactory; + } + + /** + * Gets the StreamProducer associated with a ticket if it hasn't expired based on its deadline. + * + * @param ticket The StreamTicket identifying the stream + * @return Optional of StreamProducerHolder containing the producer if found and not expired + */ + Optional getStreamProducer(StreamTicket ticket) { + String ticketId = ticket.getTicketId(); + StreamProducerHolder holder = streamProducers.get(ticketId); + if (holder == null) { + logger.debug("No stream producer found for ticket [{}]", ticketId); + return Optional.empty(); + } + + if (holder.isExpired()) { + logger.debug("Stream producer for ticket [{}] has expired", ticketId); + streamProducers.remove(ticketId); + return Optional.empty(); + } + return Optional.of(holder); + } + + /** + * Gets and removes the StreamProducer associated with a ticket. + * Ensure that close is called on the StreamProducer after use. + * @param ticket The StreamTicket identifying the stream + * @return Optional of StreamProducerHolder containing the producer if found + */ + public Optional removeStreamProducer(StreamTicket ticket) { + String ticketId = ticket.getTicketId(); + StreamProducerHolder holder = streamProducers.get(ticketId); + if (holder == null) { + return Optional.empty(); + } + streamProducers.remove(ticketId); + return Optional.of(holder); + } + + /** + * Closes the StreamManager and cancels all associated streams. + * This method should be called when the StreamManager is no longer needed to clean up resources. + * It is recommended to implement this method to cancel all threads and clear the streamManager queue. + */ + @Override + public void close() throws Exception { + streamProducers.invalidateAll(); + } +} diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/FlightStreamReader.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/FlightStreamReader.java new file mode 100644 index 0000000000000..d9e366dca30e2 --- /dev/null +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/FlightStreamReader.java @@ -0,0 +1,61 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight.impl; + +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.opensearch.ExceptionsHelper; +import org.opensearch.arrow.spi.StreamReader; + +/** + * FlightStreamReader is a wrapper class that adapts the FlightStream interface + * to the StreamReader interface. + */ +public class FlightStreamReader implements StreamReader { + + private final FlightStream flightStream; + + /** + * Constructs a FlightStreamReader with the given FlightStream. + * + * @param flightStream The FlightStream to be adapted. + */ + public FlightStreamReader(FlightStream flightStream) { + this.flightStream = flightStream; + } + + /** + * Moves the flightStream to the next batch of data. + * @return true if there is a next batch of data, false otherwise. + * @throws FlightRuntimeException if an error occurs while advancing to the next batch like early termination of stream + */ + @Override + public boolean next() throws FlightRuntimeException { + return flightStream.next(); + } + + /** + * Returns the VectorSchemaRoot containing the current batch of data. + * @return The VectorSchemaRoot containing the current batch of data. + * @throws FlightRuntimeException if an error occurs while retrieving the root like early termination of stream + */ + @Override + public VectorSchemaRoot getRoot() throws FlightRuntimeException { + return flightStream.getRoot(); + } + + /** + * Closes the flightStream. + */ + @Override + public void close() { + ExceptionsHelper.catchAsRuntimeException(flightStream::close); + } +} diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/FlightStreamTicket.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/FlightStreamTicket.java new file mode 100644 index 0000000000000..baa9e79fec6a1 --- /dev/null +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/FlightStreamTicket.java @@ -0,0 +1,111 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight.impl; + +import org.opensearch.arrow.spi.StreamTicket; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Base64; +import java.util.Objects; + +class FlightStreamTicket implements StreamTicket { + private static final int MAX_TOTAL_SIZE = 4096; + private static final int MAX_ID_LENGTH = 256; + + private final String ticketID; + private final String nodeID; + + public FlightStreamTicket(String ticketID, String nodeID) { + this.ticketID = ticketID; + this.nodeID = nodeID; + } + + @Override + public String getTicketId() { + return ticketID; + } + + @Override + public String getNodeId() { + return nodeID; + } + + @Override + public byte[] toBytes() { + byte[] ticketIDBytes = ticketID.getBytes(StandardCharsets.UTF_8); + byte[] nodeIDBytes = nodeID.getBytes(StandardCharsets.UTF_8); + + if (ticketIDBytes.length > Short.MAX_VALUE || nodeIDBytes.length > Short.MAX_VALUE) { + throw new IllegalArgumentException("Field lengths exceed the maximum allowed size."); + } + ByteBuffer buffer = ByteBuffer.allocate(2 + ticketIDBytes.length + 2 + nodeIDBytes.length); + buffer.putShort((short) ticketIDBytes.length); + buffer.putShort((short) nodeIDBytes.length); + buffer.put(ticketIDBytes); + buffer.put(nodeIDBytes); + return Base64.getEncoder().encode(buffer.array()); + } + + static StreamTicket fromBytes(byte[] bytes) { + if (bytes == null || bytes.length < 4) { + throw new IllegalArgumentException("Invalid byte array input."); + } + + if (bytes.length > MAX_TOTAL_SIZE) { + throw new IllegalArgumentException("Input exceeds maximum allowed size"); + } + + ByteBuffer buffer = ByteBuffer.wrap(Base64.getDecoder().decode(bytes)); + + short ticketIDLength = buffer.getShort(); + if (ticketIDLength < 0 || ticketIDLength > MAX_ID_LENGTH) { + throw new IllegalArgumentException("Invalid ticketID length: " + ticketIDLength); + } + + short nodeIDLength = buffer.getShort(); + if (nodeIDLength < 0 || nodeIDLength > MAX_ID_LENGTH) { + throw new IllegalArgumentException("Invalid nodeID length: " + nodeIDLength); + } + + byte[] ticketIDBytes = new byte[ticketIDLength]; + if (buffer.remaining() < ticketIDLength) { + throw new IllegalArgumentException("Malformed byte array. Not enough data for TicketId."); + } + buffer.get(ticketIDBytes); + + byte[] nodeIDBytes = new byte[nodeIDLength]; + if (buffer.remaining() < nodeIDLength) { + throw new IllegalArgumentException("Malformed byte array. Not enough data for NodeId."); + } + buffer.get(nodeIDBytes); + + String ticketID = new String(ticketIDBytes, StandardCharsets.UTF_8); + String nodeID = new String(nodeIDBytes, StandardCharsets.UTF_8); + return new FlightStreamTicket(ticketID, nodeID); + } + + @Override + public int hashCode() { + return Objects.hash(ticketID, nodeID); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + FlightStreamTicket that = (FlightStreamTicket) obj; + return Objects.equals(ticketID, that.ticketID) && Objects.equals(nodeID, that.nodeID); + } + + @Override + public String toString() { + return "FlightStreamTicket{ticketID='" + ticketID + "', nodeID='" + nodeID + "'}"; + } +} diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/FlightStreamTicketFactory.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/FlightStreamTicketFactory.java new file mode 100644 index 0000000000000..473eb92cf2db3 --- /dev/null +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/FlightStreamTicketFactory.java @@ -0,0 +1,60 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight.impl; + +import org.opensearch.arrow.spi.StreamTicket; +import org.opensearch.arrow.spi.StreamTicketFactory; +import org.opensearch.common.annotation.ExperimentalApi; + +import java.util.UUID; +import java.util.function.Supplier; + +/** + * Default implementation of StreamTicketFactory + */ +@ExperimentalApi +public class FlightStreamTicketFactory implements StreamTicketFactory { + + private final Supplier nodeId; + + /** + * Constructs a new DefaultStreamTicketFactory instance. + * + * @param nodeId A Supplier that provides the node ID for the StreamTicket + */ + public FlightStreamTicketFactory(Supplier nodeId) { + this.nodeId = nodeId; + } + + /** + * Creates a new StreamTicket with a unique ticket ID. + * + * @return A new StreamTicket instance + */ + @Override + public StreamTicket newTicket() { + return new FlightStreamTicket(generateUniqueTicket(), nodeId.get()); + } + + /** + * Deserializes a StreamTicket from its byte representation. + * + * @param bytes The byte array containing the serialized ticket data + * @return A StreamTicket instance reconstructed from the byte array + * @throws IllegalArgumentException if bytes is null or invalid + */ + @Override + public StreamTicket fromBytes(byte[] bytes) { + return FlightStreamTicket.fromBytes(bytes); + } + + private String generateUniqueTicket() { + return UUID.randomUUID().toString(); + } +} diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/ProxyStreamProducer.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/ProxyStreamProducer.java new file mode 100644 index 0000000000000..75a8d07266e07 --- /dev/null +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/ProxyStreamProducer.java @@ -0,0 +1,122 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight.impl; + +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.opensearch.ExceptionsHelper; +import org.opensearch.arrow.spi.StreamProducer; +import org.opensearch.arrow.spi.StreamReader; +import org.opensearch.arrow.spi.StreamTicket; +import org.opensearch.common.unit.TimeValue; + +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * ProxyStreamProvider acts as forward proxy for FlightStream. + * It creates a BatchedJob to handle the streaming of data from the remote FlightStream. + * This is useful when stream is not present locally and needs to be fetched from a node + * retrieved using {@link StreamTicket#getNodeId()} where it is present. + */ +public class ProxyStreamProducer implements StreamProducer { + + private final StreamReader remoteStream; + + /** + * Constructs a new ProxyStreamProducer instance. + * + * @param remoteStream The remote FlightStream to be proxied. + */ + public ProxyStreamProducer(StreamReader remoteStream) { + this.remoteStream = remoteStream; + } + + /** + * Creates a VectorSchemaRoot for the remote FlightStream. + * @param allocator The allocator to use for creating vectors + * @return A VectorSchemaRoot representing the schema of the remote FlightStream + */ + @Override + public VectorSchemaRoot createRoot(BufferAllocator allocator) { + return remoteStream.getRoot(); + } + + /** + * Creates a BatchedJob + * @param allocator The allocator to use for any additional memory allocations + */ + @Override + public BatchedJob createJob(BufferAllocator allocator) { + return new ProxyBatchedJob(remoteStream); + } + + /** + * Returns the deadline for the remote FlightStream. + * Since the stream is not present locally, the deadline is set to -1. It piggybacks on remote stream expiration + * @return The deadline for the remote FlightStream + */ + @Override + public TimeValue getJobDeadline() { + return TimeValue.MINUS_ONE; + } + + /** + * Provides an estimate of the total number of rows that will be produced. + */ + @Override + public int estimatedRowCount() { + return remoteStream.getRoot().getRowCount(); + } + + /** + * Task action name + */ + @Override + public String getAction() { + // TODO get it from remote flight stream + throw new UnsupportedOperationException("Not implemented yet"); + } + + /** + * Closes the remote FlightStream. + */ + @Override + public void close() { + ExceptionsHelper.catchAsRuntimeException(remoteStream::close); + } + + static class ProxyBatchedJob implements BatchedJob { + + private final StreamReader remoteStream; + private final AtomicBoolean isCancelled = new AtomicBoolean(false); + + ProxyBatchedJob(StreamReader remoteStream) { + this.remoteStream = remoteStream; + } + + @Override + public void run(VectorSchemaRoot root, FlushSignal flushSignal) { + while (!isCancelled.get() && remoteStream.next()) { + flushSignal.awaitConsumption(TimeValue.timeValueMillis(1000)); + } + } + + @Override + public void onCancel() { + isCancelled.set(true); + } + + @Override + public boolean isCancelled() { + // Proxy stream don't have any business logic to set this flag, + // they piggyback on remote stream getting cancelled. + return isCancelled.get(); + } + } +} diff --git a/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/package-info.java b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/package-info.java new file mode 100644 index 0000000000000..90ca54b44a55d --- /dev/null +++ b/plugins/arrow-flight-rpc/src/main/java/org/opensearch/arrow/flight/impl/package-info.java @@ -0,0 +1,12 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * Core components and implementations for OpenSearch Flight service, including base producers and consumers. + */ +package org.opensearch.arrow.flight.impl; diff --git a/plugins/arrow-flight-rpc/src/main/plugin-metadata/plugin-security.policy b/plugins/arrow-flight-rpc/src/main/plugin-metadata/plugin-security.policy index 803350a578009..40d584198fd48 100644 --- a/plugins/arrow-flight-rpc/src/main/plugin-metadata/plugin-security.policy +++ b/plugins/arrow-flight-rpc/src/main/plugin-metadata/plugin-security.policy @@ -6,17 +6,10 @@ * compatible open source license. */ -grant codeBase "${codebase.netty-common}" { - permission java.net.SocketPermission "*", "accept,connect,listen,resolve"; - permission java.lang.RuntimePermission "*", "setContextClassLoader"; -}; - -grant codeBase "${codebase.grpc-core}" { - permission java.net.SocketPermission "*", "accept,connect,listen,resolve"; - permission java.lang.RuntimePermission "*", "setContextClassLoader"; -}; - grant { + // Memory access + permission java.lang.RuntimePermission "accessClassInPackage.sun.misc"; + // arrow flight service permissions permission java.util.PropertyPermission "arrow.allocation.manager.type", "write"; permission java.util.PropertyPermission "arrow.enable_null_check_for_get", "write"; @@ -29,7 +22,6 @@ grant { permission java.util.PropertyPermission "io.netty.tryUnsafe", "write"; // Needed for netty based arrow flight server for netty configs related to buffer allocator - permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; permission java.util.PropertyPermission "arrow.allocation.manager.type", "write"; permission java.lang.RuntimePermission "modifyThreadGroup"; @@ -39,7 +31,11 @@ grant { // Reflection access needed by Arrow permission java.lang.RuntimePermission "accessDeclaredMembers"; permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; + permission java.lang.RuntimePermission "getClassLoader"; // Memory access permission java.lang.RuntimePermission "accessClassInPackage.sun.misc"; + + // needed by netty-common + permission java.lang.RuntimePermission "*", "setContextClassLoader"; }; diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/FlightStreamPluginTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/FlightStreamPluginTests.java index 2573f0032f45b..e1d7d7d95d4a9 100644 --- a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/FlightStreamPluginTests.java +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/FlightStreamPluginTests.java @@ -8,8 +8,8 @@ package org.opensearch.arrow.flight; -import org.opensearch.arrow.flight.api.FlightServerInfoAction; -import org.opensearch.arrow.flight.api.NodesFlightInfoAction; +import org.opensearch.arrow.flight.api.flightinfo.FlightServerInfoAction; +import org.opensearch.arrow.flight.api.flightinfo.NodesFlightInfoAction; import org.opensearch.arrow.flight.bootstrap.FlightService; import org.opensearch.arrow.flight.bootstrap.FlightStreamPlugin; import org.opensearch.arrow.spi.StreamManager; @@ -27,7 +27,7 @@ import java.io.IOException; import java.util.Collection; import java.util.List; -import java.util.function.Supplier; +import java.util.Optional; import static org.opensearch.common.util.FeatureFlags.ARROW_STREAMS; import static org.opensearch.plugins.NetworkPlugin.AuxTransport.AUX_TRANSPORT_TYPES_KEY; @@ -76,8 +76,8 @@ public void testPluginEnabled() throws IOException { assertFalse(executorBuilders.isEmpty()); assertEquals(2, executorBuilders.size()); - Supplier streamManager = plugin.getStreamManager(); - assertNotNull(streamManager); + Optional streamManager = plugin.getStreamManager(); + assertTrue(streamManager.isPresent()); List> settings = plugin.getSettings(); assertNotNull(settings); diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/FlightServerInfoActionTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/FlightServerInfoActionTests.java similarity index 98% rename from plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/FlightServerInfoActionTests.java rename to plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/FlightServerInfoActionTests.java index 6cb75d4a93dbe..d3115fc745475 100644 --- a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/FlightServerInfoActionTests.java +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/FlightServerInfoActionTests.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.arrow.flight.api; +package org.opensearch.arrow.flight.api.flightinfo; import org.opensearch.cluster.ClusterName; import org.opensearch.cluster.node.DiscoveryNode; diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/NodeFlightInfoTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/NodeFlightInfoTests.java similarity index 99% rename from plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/NodeFlightInfoTests.java rename to plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/NodeFlightInfoTests.java index 2f8d7deb06f3f..59e695313c16e 100644 --- a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/NodeFlightInfoTests.java +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/NodeFlightInfoTests.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.arrow.flight.api; +package org.opensearch.arrow.flight.api.flightinfo; import org.opensearch.Version; import org.opensearch.cluster.node.DiscoveryNode; diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/NodesFlightInfoRequestTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoRequestTests.java similarity index 96% rename from plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/NodesFlightInfoRequestTests.java rename to plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoRequestTests.java index 756177423fe6f..ef8f88b78c3ee 100644 --- a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/NodesFlightInfoRequestTests.java +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoRequestTests.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.arrow.flight.api; +package org.opensearch.arrow.flight.api.flightinfo; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/NodesFlightInfoResponseTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoResponseTests.java similarity index 99% rename from plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/NodesFlightInfoResponseTests.java rename to plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoResponseTests.java index 49a6cc6bacf40..707a222fe381f 100644 --- a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/NodesFlightInfoResponseTests.java +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/NodesFlightInfoResponseTests.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.arrow.flight.api; +package org.opensearch.arrow.flight.api.flightinfo; import org.opensearch.Version; import org.opensearch.action.FailedNodeException; diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/TransportNodesFlightInfoActionTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/TransportNodesFlightInfoActionTests.java similarity index 99% rename from plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/TransportNodesFlightInfoActionTests.java rename to plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/TransportNodesFlightInfoActionTests.java index d9d8af5920d61..6bd70eec4ad3a 100644 --- a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/TransportNodesFlightInfoActionTests.java +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/api/flightinfo/TransportNodesFlightInfoActionTests.java @@ -6,7 +6,7 @@ * compatible open source license. */ -package org.opensearch.arrow.flight.api; +package org.opensearch.arrow.flight.api.flightinfo; import org.opensearch.Version; import org.opensearch.action.FailedNodeException; diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/bootstrap/FlightClientManagerTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/bootstrap/FlightClientManagerTests.java index ce2f0df7f5f55..e077acc8e390a 100644 --- a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/bootstrap/FlightClientManagerTests.java +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/bootstrap/FlightClientManagerTests.java @@ -11,10 +11,10 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.RootAllocator; import org.opensearch.Version; -import org.opensearch.arrow.flight.api.NodeFlightInfo; -import org.opensearch.arrow.flight.api.NodesFlightInfoAction; -import org.opensearch.arrow.flight.api.NodesFlightInfoRequest; -import org.opensearch.arrow.flight.api.NodesFlightInfoResponse; +import org.opensearch.arrow.flight.api.flightinfo.NodeFlightInfo; +import org.opensearch.arrow.flight.api.flightinfo.NodesFlightInfoAction; +import org.opensearch.arrow.flight.api.flightinfo.NodesFlightInfoRequest; +import org.opensearch.arrow.flight.api.flightinfo.NodesFlightInfoResponse; import org.opensearch.arrow.flight.bootstrap.tls.SslContextProvider; import org.opensearch.cluster.ClusterChangedEvent; import org.opensearch.cluster.ClusterName; @@ -42,6 +42,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -107,7 +108,9 @@ public void setUp() throws Exception { clientManager.clusterChanged(event); assertBusy(() -> { assertEquals("Flight client isn't built in time limit", 2, clientManager.getFlightClients().size()); + assertTrue("local_node should exist", clientManager.getFlightClient("local_node").isPresent()); assertNotNull("local_node should exist", clientManager.getFlightClient("local_node").get()); + assertTrue("remote_node should exist", clientManager.getFlightClient("remote_node").isPresent()); assertNotNull("remote_node should exist", clientManager.getFlightClient("remote_node").get()); }, 2, TimeUnit.SECONDS); } @@ -375,8 +378,9 @@ public void testFailedClusterUpdateButSuccessfulDirectRequest() throws Exception private void validateNodes() { for (DiscoveryNode node : state.nodes()) { - FlightClient client = clientManager.getFlightClient(node.getId()).get(); - assertNotNull("Flight client should be created for existing node", client); + Optional client = clientManager.getFlightClient(node.getId()); + assertTrue("Flight client should be created for node [" + node.getId() + "].", client.isPresent()); + assertNotNull("Flight client should be created for node [" + node.getId() + "].", client.get()); } } diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/bootstrap/FlightServiceTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/bootstrap/FlightServiceTests.java index d8f5d5ba6b45b..a7274eb756458 100644 --- a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/bootstrap/FlightServiceTests.java +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/bootstrap/FlightServiceTests.java @@ -94,7 +94,7 @@ public void testStartAndStop() throws Exception { testService.start(); testService.stop(); testService.start(); - assertNull(testService.getStreamManager()); + assertNotNull(testService.getStreamManager()); } } diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/BaseFlightProducerTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/BaseFlightProducerTests.java new file mode 100644 index 0000000000000..65caae55e9e40 --- /dev/null +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/BaseFlightProducerTests.java @@ -0,0 +1,732 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight.impl; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.memory.ArrowBuf; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.message.IpcOption; +import org.opensearch.arrow.flight.bootstrap.FlightClientManager; +import org.opensearch.arrow.spi.StreamProducer; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.opensearch.common.util.FeatureFlags.ARROW_STREAMS; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class BaseFlightProducerTests extends OpenSearchTestCase { + + private BaseFlightProducer baseFlightProducer; + private FlightStreamManager streamManager; + private StreamProducer streamProducer; + private StreamProducer.BatchedJob batchedJob; + private static final String LOCAL_NODE_ID = "localNodeId"; + private static final FlightClientManager flightClientManager = mock(FlightClientManager.class); + private final Ticket ticket = new Ticket((new FlightStreamTicket("test-ticket", LOCAL_NODE_ID)).toBytes()); + private BufferAllocator allocator; + + @LockFeatureFlag(ARROW_STREAMS) + @Override + @SuppressWarnings("unchecked") + public void setUp() throws Exception { + super.setUp(); + streamManager = mock(FlightStreamManager.class); + when(streamManager.getStreamTicketFactory()).thenReturn(new FlightStreamTicketFactory(() -> LOCAL_NODE_ID)); + when(flightClientManager.getLocalNodeId()).thenReturn(LOCAL_NODE_ID); + allocator = mock(BufferAllocator.class); + streamProducer = mock(StreamProducer.class); + batchedJob = mock(StreamProducer.BatchedJob.class); + baseFlightProducer = new BaseFlightProducer(flightClientManager, streamManager, allocator); + } + + private static class TestServerStreamListener implements FlightProducer.ServerStreamListener { + private final CountDownLatch completionLatch = new CountDownLatch(1); + private final AtomicInteger putNextCount = new AtomicInteger(0); + private final AtomicBoolean isCancelled = new AtomicBoolean(false); + private Throwable error; + private final AtomicBoolean dataConsumed = new AtomicBoolean(false); + private final AtomicBoolean ready = new AtomicBoolean(false); + private Runnable onReadyHandler; + private Runnable onCancelHandler; + + @Override + public void putNext() { + assertFalse(dataConsumed.get()); + putNextCount.incrementAndGet(); + dataConsumed.set(true); + } + + @Override + public boolean isReady() { + return ready.get(); + } + + public void setReady(boolean val) { + ready.set(val); + if (this.onReadyHandler != null) { + this.onReadyHandler.run(); + } + } + + @Override + public void start(VectorSchemaRoot root) { + // No-op for this test + } + + @Override + public void start(VectorSchemaRoot root, DictionaryProvider dictionaries, IpcOption option) {} + + @Override + public void putNext(ArrowBuf metadata) { + putNext(); + } + + @Override + public void putMetadata(ArrowBuf metadata) {} + + @Override + public void completed() { + completionLatch.countDown(); + } + + @Override + public void error(Throwable t) { + error = t; + completionLatch.countDown(); + } + + @Override + public boolean isCancelled() { + return isCancelled.get(); + } + + @Override + public void setOnReadyHandler(Runnable handler) { + this.onReadyHandler = handler; + } + + @Override + public void setOnCancelHandler(Runnable handler) { + this.onCancelHandler = handler; + } + + public void resetConsumptionLatch() { + dataConsumed.set(false); + } + + public boolean getDataConsumed() { + return dataConsumed.get(); + } + + public int getPutNextCount() { + return putNextCount.get(); + } + + public Throwable getError() { + return error; + } + + public void cancel() { + isCancelled.set(true); + if (this.onCancelHandler != null) { + this.onCancelHandler.run(); + } + } + } + + public void testGetStream_SuccessfulFlow() throws Exception { + final VectorSchemaRoot root = mock(VectorSchemaRoot.class); + when(streamManager.removeStreamProducer(any(FlightStreamTicket.class))).thenReturn( + Optional.of(FlightStreamManager.StreamProducerHolder.create(streamProducer, allocator)) + ); + when(streamProducer.createJob(any(BufferAllocator.class))).thenReturn(batchedJob); + when(streamProducer.createRoot(any(BufferAllocator.class))).thenReturn(root); + + AtomicInteger flushCount = new AtomicInteger(0); + TestServerStreamListener listener = new TestServerStreamListener(); + doAnswer(invocation -> { + StreamProducer.FlushSignal flushSignal = invocation.getArgument(1); + for (int i = 0; i < 3; i++) { + Thread clientThread = new Thread(() -> { + listener.setReady(false); + listener.setReady(true); + }); + listener.setReady(false); + clientThread.start(); + flushSignal.awaitConsumption(TimeValue.timeValueMillis(100)); + assertTrue(listener.getDataConsumed()); + flushCount.incrementAndGet(); + listener.resetConsumptionLatch(); + } + return null; + }).when(batchedJob).run(any(VectorSchemaRoot.class), any(StreamProducer.FlushSignal.class)); + baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), ticket, listener); + + assertNull(listener.getError()); + assertEquals(3, listener.getPutNextCount()); + assertEquals(3, flushCount.get()); + + verify(streamManager).removeStreamProducer(any(FlightStreamTicket.class)); + verify(root).close(); + } + + public void testGetStream_WithSlowClient() throws Exception { + final VectorSchemaRoot root = mock(VectorSchemaRoot.class); + when(streamManager.removeStreamProducer(any(FlightStreamTicket.class))).thenReturn( + Optional.of(FlightStreamManager.StreamProducerHolder.create(streamProducer, allocator)) + ); + when(streamProducer.createJob(any(BufferAllocator.class))).thenReturn(batchedJob); + when(streamProducer.createRoot(any(BufferAllocator.class))).thenReturn(root); + + AtomicInteger flushCount = new AtomicInteger(0); + TestServerStreamListener listener = new TestServerStreamListener(); + AtomicBoolean isCancelled = new AtomicBoolean(false); + doAnswer(invocation -> { + StreamProducer.FlushSignal flushSignal = invocation.getArgument(1); + for (int i = 0; i < 5; i++) { + Thread clientThread = new Thread(() -> { + try { + listener.setReady(false); + Thread.sleep(100); + listener.setReady(true); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + listener.setReady(false); + clientThread.start(); + flushSignal.awaitConsumption(TimeValue.timeValueMillis(300)); // Longer than client sleep + if (isCancelled.get()) { + break; + } + assertTrue(listener.getDataConsumed()); + flushCount.incrementAndGet(); + listener.resetConsumptionLatch(); + } + return null; + }).when(batchedJob).run(any(), any()); + doAnswer(invocation -> { + isCancelled.set(true); + return null; + }).when(batchedJob).onCancel(); + + baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), ticket, listener); + assertNull(listener.getError()); + assertEquals(5, listener.getPutNextCount()); + assertEquals(5, flushCount.get()); + + verify(streamManager).removeStreamProducer(any(FlightStreamTicket.class)); + verify(root).close(); + } + + public void testGetStream_WithSlowClientTimeout() throws Exception { + final VectorSchemaRoot root = mock(VectorSchemaRoot.class); + when(streamManager.removeStreamProducer(any(FlightStreamTicket.class))).thenReturn( + Optional.of(FlightStreamManager.StreamProducerHolder.create(streamProducer, allocator)) + ); + when(streamProducer.createJob(any(BufferAllocator.class))).thenReturn(batchedJob); + when(streamProducer.createRoot(any(BufferAllocator.class))).thenReturn(root); + + AtomicInteger flushCount = new AtomicInteger(0); + TestServerStreamListener listener = new TestServerStreamListener(); + AtomicBoolean isCancelled = new AtomicBoolean(false); + + doAnswer(invocation -> { + StreamProducer.FlushSignal flushSignal = invocation.getArgument(1); + Thread clientThread = new Thread(() -> { + try { + listener.setReady(false); + Thread.sleep(400); // Longer than timeout + listener.setReady(true); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + }); + listener.setReady(false); + clientThread.start(); + flushSignal.awaitConsumption(TimeValue.timeValueMillis(100)); // Shorter than client sleep + return null; + }).when(batchedJob).run(any(), any()); + doAnswer(invocation -> { + isCancelled.set(true); + return null; + }).when(batchedJob).onCancel(); + + FlightRuntimeException exception = expectThrows( + FlightRuntimeException.class, + () -> baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), ticket, listener) + ); + + assertEquals("Stream deadline exceeded", exception.getMessage()); + assertNotNull(listener.getError()); + assertEquals("Stream deadline exceeded", listener.getError().getMessage()); + assertEquals(0, listener.getPutNextCount()); + assertEquals(0, flushCount.get()); + + verify(streamManager).removeStreamProducer(any(FlightStreamTicket.class)); + verify(root).close(); + } + + public void testGetStream_WithClientCancel() throws Exception { + final VectorSchemaRoot root = mock(VectorSchemaRoot.class); + when(streamManager.removeStreamProducer(any(FlightStreamTicket.class))).thenReturn( + Optional.of(FlightStreamManager.StreamProducerHolder.create(streamProducer, allocator)) + ); + when(streamProducer.createJob(any(BufferAllocator.class))).thenReturn(batchedJob); + when(streamProducer.createRoot(any(BufferAllocator.class))).thenReturn(root); + + AtomicInteger flushCount = new AtomicInteger(0); + TestServerStreamListener listener = new TestServerStreamListener(); + AtomicBoolean isCancelled = new AtomicBoolean(false); + + doAnswer(invocation -> { + StreamProducer.FlushSignal flushSignal = invocation.getArgument(1); + for (int i = 0; i < 5; i++) { + int finalI = i; + Thread clientThread = new Thread(() -> { + if (finalI == 4) { + listener.cancel(); + } else { + listener.setReady(true); + } + }); + listener.setReady(false); + clientThread.start(); + flushSignal.awaitConsumption(TimeValue.timeValueMillis(100)); + if (isCancelled.get()) { + break; + } + assertTrue(listener.getDataConsumed()); + flushCount.incrementAndGet(); + listener.resetConsumptionLatch(); + } + return null; + }).when(batchedJob).run(any(), any()); + doAnswer(invocation -> { + isCancelled.set(true); + return null; + }).when(batchedJob).onCancel(); + + FlightRuntimeException exception = expectThrows( + FlightRuntimeException.class, + () -> baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), ticket, listener) + ); + + assertEquals("Stream cancelled by client", exception.getMessage()); + assertNotNull(listener.getError()); + assertEquals("Stream cancelled by client", listener.getError().getMessage()); + assertEquals(4, listener.getPutNextCount()); + assertEquals(4, flushCount.get()); + + verify(streamManager).removeStreamProducer(any(FlightStreamTicket.class)); + verify(root).close(); + } + + public void testGetStream_WithUnresponsiveClient() throws Exception { + final VectorSchemaRoot root = mock(VectorSchemaRoot.class); + when(streamManager.removeStreamProducer(any(FlightStreamTicket.class))).thenReturn( + Optional.of(FlightStreamManager.StreamProducerHolder.create(streamProducer, allocator)) + ); + when(streamProducer.createJob(any(BufferAllocator.class))).thenReturn(batchedJob); + when(streamProducer.createRoot(any(BufferAllocator.class))).thenReturn(root); + + AtomicInteger flushCount = new AtomicInteger(0); + TestServerStreamListener listener = new TestServerStreamListener(); + AtomicBoolean isCancelled = new AtomicBoolean(false); + + doAnswer(invocation -> { + StreamProducer.FlushSignal flushSignal = invocation.getArgument(1); + Thread clientThread = new Thread(() -> listener.setReady(false)); // Never sets ready + listener.setReady(false); + clientThread.start(); + flushSignal.awaitConsumption(TimeValue.timeValueMillis(100)); + return null; + }).when(batchedJob).run(any(), any()); + doAnswer(invocation -> { + isCancelled.set(true); + return null; + }).when(batchedJob).onCancel(); + + FlightRuntimeException exception = expectThrows( + FlightRuntimeException.class, + () -> baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), ticket, listener) + ); + + assertEquals("Stream deadline exceeded", exception.getMessage()); + assertNotNull(listener.getError()); + assertEquals("Stream deadline exceeded", listener.getError().getMessage()); + assertEquals(0, listener.getPutNextCount()); + assertEquals(0, flushCount.get()); + + verify(streamManager).removeStreamProducer(any(FlightStreamTicket.class)); + verify(root).close(); + } + + public void testGetStream_WithServerBackpressure() throws Exception { + final VectorSchemaRoot root = mock(VectorSchemaRoot.class); + when(streamManager.removeStreamProducer(any(FlightStreamTicket.class))).thenReturn( + Optional.of(FlightStreamManager.StreamProducerHolder.create(streamProducer, allocator)) + ); + when(streamProducer.createJob(any(BufferAllocator.class))).thenReturn(batchedJob); + when(streamProducer.createRoot(any(BufferAllocator.class))).thenReturn(root); + + TestServerStreamListener listener = new TestServerStreamListener(); + AtomicInteger flushCount = new AtomicInteger(0); + doAnswer(invocation -> { + StreamProducer.FlushSignal flushSignal = invocation.getArgument(1); + for (int i = 0; i < 5; i++) { + Thread clientThread = new Thread(() -> { + listener.setReady(false); + listener.setReady(true); + }); + listener.setReady(false); + clientThread.start(); + Thread.sleep(100); // Simulate server backpressure + flushSignal.awaitConsumption(TimeValue.timeValueMillis(200)); // Longer than sleep + assertTrue(listener.getDataConsumed()); + flushCount.incrementAndGet(); + listener.resetConsumptionLatch(); + } + return null; + }).when(batchedJob).run(any(VectorSchemaRoot.class), any(StreamProducer.FlushSignal.class)); + + baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), ticket, listener); + + assertNull(listener.getError()); + assertEquals(5, listener.getPutNextCount()); + assertEquals(5, flushCount.get()); + + verify(streamManager).removeStreamProducer(any(FlightStreamTicket.class)); + verify(root).close(); + } + + public void testGetStream_WithServerError() throws Exception { + final VectorSchemaRoot root = mock(VectorSchemaRoot.class); + when(streamManager.removeStreamProducer(any(FlightStreamTicket.class))).thenReturn( + Optional.of(FlightStreamManager.StreamProducerHolder.create(streamProducer, allocator)) + ); + when(streamProducer.createJob(any(BufferAllocator.class))).thenReturn(batchedJob); + when(streamProducer.createRoot(any(BufferAllocator.class))).thenReturn(root); + + TestServerStreamListener listener = new TestServerStreamListener(); + AtomicInteger flushCount = new AtomicInteger(0); + doAnswer(invocation -> { + StreamProducer.FlushSignal flushSignal = invocation.getArgument(1); + for (int i = 0; i < 5; i++) { + Thread clientThread = new Thread(() -> { + listener.setReady(false); + listener.setReady(true); + }); + listener.setReady(false); + clientThread.start(); + if (i == 4) { + throw new RuntimeException("Server error"); + } + flushSignal.awaitConsumption(TimeValue.timeValueMillis(100)); + assertTrue(listener.getDataConsumed()); + flushCount.incrementAndGet(); + listener.resetConsumptionLatch(); + } + return null; + }).when(batchedJob).run(any(VectorSchemaRoot.class), any(StreamProducer.FlushSignal.class)); + + FlightRuntimeException exception = expectThrows( + FlightRuntimeException.class, + () -> baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), ticket, listener) + ); + + assertEquals("Unexpected server error", exception.getMessage()); + assertNotNull(listener.getError()); + assertEquals("Unexpected server error", listener.getError().getMessage()); + assertEquals(4, listener.getPutNextCount()); + assertEquals(4, flushCount.get()); + + verify(streamManager).removeStreamProducer(any(FlightStreamTicket.class)); + verify(root).close(); + } + + public void testGetStream_StreamNotFound() throws Exception { + when(streamManager.removeStreamProducer(any(FlightStreamTicket.class))).thenReturn(Optional.empty()); + TestServerStreamListener listener = new TestServerStreamListener(); + + FlightRuntimeException exception = expectThrows( + FlightRuntimeException.class, + () -> baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), ticket, listener) + ); + + assertEquals("Stream not found", exception.getMessage()); + assertNotNull(listener.getError()); + assertEquals("Stream not found", listener.getError().getMessage()); + assertEquals(0, listener.getPutNextCount()); + + verify(streamManager).removeStreamProducer(any(FlightStreamTicket.class)); + } + + public void testGetStreamRemoteNode() throws Exception { + final String remoteNodeId = "remote-node"; + FlightStreamTicket remoteTicket = new FlightStreamTicket("test-id", remoteNodeId); + FlightClient remoteClient = mock(FlightClient.class); + FlightStream mockFlightStream = mock(FlightStream.class); + + when(flightClientManager.getFlightClient(remoteNodeId)).thenReturn(Optional.of(remoteClient)); + when(remoteClient.getStream(any(Ticket.class))).thenReturn(mockFlightStream); + TestServerStreamListener listener = new TestServerStreamListener(); + + baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), new Ticket(remoteTicket.toBytes()), listener); + verify(remoteClient).getStream(any(Ticket.class)); + } + + public void testGetStreamRemoteNodeWithNonExistentClient() throws Exception { + final String remoteNodeId = "remote-node-5"; + FlightStreamTicket remoteTicket = new FlightStreamTicket("test-id", remoteNodeId); + when(flightClientManager.getFlightClient(remoteNodeId)).thenReturn(Optional.empty()); + TestServerStreamListener listener = new TestServerStreamListener(); + + FlightRuntimeException exception = expectThrows( + FlightRuntimeException.class, + () -> baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), new Ticket(remoteTicket.toBytes()), listener) + ); + + assertEquals("Stream not found", exception.getMessage()); + assertNotNull(listener.getError()); + assertEquals("Stream not found", listener.getError().getMessage()); + assertEquals(0, listener.getPutNextCount()); + } + + public void testGetFlightInfo() { + final VectorSchemaRoot root = mock(VectorSchemaRoot.class); + when(streamManager.getStreamProducer(any(FlightStreamTicket.class))).thenReturn( + Optional.of(FlightStreamManager.StreamProducerHolder.create(streamProducer, allocator)) + ); + when(streamProducer.createRoot(any(BufferAllocator.class))).thenReturn(root); + + Location location = Location.forGrpcInsecure(LOCAL_NODE_ID, 8815); + when(flightClientManager.getFlightClientLocation(LOCAL_NODE_ID)).thenReturn(Optional.of(location)); + when(streamProducer.estimatedRowCount()).thenReturn(100); + FlightDescriptor descriptor = FlightDescriptor.command(ticket.getBytes()); + FlightInfo flightInfo = baseFlightProducer.getFlightInfo(null, descriptor); + + assertNotNull(flightInfo); + assertEquals(100L, flightInfo.getRecords()); + assertEquals(1, flightInfo.getEndpoints().size()); + assertEquals(location, flightInfo.getEndpoints().getFirst().getLocations().getFirst()); + } + + public void testGetFlightInfo_NotFound() { + when(streamManager.getStreamProducer(any(FlightStreamTicket.class))).thenReturn(Optional.empty()); + FlightDescriptor descriptor = FlightDescriptor.command(ticket.getBytes()); + FlightRuntimeException exception = expectThrows( + FlightRuntimeException.class, + () -> baseFlightProducer.getFlightInfo(null, descriptor) + ); + + assertEquals("FlightInfo not found", exception.getMessage()); + } + + public void testGetFlightInfo_LocationNotFound() { + final VectorSchemaRoot root = mock(VectorSchemaRoot.class); + when(streamManager.getStreamProducer(any(FlightStreamTicket.class))).thenReturn( + Optional.of(FlightStreamManager.StreamProducerHolder.create(streamProducer, allocator)) + ); + when(streamProducer.createRoot(any(BufferAllocator.class))).thenReturn(root); + when(flightClientManager.getFlightClientLocation(LOCAL_NODE_ID)).thenReturn(Optional.empty()); + + FlightDescriptor descriptor = FlightDescriptor.command(ticket.getBytes()); + FlightRuntimeException exception = expectThrows( + FlightRuntimeException.class, + () -> baseFlightProducer.getFlightInfo(null, descriptor) + ); + + assertEquals("Internal error determining location", exception.getMessage()); + } + + public void testGetFlightInfo_SchemaError() { + when(streamManager.getStreamProducer(any(FlightStreamTicket.class))) + .thenReturn(Optional.of(FlightStreamManager.StreamProducerHolder.create(streamProducer, allocator))); + Location location = Location.forGrpcInsecure("localhost", 8815); + when(flightClientManager.getFlightClientLocation(LOCAL_NODE_ID)).thenReturn(Optional.of(location)); + when(streamProducer.createRoot(allocator)).thenReturn(mock(VectorSchemaRoot.class)); + when(streamProducer.estimatedRowCount()).thenThrow(new RuntimeException("Schema error")); + + FlightDescriptor descriptor = FlightDescriptor.command(ticket.getBytes()); + FlightRuntimeException exception = expectThrows( + FlightRuntimeException.class, + () -> baseFlightProducer.getFlightInfo(null, descriptor) + ); + + assertTrue(exception.getMessage(), exception.getMessage().contains("Error creating FlightInfo: Schema error")); + } + + public void testGetFlightInfo_NonLocalNode() { + final String remoteNodeId = "remote-node"; + FlightStreamTicket remoteTicket = new FlightStreamTicket("test-id", remoteNodeId); + FlightClient remoteClient = mock(FlightClient.class); + FlightInfo mockFlightInfo = mock(FlightInfo.class); + when(flightClientManager.getFlightClient(remoteNodeId)).thenReturn(Optional.of(remoteClient)); + when(remoteClient.getInfo(any(FlightDescriptor.class))).thenReturn(mockFlightInfo); + + FlightDescriptor descriptor = FlightDescriptor.command(remoteTicket.toBytes()); + FlightInfo flightInfo = baseFlightProducer.getFlightInfo(null, descriptor); + assertEquals(mockFlightInfo, flightInfo); + } + + public void testGetFlightInfo_NonLocalNode_LocationNotFound() { + final String remoteNodeId = "remote-node-2"; + FlightStreamTicket remoteTicket = new FlightStreamTicket("test-id", remoteNodeId); + when(flightClientManager.getFlightClient(remoteNodeId)).thenReturn(Optional.empty()); + FlightDescriptor descriptor = FlightDescriptor.command(remoteTicket.toBytes()); + FlightRuntimeException exception = expectThrows( + FlightRuntimeException.class, + () -> baseFlightProducer.getFlightInfo(null, descriptor) + ); + assertEquals("Client doesn't support Stream", exception.getMessage()); + } + + public void testGetStream_InvalidTicketFormat() throws Exception { + Ticket invalidTicket = new Ticket(new byte[] { 1, 2, 3 }); // Invalid byte array + TestServerStreamListener listener = new TestServerStreamListener(); + + FlightRuntimeException exception = expectThrows( + FlightRuntimeException.class, + () -> baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), invalidTicket, listener) + ); + + assertTrue(exception.getMessage().contains("Invalid ticket format")); + assertNotNull(listener.getError()); + assertTrue(listener.getError().getMessage().contains("Invalid ticket format")); + assertEquals(0, listener.getPutNextCount()); + } + + public void testGetFlightInfo_InvalidDescriptorFormat() { + FlightDescriptor invalidDescriptor = FlightDescriptor.command(new byte[] { 1, 2, 3 }); // Invalid byte array + FlightRuntimeException exception = expectThrows( + FlightRuntimeException.class, + () -> baseFlightProducer.getFlightInfo(mock(FlightProducer.CallContext.class), invalidDescriptor) + ); + + assertTrue(exception.getMessage().contains("Invalid descriptor format")); + } + + public void testGetStream_FailedToCreateStreamProducer_Remote() throws Exception { + final String remoteNodeId = "remote-node"; + FlightStreamTicket remoteTicket = new FlightStreamTicket("test-id", remoteNodeId); + FlightClient remoteClient = mock(FlightClient.class); + + when(flightClientManager.getFlightClient(remoteNodeId)).thenReturn(Optional.of(remoteClient)); + when(remoteClient.getStream(any(Ticket.class))).thenThrow(new RuntimeException("Remote stream error")); + + TestServerStreamListener listener = new TestServerStreamListener(); + FlightRuntimeException exception = expectThrows( + FlightRuntimeException.class, + () -> baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), new Ticket(remoteTicket.toBytes()), listener) + ); + + assertTrue(exception.getMessage().contains("Unable to create proxy stream: Remote stream error")); + assertNotNull(listener.getError()); + assertTrue(listener.getError().getMessage().contains("Unable to create proxy stream: Remote stream error")); + assertEquals(0, listener.getPutNextCount()); + } + + public void testGetStream_RemoteFlightStreamNull() throws Exception { + final String remoteNodeId = "remote-node"; + FlightStreamTicket remoteTicket = new FlightStreamTicket("test-id", remoteNodeId); + FlightClient remoteClient = mock(FlightClient.class); + + when(flightClientManager.getFlightClient(remoteNodeId)).thenReturn(Optional.of(remoteClient)); + when(remoteClient.getStream(any(Ticket.class))).thenReturn(null); // Simulate null FlightStream + + TestServerStreamListener listener = new TestServerStreamListener(); + FlightRuntimeException exception = expectThrows( + FlightRuntimeException.class, + () -> baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), new Ticket(remoteTicket.toBytes()), listener) + ); + + assertEquals("Stream not found", exception.getMessage()); + assertNotNull(listener.getError()); + assertEquals("Stream not found", listener.getError().getMessage()); + assertEquals(0, listener.getPutNextCount()); + } + + public void testGetStream_CreateProxyProducerException() throws Exception { + final String remoteNodeId = "remote-node"; + FlightStreamTicket remoteTicket = new FlightStreamTicket("test-id", remoteNodeId); + FlightClient remoteClient = mock(FlightClient.class); + + when(flightClientManager.getFlightClient(remoteNodeId)).thenReturn(Optional.of(remoteClient)); + when(remoteClient.getStream(any(Ticket.class))).thenThrow(new RuntimeException("Proxy creation error")); + + TestServerStreamListener listener = new TestServerStreamListener(); + FlightRuntimeException exception = expectThrows( + FlightRuntimeException.class, + () -> baseFlightProducer.getStream(mock(FlightProducer.CallContext.class), new Ticket(remoteTicket.toBytes()), listener) + ); + + assertTrue(exception.getMessage().contains("Unable to create proxy stream: Proxy creation error")); + assertNotNull(listener.getError()); + assertTrue(listener.getError().getMessage().contains("Unable to create proxy stream: Proxy creation error")); + assertEquals(0, listener.getPutNextCount()); + } + + public void testGetStream_CancellationException() throws Exception { + FlightProducer.CallContext context = mock(FlightProducer.CallContext.class); + when(context.isCancelled()).thenReturn(true); // Simulate cancellation + + when(streamManager.removeStreamProducer(any(FlightStreamTicket.class))).thenReturn( + Optional.of(FlightStreamManager.StreamProducerHolder.create(streamProducer, allocator)) + ); + when(streamProducer.createJob(any(BufferAllocator.class))).thenReturn(batchedJob); + doThrow(new RuntimeException("Cancellation error")).when(batchedJob).onCancel(); + + TestServerStreamListener listener = new TestServerStreamListener(); + FlightRuntimeException exception = expectThrows( + FlightRuntimeException.class, + () -> baseFlightProducer.getStream(context, ticket, listener) + ); + + assertTrue(exception.getMessage().contains("Error during cancellation: Cancellation error")); + assertNotNull(listener.getError()); + assertTrue(listener.getError().getMessage().contains("Error during cancellation: Cancellation error")); + } + + public void testGetFlightInfo_RemoteFlightInfoException() { + final String remoteNodeId = "remote-node"; + FlightStreamTicket remoteTicket = new FlightStreamTicket("test-id", remoteNodeId); + FlightClient remoteClient = mock(FlightClient.class); + + when(flightClientManager.getFlightClient(remoteNodeId)).thenReturn(Optional.of(remoteClient)); + when(remoteClient.getInfo(any(FlightDescriptor.class))).thenThrow(new RuntimeException("Remote info error")); + + FlightDescriptor descriptor = FlightDescriptor.command(remoteTicket.toBytes()); + FlightRuntimeException exception = expectThrows( + FlightRuntimeException.class, + () -> baseFlightProducer.getFlightInfo(mock(FlightProducer.CallContext.class), descriptor) + ); + + assertTrue(exception.getMessage().contains("Error retrieving remote FlightInfo: Remote info error")); + } +} diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/FlightStreamManagerTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/FlightStreamManagerTests.java new file mode 100644 index 0000000000000..f194f9ba0860a --- /dev/null +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/FlightStreamManagerTests.java @@ -0,0 +1,176 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight.impl; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.Schema; +import org.opensearch.arrow.flight.bootstrap.FlightClientManager; +import org.opensearch.arrow.spi.StreamProducer; +import org.opensearch.arrow.spi.StreamReader; +import org.opensearch.arrow.spi.StreamTicket; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.Collections; +import java.util.Optional; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class FlightStreamManagerTests extends OpenSearchTestCase { + + private FlightClient flightClient; + private FlightStreamManager flightStreamManager; + private static final String NODE_ID = "testNodeId"; + private static final String TICKET_ID = "testTicketId"; + + @Override + public void setUp() throws Exception { + super.setUp(); + flightClient = mock(FlightClient.class); + FlightClientManager clientManager = mock(FlightClientManager.class); + when(clientManager.getLocalNodeId()).thenReturn(NODE_ID); + when(clientManager.getFlightClient(NODE_ID)).thenReturn(Optional.of(flightClient)); + BufferAllocator allocator = mock(BufferAllocator.class); + flightStreamManager = new FlightStreamManager(); + flightStreamManager.setAllocatorSupplier(() -> allocator); + flightStreamManager.setClientManager(clientManager); + } + + public void testGetStreamReader() throws Exception { + StreamTicket ticket = new FlightStreamTicket(TICKET_ID, NODE_ID); + FlightStream mockFlightStream = mock(FlightStream.class); + VectorSchemaRoot mockRoot = mock(VectorSchemaRoot.class); + when(flightClient.getStream(new Ticket(ticket.toBytes()))).thenReturn(mockFlightStream); + when(mockFlightStream.getRoot()).thenReturn(mockRoot); + when(mockRoot.getSchema()).thenReturn(new Schema(Collections.emptyList())); + + StreamReader streamReader = flightStreamManager.getStreamReader(ticket); + + assertNotNull(streamReader); + assertNotNull(streamReader.getRoot()); + assertEquals(new Schema(Collections.emptyList()), streamReader.getRoot().getSchema()); + verify(flightClient).getStream(new Ticket(ticket.toBytes())); + } + + public void testGetVectorSchemaRootWithException() { + StreamTicket ticket = new FlightStreamTicket(TICKET_ID, NODE_ID); + when(flightClient.getStream(new Ticket(ticket.toBytes()))).thenThrow(new RuntimeException("Test exception")); + + expectThrows(RuntimeException.class, () -> flightStreamManager.getStreamReader(ticket)); + } + + public void testRegisterStream() throws IOException { + try (TestStreamProducer producer = new TestStreamProducer()) { + assertNotNull(flightStreamManager.getStreamTicketFactory()); + StreamTicket resultTicket = flightStreamManager.registerStream(producer, null); + assertNotNull(resultTicket); + assertTrue(resultTicket instanceof FlightStreamTicket); + FlightStreamTicket flightTicket = (FlightStreamTicket) resultTicket; + assertEquals(NODE_ID, flightTicket.getNodeId()); + assertNotNull(flightTicket.getTicketId()); + Optional retrievedProducer = flightStreamManager.getStreamProducer(resultTicket); + assertTrue(retrievedProducer.isPresent()); + assertEquals(producer, retrievedProducer.get().producer()); + assertNotNull(retrievedProducer.get().getRoot()); + } + } + + public void testGetStreamProducerNotFound() { + StreamTicket ticket = new FlightStreamTicket("nonexistent", NODE_ID); + assertFalse(flightStreamManager.getStreamProducer(ticket).isPresent()); + StreamTicket ticket2 = new FlightStreamTicket("nonexistent", "unknown"); + try { + flightStreamManager.getStreamReader(ticket2); + fail("RuntimeException expected"); + } catch (RuntimeException e) { + assertEquals("Flight client not found for node [unknown].", e.getMessage()); + } + } + + public void testRemoveStreamProducer() throws IOException { + try (TestStreamProducer producer = new TestStreamProducer()) { + StreamTicket resultTicket = flightStreamManager.registerStream(producer, null); + assertNotNull(resultTicket); + assertTrue(resultTicket instanceof FlightStreamTicket); + FlightStreamTicket flightTicket = (FlightStreamTicket) resultTicket; + assertEquals(NODE_ID, flightTicket.getNodeId()); + assertNotNull(flightTicket.getTicketId()); + + Optional retrievedProducer = flightStreamManager.removeStreamProducer(resultTicket); + assertTrue(retrievedProducer.isPresent()); + assertEquals(producer, retrievedProducer.get().producer()); + assertNotNull(retrievedProducer.get().getRoot()); + assertFalse(flightStreamManager.getStreamProducer(resultTicket).isPresent()); + } + } + + public void testRemoveNonExistentStreamProducer() { + StreamTicket ticket = new FlightStreamTicket("nonexistent", NODE_ID); + Optional removedProducer = flightStreamManager.removeStreamProducer(ticket); + assertFalse(removedProducer.isPresent()); + } + + public void testStreamProducerExpired() { + TestStreamProducer producer = new TestStreamProducer() { + @Override + public TimeValue getJobDeadline() { + return TimeValue.timeValueMillis(0); + } + }; + StreamTicket ticket = flightStreamManager.registerStream(producer, null); + Optional expiredProducer = flightStreamManager.getStreamProducer(ticket); + assertFalse(expiredProducer.isPresent()); + } + + public void testClose() throws Exception { + TestStreamProducer producer = new TestStreamProducer(); + StreamTicket ticket = flightStreamManager.registerStream(producer, null); + flightStreamManager.close(); + assertFalse(flightStreamManager.getStreamProducer(ticket).isPresent()); + } + + static class TestStreamProducer implements StreamProducer { + @Override + public VectorSchemaRoot createRoot(BufferAllocator bufferAllocator) { + return mock(VectorSchemaRoot.class); + } + + @Override + public BatchedJob createJob(BufferAllocator bufferAllocator) { + return null; + } + + @Override + public TimeValue getJobDeadline() { + return TimeValue.timeValueMillis(1000); + } + + @Override + public int estimatedRowCount() { + return 0; + } + + @Override + public String getAction() { + return ""; + } + + @Override + public void close() throws IOException { + + } + } +} diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/FlightStreamReaderTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/FlightStreamReaderTests.java new file mode 100644 index 0000000000000..f8bb592662a85 --- /dev/null +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/FlightStreamReaderTests.java @@ -0,0 +1,86 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight.impl; + +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.opensearch.arrow.flight.bootstrap.ServerConfig; +import org.opensearch.common.settings.Settings; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; + +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class FlightStreamReaderTests extends OpenSearchTestCase { + + private FlightStream mockFlightStream; + + private FlightStreamReader iterator; + private VectorSchemaRoot root; + private BufferAllocator allocator; + + @Override + public void setUp() throws Exception { + super.setUp(); + ServerConfig.init(Settings.EMPTY); + mockFlightStream = mock(FlightStream.class); + allocator = new RootAllocator(100000); + Field field = new Field("id", FieldType.nullable(new ArrowType.Int(32, true)), null); + Schema schema = new Schema(List.of(field)); + root = VectorSchemaRoot.create(schema, allocator); + when(mockFlightStream.getRoot()).thenReturn(root); + iterator = new FlightStreamReader(mockFlightStream); + } + + @Override + public void tearDown() throws Exception { + super.tearDown(); + root.close(); + allocator.close(); + } + + public void testNext_ReturnsTrue_WhenFlightStreamHasNext() throws Exception { + when(mockFlightStream.next()).thenReturn(true); + assertTrue(iterator.next()); + assert(mockFlightStream).next(); + } + + public void testNext_ReturnsFalse_WhenFlightStreamHasNoNext() throws Exception { + when(mockFlightStream.next()).thenReturn(false); + assertFalse(iterator.next()); + verify(mockFlightStream).next(); + } + + public void testGetRoot_ReturnsRootFromFlightStream() throws Exception { + VectorSchemaRoot returnedRoot = iterator.getRoot(); + assertEquals(root, returnedRoot); + verify(mockFlightStream).getRoot(); + } + + public void testClose_CallsCloseOnFlightStream() throws Exception { + iterator.close(); + verify(mockFlightStream).close(); + } + + public void testClose_WrapsExceptionInRuntimeException() throws Exception { + doThrow(new Exception("Test exception")).when(mockFlightStream).close(); + assertThrows(RuntimeException.class, () -> iterator.close()); + verify(mockFlightStream).close(); + } +} diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/FlightStreamTicketTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/FlightStreamTicketTests.java new file mode 100644 index 0000000000000..819da2826c173 --- /dev/null +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/FlightStreamTicketTests.java @@ -0,0 +1,111 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight.impl; + +import org.opensearch.arrow.spi.StreamTicket; +import org.opensearch.test.OpenSearchTestCase; + +import java.nio.charset.StandardCharsets; +import java.util.Base64; + +public class FlightStreamTicketTests extends OpenSearchTestCase { + + public void testConstructorAndGetters() { + String ticketID = "ticket123"; + String nodeID = "node456"; + StreamTicket ticket = new FlightStreamTicket(ticketID, nodeID); + + assertEquals(ticketID, ticket.getTicketId()); + assertEquals(nodeID, ticket.getNodeId()); + } + + public void testToBytes() { + StreamTicket ticket = new FlightStreamTicket("ticket123", "node456"); + byte[] bytes = ticket.toBytes(); + + assertNotNull(bytes); + assertTrue(bytes.length > 0); + + // Decode the Base64 and check the structure + byte[] decoded = Base64.getDecoder().decode(bytes); + assertEquals(2 + 9 + 2 + 7, decoded.length); // 2 shorts + "ticket123" + "node456" + } + + public void testFromBytes() { + StreamTicket original = new FlightStreamTicket("ticket123", "node456"); + byte[] bytes = original.toBytes(); + + StreamTicket reconstructed = FlightStreamTicket.fromBytes(bytes); + + assertEquals(original.getTicketId(), reconstructed.getTicketId()); + assertEquals(original.getNodeId(), reconstructed.getNodeId()); + } + + public void testToBytesWithLongStrings() { + String longString = randomAlphaOfLength(Short.MAX_VALUE + 1); + StreamTicket ticket = new FlightStreamTicket(longString, "node456"); + + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, ticket::toBytes); + assertEquals("Field lengths exceed the maximum allowed size.", exception.getMessage()); + } + + public void testNullInput() { + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> FlightStreamTicket.fromBytes(null)); + assertEquals("Invalid byte array input.", e.getMessage()); + } + + public void testEmptyInput() { + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> FlightStreamTicket.fromBytes(new byte[0])); + assertEquals("Invalid byte array input.", e.getMessage()); + } + + public void testMalformedBase64() { + byte[] invalidBase64 = "Invalid Base64!@#$".getBytes(StandardCharsets.UTF_8); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> FlightStreamTicket.fromBytes(invalidBase64)); + assertEquals("Illegal base64 character 20", e.getMessage()); + } + + public void testModifiedLengthFields() { + StreamTicket original = new FlightStreamTicket("ticket123", "node456"); + byte[] bytes = original.toBytes(); + byte[] decoded = Base64.getDecoder().decode(bytes); + + // Modify the length field to be larger than actual data + decoded[0] = (byte) 0xFF; + decoded[1] = (byte) 0xFF; + + byte[] modified = Base64.getEncoder().encode(decoded); + IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> FlightStreamTicket.fromBytes(modified)); + assertEquals("Invalid ticketID length: -1", e.getMessage()); + } + + public void testEquals() { + StreamTicket ticket1 = new FlightStreamTicket("ticket123", "node456"); + StreamTicket ticket2 = new FlightStreamTicket("ticket123", "node456"); + StreamTicket ticket3 = new FlightStreamTicket("ticket789", "node456"); + + assertEquals(ticket1, ticket2); + assertNotEquals(ticket1, ticket3); + assertNotEquals(null, ticket1); + assertNotEquals("Not a StreamTicket", ticket1); + } + + public void testHashCode() { + StreamTicket ticket1 = new FlightStreamTicket("ticket123", "node456"); + StreamTicket ticket2 = new FlightStreamTicket("ticket123", "node456"); + + assertEquals(ticket1.hashCode(), ticket2.hashCode()); + } + + public void testToString() { + StreamTicket ticket = new FlightStreamTicket("ticket123", "node456"); + String expected = "FlightStreamTicket{ticketID='ticket123', nodeID='node456'}"; + assertEquals(expected, ticket.toString()); + } +} diff --git a/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/ProxyStreamProducerTests.java b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/ProxyStreamProducerTests.java new file mode 100644 index 0000000000000..55905c435365d --- /dev/null +++ b/plugins/arrow-flight-rpc/src/test/java/org/opensearch/arrow/flight/impl/ProxyStreamProducerTests.java @@ -0,0 +1,120 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.arrow.flight.impl; + +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.opensearch.arrow.spi.StreamProducer; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.test.OpenSearchTestCase; +import org.junit.After; + +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class ProxyStreamProducerTests extends OpenSearchTestCase { + + private FlightStream mockRemoteStream; + private BufferAllocator mockAllocator; + private ProxyStreamProducer proxyStreamProducer; + + @Override + public void setUp() throws Exception { + super.setUp(); + mockRemoteStream = mock(FlightStream.class); + mockAllocator = mock(BufferAllocator.class); + proxyStreamProducer = new ProxyStreamProducer(new FlightStreamReader(mockRemoteStream)); + } + + public void testCreateRoot() throws Exception { + VectorSchemaRoot mockRoot = mock(VectorSchemaRoot.class); + when(mockRemoteStream.getRoot()).thenReturn(mockRoot); + + VectorSchemaRoot result = proxyStreamProducer.createRoot(mockAllocator); + + assertEquals(mockRoot, result); + verify(mockRemoteStream).getRoot(); + } + + public void testDefaults() { + VectorSchemaRoot mockRoot = mock(VectorSchemaRoot.class); + when(mockRoot.getRowCount()).thenReturn(100); + when(mockRemoteStream.getRoot()).thenReturn(mockRoot); + assertEquals(100, proxyStreamProducer.estimatedRowCount()); + try { + proxyStreamProducer.getAction(); + fail("Expected UnsupportedOperationException"); + } catch (UnsupportedOperationException e) { + assertEquals("Not implemented yet", e.getMessage()); + } + } + + public void testCreateJob() { + StreamProducer.BatchedJob job = proxyStreamProducer.createJob(mockAllocator); + + assertNotNull(job); + assertTrue(job instanceof ProxyStreamProducer.ProxyBatchedJob); + } + + public void testProxyBatchedJob() throws Exception { + StreamProducer.BatchedJob job = proxyStreamProducer.createJob(mockAllocator); + VectorSchemaRoot mockRoot = mock(VectorSchemaRoot.class); + StreamProducer.FlushSignal mockFlushSignal = mock(StreamProducer.FlushSignal.class); + + when(mockRemoteStream.next()).thenReturn(true, true, false); + + job.run(mockRoot, mockFlushSignal); + + verify(mockRemoteStream, times(3)).next(); + verify(mockFlushSignal, times(2)).awaitConsumption(TimeValue.timeValueMillis(1000)); + } + + public void testProxyBatchedJobWithException() throws Exception { + StreamProducer.BatchedJob job = proxyStreamProducer.createJob(mockAllocator); + VectorSchemaRoot mockRoot = mock(VectorSchemaRoot.class); + StreamProducer.FlushSignal mockFlushSignal = mock(StreamProducer.FlushSignal.class); + + doThrow(new RuntimeException("Test exception")).when(mockRemoteStream).next(); + + try { + job.run(mockRoot, mockFlushSignal); + fail("Expected RuntimeException"); + } catch (RuntimeException e) { + assertEquals("Test exception", e.getMessage()); + } + + verify(mockRemoteStream, times(1)).next(); + } + + public void testProxyBatchedJobOnCancel() throws Exception { + StreamProducer.BatchedJob job = proxyStreamProducer.createJob(mockAllocator); + VectorSchemaRoot mockRoot = mock(VectorSchemaRoot.class); + StreamProducer.FlushSignal mockFlushSignal = mock(StreamProducer.FlushSignal.class); + when(mockRemoteStream.next()).thenReturn(true, true, false); + + // cancel the job + job.onCancel(); + job.run(mockRoot, mockFlushSignal); + verify(mockRemoteStream, times(0)).next(); + verify(mockFlushSignal, times(0)).awaitConsumption(TimeValue.timeValueMillis(1000)); + assertTrue(job.isCancelled()); + } + + @After + public void tearDown() throws Exception { + if (proxyStreamProducer != null) { + proxyStreamProducer.close(); + } + super.tearDown(); + } +} diff --git a/server/build.gradle b/server/build.gradle index d3e55c4d8f784..faf49c88c3505 100644 --- a/server/build.gradle +++ b/server/build.gradle @@ -69,7 +69,6 @@ dependencies { api project(":libs:opensearch-geo") api project(":libs:opensearch-telemetry") api project(":libs:opensearch-task-commons") - implementation project(':libs:opensearch-arrow-spi') compileOnly project(":libs:agent-sm:bootstrap") compileOnly project(':libs:opensearch-plugin-classloader') diff --git a/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamManager.java b/server/src/main/java/org/opensearch/arrow/spi/StreamManager.java similarity index 100% rename from libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamManager.java rename to server/src/main/java/org/opensearch/arrow/spi/StreamManager.java diff --git a/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamProducer.java b/server/src/main/java/org/opensearch/arrow/spi/StreamProducer.java similarity index 94% rename from libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamProducer.java rename to server/src/main/java/org/opensearch/arrow/spi/StreamProducer.java index 6ca5b8944319b..955ae9ed8913a 100644 --- a/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamProducer.java +++ b/server/src/main/java/org/opensearch/arrow/spi/StreamProducer.java @@ -9,6 +9,7 @@ package org.opensearch.arrow.spi; import org.opensearch.common.annotation.ExperimentalApi; +import org.opensearch.common.unit.TimeValue; import org.opensearch.core.tasks.TaskId; import java.io.Closeable; @@ -95,6 +96,14 @@ public interface StreamProducer extends Closeable { */ BatchedJob createJob(Allocator allocator); + /** + * Returns the deadline for the job execution. + * After this deadline, the job should be considered expired. + * + * @return TimeValue representing the job's deadline + */ + TimeValue getJobDeadline(); + /** * Provides an estimate of the total number of rows that will be produced. * @@ -111,6 +120,7 @@ public interface StreamProducer extends Closeable { /** * BatchedJob interface for producing stream data in batches. */ + @ExperimentalApi interface BatchedJob { /** @@ -144,12 +154,13 @@ interface BatchedJob { * Functional interface for managing stream consumption signals. */ @FunctionalInterface + @ExperimentalApi interface FlushSignal { /** * Blocks until the current batch has been consumed or timeout occurs. * * @param timeout Maximum milliseconds to wait */ - void awaitConsumption(int timeout); + void awaitConsumption(TimeValue timeout); } } diff --git a/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamReader.java b/server/src/main/java/org/opensearch/arrow/spi/StreamReader.java similarity index 100% rename from libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamReader.java rename to server/src/main/java/org/opensearch/arrow/spi/StreamReader.java diff --git a/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamTicket.java b/server/src/main/java/org/opensearch/arrow/spi/StreamTicket.java similarity index 100% rename from libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamTicket.java rename to server/src/main/java/org/opensearch/arrow/spi/StreamTicket.java diff --git a/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamTicketFactory.java b/server/src/main/java/org/opensearch/arrow/spi/StreamTicketFactory.java similarity index 100% rename from libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/StreamTicketFactory.java rename to server/src/main/java/org/opensearch/arrow/spi/StreamTicketFactory.java diff --git a/libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/package-info.java b/server/src/main/java/org/opensearch/arrow/spi/package-info.java similarity index 100% rename from libs/arrow-spi/src/main/java/org/opensearch/arrow/spi/package-info.java rename to server/src/main/java/org/opensearch/arrow/spi/package-info.java diff --git a/server/src/main/java/org/opensearch/common/cache/Cache.java b/server/src/main/java/org/opensearch/common/cache/Cache.java index e01a1223955ed..679c402434c15 100644 --- a/server/src/main/java/org/opensearch/common/cache/Cache.java +++ b/server/src/main/java/org/opensearch/common/cache/Cache.java @@ -566,6 +566,19 @@ private void put(K key, V value, long now) { } }; + private final Consumer>> removalConsumer = f -> { + try { + Entry entry = f.get(); + try (ReleasableLock ignored = lruLock.acquire()) { + delete(entry, RemovalReason.EXPLICIT); + } + } catch (ExecutionException e) { + // ok + } catch (InterruptedException e) { + throw new IllegalStateException(e); + } + }; + /** * Invalidate the association for the specified key. A removal notification will be issued for invalidated * entries with {@link RemovalReason} INVALIDATED. @@ -577,6 +590,17 @@ public void invalidate(K key) { segment.remove(key, invalidationConsumer); } + /** + * Removes the association for the specified key. A removal notification will be issued for removed + * entry with {@link RemovalReason} EXPLICIT. + * + * @param key the key whose mapping is to be removed from the cache + */ + public void remove(K key) { + CacheSegment segment = getCacheSegment(key); + segment.remove(key, removalConsumer); + } + /** * Invalidate the entry for the specified key and value. If the value provided is not equal to the value in * the cache, no removal will occur. A removal notification will be issued for invalidated diff --git a/server/src/main/java/org/opensearch/node/Node.java b/server/src/main/java/org/opensearch/node/Node.java index 8037f90653d89..d5683703b2df1 100644 --- a/server/src/main/java/org/opensearch/node/Node.java +++ b/server/src/main/java/org/opensearch/node/Node.java @@ -56,6 +56,7 @@ import org.opensearch.action.search.SearchTransportService; import org.opensearch.action.support.TransportAction; import org.opensearch.action.update.UpdateHelper; +import org.opensearch.arrow.spi.StreamManager; import org.opensearch.bootstrap.BootstrapCheck; import org.opensearch.bootstrap.BootstrapContext; import org.opensearch.cluster.ClusterInfoService; @@ -218,6 +219,7 @@ import org.opensearch.plugins.SearchPipelinePlugin; import org.opensearch.plugins.SearchPlugin; import org.opensearch.plugins.SecureSettingsFactory; +import org.opensearch.plugins.StreamManagerPlugin; import org.opensearch.plugins.SystemIndexPlugin; import org.opensearch.plugins.TaskManagerClientPlugin; import org.opensearch.plugins.TelemetryAwarePlugin; @@ -314,6 +316,7 @@ import java.util.stream.Stream; import static java.util.stream.Collectors.toList; +import static org.opensearch.common.util.FeatureFlags.ARROW_STREAMS_SETTING; import static org.opensearch.common.util.FeatureFlags.BACKGROUND_TASK_EXECUTION_EXPERIMENTAL; import static org.opensearch.common.util.FeatureFlags.TELEMETRY; import static org.opensearch.env.NodeEnvironment.collectFileCacheDataPath; @@ -1387,6 +1390,25 @@ protected Node( cacheService ); + if (FeatureFlags.isEnabled(ARROW_STREAMS_SETTING)) { + final List streamManagerPlugins = pluginsService.filterPlugins(StreamManagerPlugin.class); + + final List streamManagers = streamManagerPlugins.stream() + .map(StreamManagerPlugin::getStreamManager) + .filter(Optional::isPresent) + .map(Optional::get) + .toList(); + + if (streamManagers.size() > 1) { + throw new IllegalStateException( + String.format(Locale.ROOT, "Only one StreamManagerPlugin can be installed. Found: %d", streamManagerPlugins.size()) + ); + } else if (streamManagers.isEmpty() == false) { + StreamManager streamManager = streamManagers.getFirst(); + streamManagerPlugins.forEach(plugin -> plugin.onStreamManagerInitialized(streamManager)); + } + } + final SearchService searchService = newSearchService( clusterService, indicesService, diff --git a/server/src/main/java/org/opensearch/plugins/StreamManagerPlugin.java b/server/src/main/java/org/opensearch/plugins/StreamManagerPlugin.java index 60bdb789b3750..929ec96950f08 100644 --- a/server/src/main/java/org/opensearch/plugins/StreamManagerPlugin.java +++ b/server/src/main/java/org/opensearch/plugins/StreamManagerPlugin.java @@ -10,11 +10,13 @@ import org.opensearch.arrow.spi.StreamManager; -import java.util.function.Supplier; +import java.util.Optional; /** * An interface for OpenSearch plugins to implement to provide a StreamManager. - * Plugins can implement this interface to provide custom StreamManager implementation. + * Plugins can implement this interface to provide custom StreamManager implementation + * or get a reference to the StreamManager instance provided by OpenSearch. + * * @see StreamManager */ public interface StreamManagerPlugin { @@ -23,5 +25,13 @@ public interface StreamManagerPlugin { * * @return The StreamManager instance */ - Supplier getStreamManager(); + default Optional getStreamManager() { + return Optional.empty(); + } + + /** + * Called when the StreamManager is initialized. + * @param streamManager Supplier of the StreamManager instance + */ + default void onStreamManagerInitialized(StreamManager streamManager) {} }