diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java index 786364bd98b19..f058f458989d5 100644 --- a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java @@ -40,6 +40,7 @@ import software.amazon.awssdk.services.s3.model.AbortMultipartUploadRequest; import software.amazon.awssdk.services.s3.model.CommonPrefix; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; +import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload; import software.amazon.awssdk.services.s3.model.CompletedPart; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; @@ -53,6 +54,7 @@ import software.amazon.awssdk.services.s3.model.NoSuchKeyException; import software.amazon.awssdk.services.s3.model.ObjectAttributes; import software.amazon.awssdk.services.s3.model.PutObjectRequest; +import software.amazon.awssdk.services.s3.model.S3Exception; import software.amazon.awssdk.services.s3.model.UploadPartRequest; import software.amazon.awssdk.services.s3.model.UploadPartResponse; import software.amazon.awssdk.services.s3.paginators.ListObjectsV2Iterable; @@ -62,6 +64,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; +import org.opensearch.OpenSearchException; import org.opensearch.action.support.PlainActionFuture; import org.opensearch.common.Nullable; import org.opensearch.common.SetOnce; @@ -72,6 +75,8 @@ import org.opensearch.common.blobstore.BlobMetadata; import org.opensearch.common.blobstore.BlobPath; import org.opensearch.common.blobstore.BlobStoreException; +import org.opensearch.common.blobstore.ConditionalWrite.ConditionalWriteOptions; +import org.opensearch.common.blobstore.ConditionalWrite.ConditionalWriteResponse; import org.opensearch.common.blobstore.DeleteResult; import org.opensearch.common.blobstore.InputStreamWithMetadata; import org.opensearch.common.blobstore.stream.read.ReadContext; @@ -96,6 +101,7 @@ import java.io.InputStream; import java.util.ArrayList; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; @@ -117,6 +123,7 @@ class S3BlobContainer extends AbstractBlobContainer implements AsyncMultiStreamB private final S3BlobStore blobStore; private final String keyPath; + public static final int HTTP_STATUS_PRECONDITION_FAILED = 412; S3BlobContainer(BlobPath path, S3BlobStore blobStore) { super(path); @@ -209,8 +216,25 @@ public void writeBlobWithMetadata( }); } + @Override + public void asyncBlobUploadConditionally( + WriteContext writeContext, + ConditionalWriteOptions options, + ActionListener completionListener + ) throws IOException { + executeAsyncUpload(writeContext, options, completionListener); + } + @Override public void asyncBlobUpload(WriteContext writeContext, ActionListener completionListener) throws IOException { + executeAsyncUpload(writeContext, null, new ConditionalResponseToVoidListener(completionListener)); + } + + private void executeAsyncUpload( + WriteContext writeContext, + @Nullable ConditionalWriteOptions options, + ActionListener completionListener + ) throws IOException { UploadRequest uploadRequest = new UploadRequest( blobStore.bucket(), buildKey(writeContext.getFileName()), @@ -221,12 +245,14 @@ public void asyncBlobUpload(WriteContext writeContext, ActionListener comp writeContext.getExpectedChecksum(), blobStore.isUploadRetryEnabled(), writeContext.getMetadata(), + options, blobStore.serverSideEncryptionType(), blobStore.serverSideEncryptionKmsKey(), blobStore.serverSideEncryptionBucketKey(), blobStore.serverSideEncryptionEncryptionContext(), blobStore.expectedBucketOwner() ); + try { // If file size is greater than the queue capacity than SizeBasedBlockingQ will always reject the upload. // Therefore, redirecting it to slow client. @@ -236,19 +262,33 @@ public void asyncBlobUpload(WriteContext writeContext, ActionListener comp && uploadRequest.getWritePriority() != WritePriority.URGENT && blobStore.getNormalPrioritySizeBasedBlockingQ() .isMaxCapacityBelowContentLength(uploadRequest.getContentLength()) == false)) { + StreamContext streamContext = SocketAccess.doPrivileged( () -> writeContext.getStreamProvider(uploadRequest.getContentLength()) ); InputStreamContainer inputStream = streamContext.provideStream(0); + try { - executeMultipartUpload( - blobStore, - uploadRequest.getKey(), - inputStream.getInputStream(), - uploadRequest.getContentLength(), - uploadRequest.getMetadata() - ); - completionListener.onResponse(null); + if (options != null) { + executeMultipartUploadConditionally( + blobStore, + uploadRequest.getKey(), + inputStream.getInputStream(), + uploadRequest.getContentLength(), + uploadRequest.getMetadata(), + options, + completionListener + ); + } else { + executeMultipartUpload( + blobStore, + uploadRequest.getKey(), + inputStream.getInputStream(), + uploadRequest.getContentLength(), + uploadRequest.getMetadata() + ); + completionListener.onResponse(ConditionalWriteResponse.success(null)); + } } catch (Exception ex) { logger.error( () -> new ParameterizedMessage( @@ -279,23 +319,29 @@ public void asyncBlobUpload(WriteContext writeContext, ActionListener comp if (writeContext.getWritePriority() == WritePriority.URGENT || writeContext.getWritePriority() == WritePriority.HIGH || blobStore.isPermitBackedTransferEnabled() == false) { - createFileCompletableFuture(s3AsyncClient, uploadRequest, streamContext, completionListener); + + if (options != null) { + createFileCompletableFutureConditionally(s3AsyncClient, uploadRequest, streamContext, completionListener); + } else { + createFileCompletableFuture(s3AsyncClient, uploadRequest, streamContext, completionListener); + } + } else if (writeContext.getWritePriority() == WritePriority.LOW) { - blobStore.getLowPrioritySizeBasedBlockingQ() - .produce( - new SizeBasedBlockingQ.Item( - writeContext.getFileSize(), - () -> createFileCompletableFuture(s3AsyncClient, uploadRequest, streamContext, completionListener) - ) - ); + blobStore.getLowPrioritySizeBasedBlockingQ().produce(new SizeBasedBlockingQ.Item(writeContext.getFileSize(), () -> { + if (options != null) { + createFileCompletableFutureConditionally(s3AsyncClient, uploadRequest, streamContext, completionListener); + } else { + createFileCompletableFuture(s3AsyncClient, uploadRequest, streamContext, completionListener); + } + })); } else if (writeContext.getWritePriority() == WritePriority.NORMAL) { - blobStore.getNormalPrioritySizeBasedBlockingQ() - .produce( - new SizeBasedBlockingQ.Item( - writeContext.getFileSize(), - () -> createFileCompletableFuture(s3AsyncClient, uploadRequest, streamContext, completionListener) - ) - ); + blobStore.getNormalPrioritySizeBasedBlockingQ().produce(new SizeBasedBlockingQ.Item(writeContext.getFileSize(), () -> { + if (options != null) { + createFileCompletableFutureConditionally(s3AsyncClient, uploadRequest, streamContext, completionListener); + } else { + createFileCompletableFuture(s3AsyncClient, uploadRequest, streamContext, completionListener); + } + })); } else { throw new IllegalStateException("Cannot perform upload for other priority types."); } @@ -306,14 +352,38 @@ public void asyncBlobUpload(WriteContext writeContext, ActionListener comp } } - private CompletableFuture createFileCompletableFuture( + private CompletableFuture createFileCompletableFuture( S3AsyncClient s3AsyncClient, UploadRequest uploadRequest, StreamContext streamContext, - ActionListener completionListener + ActionListener completionListener ) { - CompletableFuture completableFuture = blobStore.getAsyncTransferManager() + CompletableFuture standardFuture = blobStore.getAsyncTransferManager() .uploadObject(s3AsyncClient, uploadRequest, streamContext, blobStore.getStatsMetricPublisher()); + + CompletableFuture convertedFuture = standardFuture.thenApply( + result -> ConditionalWriteResponse.success(null) + ); + + return convertedFuture.whenComplete((response, throwable) -> { + if (throwable == null) { + completionListener.onResponse(response); + } else { + Exception ex = throwable instanceof Error ? new Exception(throwable) : (Exception) throwable; + completionListener.onFailure(ex); + } + }); + } + + private CompletableFuture createFileCompletableFutureConditionally( + S3AsyncClient s3AsyncClient, + UploadRequest uploadRequest, + StreamContext streamContext, + ActionListener completionListener + ) { + CompletableFuture completableFuture = blobStore.getAsyncTransferManager() + .uploadObjectConditionally(s3AsyncClient, uploadRequest, streamContext, blobStore.getStatsMetricPublisher()); + return completableFuture.whenComplete((response, throwable) -> { if (throwable == null) { completionListener.onResponse(response); @@ -324,6 +394,27 @@ private CompletableFuture createFileCompletableFuture( }); } + /** + * Helper class to convert ConditionalWriteResponse to Void + */ + private static class ConditionalResponseToVoidListener implements ActionListener { + private final ActionListener delegate; + + ConditionalResponseToVoidListener(ActionListener delegate) { + this.delegate = delegate; + } + + @Override + public void onResponse(ConditionalWriteResponse response) { + delegate.onResponse(null); + } + + @Override + public void onFailure(Exception e) { + delegate.onFailure(e); + } + } + @ExperimentalApi @Override public void readBlobAsync(String blobName, ActionListener listener) { @@ -521,6 +612,177 @@ private String buildKey(String blobName) { return keyPath + blobName; } + public void executeMultipartUploadConditionally( + final S3BlobStore blobStore, + final String blobName, + final InputStream input, + final long blobSize, + final Map metadata, + final ConditionalWriteOptions options, + final ActionListener listener + ) throws IOException { + + ensureMultiPartUploadSize(blobSize); + + final long partSize = blobStore.bufferSizeInBytes(); + final Tuple multiparts = numberOfMultiparts(blobSize, partSize); + if (multiparts.v1() > Integer.MAX_VALUE) { + throw new IllegalArgumentException("Too many multipart upload parts; consider a larger buffer size."); + } + final int nbParts = multiparts.v1().intValue(); + final long lastPartSize = multiparts.v2(); + assert blobSize == (((nbParts - 1) * partSize) + lastPartSize) : "blobSize does not match multipart sizes"; + + CreateMultipartUploadRequest.Builder createRequestBuilder = CreateMultipartUploadRequest.builder() + .bucket(blobStore.bucket()) + .key(blobName) + .storageClass(blobStore.getStorageClass()) + .acl(blobStore.getCannedACL()) + .overrideConfiguration(o -> o.addMetricPublisher(blobStore.getStatsMetricPublisher().multipartUploadMetricCollector)) + .expectedBucketOwner(blobStore.expectedBucketOwner()); + + if (CollectionUtils.isNotEmpty(metadata)) { + createRequestBuilder.metadata(metadata); + } + configureEncryptionSettings(createRequestBuilder, blobStore); + + final CreateMultipartUploadRequest createMultipartUploadRequest = createRequestBuilder.build(); + final SetOnce uploadId = new SetOnce<>(); + final String bucketName = blobStore.bucket(); + boolean success = false; + + final InputStream requestInputStream = blobStore.isUploadRetryEnabled() + ? new BufferedInputStream(input, (int) (partSize + 1)) + : input; + + try (AmazonS3Reference clientReference = blobStore.clientReference()) { + uploadId.set( + SocketAccess.doPrivileged(() -> clientReference.get().createMultipartUpload(createMultipartUploadRequest).uploadId()) + ); + if (Strings.isEmpty(uploadId.get())) { + IOException exception = new IOException("Failed to initialize multipart upload for " + blobName); + listener.onFailure(exception); + throw exception; + } + + final List parts = new ArrayList<>(nbParts); + long bytesCount = 0; + + for (int i = 1; i <= nbParts; i++) { + long currentPartSize = (i < nbParts) ? partSize : lastPartSize; + final UploadPartRequest uploadPartRequest = UploadPartRequest.builder() + .bucket(bucketName) + .key(blobName) + .uploadId(uploadId.get()) + .partNumber(i) + .contentLength(currentPartSize) + .overrideConfiguration(o -> o.addMetricPublisher(blobStore.getStatsMetricPublisher().multipartUploadMetricCollector)) + .expectedBucketOwner(blobStore.expectedBucketOwner()) + .build(); + + bytesCount += currentPartSize; + + final UploadPartResponse uploadResponse = SocketAccess.doPrivileged( + () -> clientReference.get() + .uploadPart(uploadPartRequest, RequestBody.fromInputStream(requestInputStream, currentPartSize)) + ); + + String partETag = uploadResponse.eTag(); + if (partETag == null) { + IOException exception = new IOException( + String.format(Locale.ROOT, "S3 part upload for [%s] part [%d] returned null ETag", blobName, i) + ); + listener.onFailure(exception); + throw exception; + } + + parts.add(CompletedPart.builder().partNumber(i).eTag(partETag).build()); + } + + if (bytesCount != blobSize) { + IOException exception = new IOException( + String.format(Locale.ROOT, "Multipart upload for [%s] sent %d bytes; expected %d bytes", blobName, bytesCount, blobSize) + ); + listener.onFailure(exception); + throw exception; + } + + CompleteMultipartUploadRequest.Builder completeRequestBuilder = CompleteMultipartUploadRequest.builder() + .bucket(bucketName) + .key(blobName) + .uploadId(uploadId.get()) + .multipartUpload(CompletedMultipartUpload.builder().parts(parts).build()) + .overrideConfiguration(o -> o.addMetricPublisher(blobStore.getStatsMetricPublisher().multipartUploadMetricCollector)) + .expectedBucketOwner(blobStore.expectedBucketOwner()); + + if (options.isIfMatch()) { + completeRequestBuilder.ifMatch(options.getVersionIdentifier()); + } else if (options.isIfNotExists()) { + completeRequestBuilder.ifNoneMatch("*"); + } + + CompleteMultipartUploadRequest completeRequest = completeRequestBuilder.build(); + + CompleteMultipartUploadResponse completeResponse = SocketAccess.doPrivileged( + () -> clientReference.get().completeMultipartUpload(completeRequest) + ); + + if (completeResponse.eTag() != null) { + success = true; + listener.onResponse(ConditionalWriteResponse.success(completeResponse.eTag())); + } else { + IOException exception = new IOException( + "S3 multipart upload for [" + blobName + "] returned null ETag, violating data integrity expectations" + ); + listener.onFailure(exception); + throw exception; + } + + } catch (S3Exception e) { + if (e.statusCode() == HTTP_STATUS_PRECONDITION_FAILED) { + listener.onFailure(new OpenSearchException("Precondition Failed : Etag Mismatch", e, blobName)); + throw new IOException("Unable to upload object [" + blobName + "] due to ETag mismatch", e); + } else { + IOException exception = new IOException( + String.format(Locale.ROOT, "S3 error during multipart upload [%s]: %s", blobName, e.getMessage()), + e + ); + listener.onFailure(exception); + throw exception; + } + } catch (SdkException e) { + IOException exception = new IOException(String.format(Locale.ROOT, "S3 multipart upload failed for [%s]", blobName), e); + listener.onFailure(exception); + throw exception; + } catch (Exception e) { + IOException exception = new IOException( + String.format(Locale.ROOT, "Unexpected error during multipart upload [%s]: %s", blobName, e.getMessage()), + e + ); + listener.onFailure(exception); + throw exception; + } finally { + if (!success && Strings.hasLength(uploadId.get())) { + AbortMultipartUploadRequest abortRequest = AbortMultipartUploadRequest.builder() + .bucket(bucketName) + .key(blobName) + .uploadId(uploadId.get()) + .expectedBucketOwner(blobStore.expectedBucketOwner()) + .build(); + try (AmazonS3Reference abortClient = blobStore.clientReference()) { + SocketAccess.doPrivilegedVoid(() -> abortClient.get().abortMultipartUpload(abortRequest)); + } catch (Exception abortException) { + logger.warn( + "Failed to abort incomplete multipart upload [{}] with ID [{}]. " + + "This may result in orphaned S3 data and charges.", + new Object[] { blobName, uploadId.get() }, + abortException + ); + } + } + } + } + /** * Uploads a blob using a single upload request */ diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncTransferManager.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncTransferManager.java index 1c5c12fe799cb..e7b8f57b474dd 100644 --- a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncTransferManager.java +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/AsyncTransferManager.java @@ -14,7 +14,6 @@ import software.amazon.awssdk.services.s3.S3AsyncClient; import software.amazon.awssdk.services.s3.model.ChecksumAlgorithm; import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadRequest; -import software.amazon.awssdk.services.s3.model.CompleteMultipartUploadResponse; import software.amazon.awssdk.services.s3.model.CompletedMultipartUpload; import software.amazon.awssdk.services.s3.model.CompletedPart; import software.amazon.awssdk.services.s3.model.CreateMultipartUploadRequest; @@ -31,6 +30,8 @@ import org.apache.logging.log4j.message.ParameterizedMessage; import org.opensearch.ExceptionsHelper; import org.opensearch.common.StreamContext; +import org.opensearch.common.blobstore.ConditionalWrite.ConditionalWriteOptions; +import org.opensearch.common.blobstore.ConditionalWrite.ConditionalWriteResponse; import org.opensearch.common.blobstore.exception.CorruptFileException; import org.opensearch.common.blobstore.stream.write.WritePriority; import org.opensearch.common.io.InputStreamContainer; @@ -51,7 +52,6 @@ import java.util.concurrent.ExecutorService; import java.util.concurrent.Semaphore; import java.util.concurrent.atomic.AtomicReferenceArray; -import java.util.function.BiFunction; import java.util.function.Supplier; import java.util.stream.IntStream; @@ -70,6 +70,9 @@ public final class AsyncTransferManager { private final long minimumPartSize; private final long maxRetryablePartSize; + private static final int HTTP_STATUS_PRECONDITION_FAILED = 412; + private static final int HTTP_STATUS_CONFLICT = 409; + @SuppressWarnings("rawtypes") private final TransferSemaphoresHolder transferSemaphoresHolder; @@ -114,11 +117,48 @@ public CompletableFuture uploadObject( StreamContext streamContext, StatsMetricPublisher statsMetricPublisher ) { + return processUploadRequest(s3AsyncClient, uploadRequest, streamContext, statsMetricPublisher, false).thenApply(response -> null); + } - CompletableFuture returnFuture = new CompletableFuture<>(); + /** + * Upload an object to S3 conditionally using the async client + * + * @param s3AsyncClient S3 client to use for upload + * @param uploadRequest The {@link UploadRequest} object encapsulating all relevant details for upload + * @param streamContext The {@link StreamContext} to supply streams during upload + * @param statsMetricPublisher Metric publisher for collecting stats + * @return A {@link CompletableFuture} that will complete with the ConditionalWriteResponse or an exception + */ + public CompletableFuture uploadObjectConditionally( + S3AsyncClient s3AsyncClient, + UploadRequest uploadRequest, + StreamContext streamContext, + StatsMetricPublisher statsMetricPublisher + ) { + ConditionalWriteOptions options = uploadRequest.getConditionalOptions(); + if (options == null) { + throw new IllegalArgumentException("Cannot perform conditional upload with null options"); + } + + return processUploadRequest(s3AsyncClient, uploadRequest, streamContext, statsMetricPublisher, true); + } + + /** + internal upload execution for both regular and conditional uploads + */ + private CompletableFuture processUploadRequest( + S3AsyncClient s3AsyncClient, + UploadRequest uploadRequest, + StreamContext streamContext, + StatsMetricPublisher statsMetricPublisher, + boolean isConditional + ) { + CompletableFuture returnFuture = new CompletableFuture<>(); try { if (streamContext.getNumberOfParts() == 1) { - log.debug(() -> "Starting the upload as a single upload part request"); + log.debug( + () -> "Starting " + (isConditional ? "conditional " : "") + "single part upload for key: " + uploadRequest.getKey() + ); TransferSemaphoresHolder.RequestContext requestContext = transferSemaphoresHolder.createRequestContext(); Semaphore semaphore = AsyncPartsHandler.maybeAcquireSemaphore( transferSemaphoresHolder, @@ -127,7 +167,15 @@ public CompletableFuture uploadObject( uploadRequest.getKey() ); try { - uploadInOneChunk(s3AsyncClient, uploadRequest, streamContext, returnFuture, statsMetricPublisher, semaphore); + uploadInOneChunk( + s3AsyncClient, + uploadRequest, + streamContext, + returnFuture, + statsMetricPublisher, + semaphore, + isConditional + ); } catch (Exception ex) { if (semaphore != null) { semaphore.release(); @@ -135,8 +183,10 @@ public CompletableFuture uploadObject( throw ex; } } else { - log.debug(() -> "Starting the upload as multipart upload request"); - uploadInParts(s3AsyncClient, uploadRequest, streamContext, returnFuture, statsMetricPublisher); + log.debug( + () -> "Starting " + (isConditional ? "conditional " : "") + "multipart upload for key: " + uploadRequest.getKey() + ); + uploadInParts(s3AsyncClient, uploadRequest, streamContext, returnFuture, statsMetricPublisher, isConditional); } } catch (Throwable throwable) { returnFuture.completeExceptionally(throwable); @@ -145,12 +195,16 @@ public CompletableFuture uploadObject( return returnFuture; } + /** + multipart upload initiation for both regular and conditional async uploads + */ private void uploadInParts( S3AsyncClient s3AsyncClient, UploadRequest uploadRequest, StreamContext streamContext, - CompletableFuture returnFuture, - StatsMetricPublisher statsMetricPublisher + CompletableFuture returnFuture, + StatsMetricPublisher statsMetricPublisher, + boolean isConditional ) { CreateMultipartUploadRequest.Builder createMultipartUploadRequestBuilder = CreateMultipartUploadRequest.builder() @@ -183,22 +237,25 @@ private void uploadInParts( uploadId = createMultipartUploadResponse.uploadId(); log.debug(() -> "Initiated new multipart upload, uploadId: " + createMultipartUploadResponse.uploadId()); } catch (Exception ex) { - handleException(returnFuture, () -> "Failed to initiate multipart upload", ex); + handleUploadException(returnFuture, ex, uploadRequest.getKey(), isConditional); return; } - doUploadInParts(s3AsyncClient, uploadRequest, streamContext, returnFuture, uploadId, statsMetricPublisher); + doUploadInParts(s3AsyncClient, uploadRequest, streamContext, returnFuture, uploadId, statsMetricPublisher, isConditional); } + /** + multipart upload execution for both regular and conditional uploads + */ private void doUploadInParts( S3AsyncClient s3AsyncClient, UploadRequest uploadRequest, StreamContext streamContext, - CompletableFuture returnFuture, + CompletableFuture returnFuture, String uploadId, - StatsMetricPublisher statsMetricPublisher + StatsMetricPublisher statsMetricPublisher, + boolean isConditional ) { - // The list of completed parts must be sorted AtomicReferenceArray completedParts = new AtomicReferenceArray<>(streamContext.getNumberOfParts()); AtomicReferenceArray inputStreamContainers = new AtomicReferenceArray<>(streamContext.getNumberOfParts()); @@ -241,123 +298,97 @@ private void doUploadInParts( mergeAndVerifyChecksum(inputStreamContainers, uploadRequest.getKey(), uploadRequest.getExpectedChecksum()); } return null; - }) - .thenCompose(ignore -> completeMultipartUpload(s3AsyncClient, uploadRequest, uploadId, completedParts, statsMetricPublisher)) - .handle(handleExceptionOrResponse(s3AsyncClient, uploadRequest, returnFuture, uploadId)) - .exceptionally(throwable -> { - handleException(returnFuture, () -> "Unexpected exception occurred", throwable); - return null; - }); - } - - private void mergeAndVerifyChecksum( - AtomicReferenceArray inputStreamContainers, - String fileName, - long expectedChecksum - ) { - long resultantChecksum = fromBase64String(inputStreamContainers.get(0).getChecksum()); - for (int index = 1; index < inputStreamContainers.length(); index++) { - long curChecksum = fromBase64String(inputStreamContainers.get(index).getChecksum()); - resultantChecksum = JZlib.crc32_combine(resultantChecksum, curChecksum, inputStreamContainers.get(index).getContentLength()); - } - - if (resultantChecksum != expectedChecksum) { - throw new RuntimeException(new CorruptFileException("File level checksums didn't match combined part checksums", fileName)); - } - } - - private BiFunction handleExceptionOrResponse( - S3AsyncClient s3AsyncClient, - UploadRequest uploadRequest, - CompletableFuture returnFuture, - String uploadId - ) { - - return (response, throwable) -> { + }).thenCompose(ignore -> { + log.debug(() -> "Completing " + (isConditional ? "conditional " : "") + "multipart upload, uploadId: " + uploadId); + return completeMultipartUpload(s3AsyncClient, uploadRequest, uploadId, completedParts, statsMetricPublisher, isConditional); + }).handle((response, throwable) -> { if (throwable != null) { AsyncPartsHandler.cleanUpParts(s3AsyncClient, uploadRequest, uploadId); - handleException(returnFuture, () -> "Failed to send multipart upload requests.", throwable); + + if (isConditional) { + Throwable unwrappedThrowable = ExceptionsHelper.unwrap(throwable, S3Exception.class); + if (unwrappedThrowable != null) { + S3Exception s3Exception = (S3Exception) unwrappedThrowable; + if (s3Exception.statusCode() == HTTP_STATUS_PRECONDITION_FAILED) { + returnFuture.completeExceptionally( + S3Exception.builder() + .message("Conditional write failed: condition not met for " + uploadRequest.getKey()) + .statusCode(HTTP_STATUS_PRECONDITION_FAILED) + .cause(s3Exception) + .build() + ); + return null; + } else if (s3Exception.statusCode() == HTTP_STATUS_CONFLICT) { + returnFuture.completeExceptionally( + S3Exception.builder() + .message("Blob already exists: " + uploadRequest.getKey()) + .statusCode(HTTP_STATUS_CONFLICT) + .cause(s3Exception) + .build() + ); + return null; + } + } + } + handleUploadException(returnFuture, throwable, uploadRequest.getKey(), isConditional); + return null; } else { - returnFuture.complete(null); + returnFuture.complete(response); + return null; } + }).exceptionally(throwable -> { + handleUploadException(returnFuture, throwable, uploadRequest.getKey(), isConditional); return null; - }; + }); } - private CompletableFuture completeMultipartUpload( + /** + complete multipart upload for both regular and conditional uploads + */ + private CompletableFuture completeMultipartUpload( S3AsyncClient s3AsyncClient, UploadRequest uploadRequest, String uploadId, AtomicReferenceArray completedParts, - StatsMetricPublisher statsMetricPublisher + StatsMetricPublisher statsMetricPublisher, + boolean isConditional ) { log.debug(() -> new ParameterizedMessage("Sending completeMultipartUploadRequest, uploadId: {}", uploadId)); CompletedPart[] parts = IntStream.range(0, completedParts.length()).mapToObj(completedParts::get).toArray(CompletedPart[]::new); - CompleteMultipartUploadRequest completeMultipartUploadRequest = CompleteMultipartUploadRequest.builder() + + CompleteMultipartUploadRequest.Builder completeRequestBuilder = CompleteMultipartUploadRequest.builder() .bucket(uploadRequest.getBucket()) .key(uploadRequest.getKey()) .uploadId(uploadId) .overrideConfiguration(o -> o.addMetricPublisher(statsMetricPublisher.multipartUploadMetricCollector)) .multipartUpload(CompletedMultipartUpload.builder().parts(parts).build()) - .expectedBucketOwner(uploadRequest.getExpectedBucketOwner()) - .build(); - - return SocketAccess.doPrivileged(() -> s3AsyncClient.completeMultipartUpload(completeMultipartUploadRequest)); - } - - private static String base64StringFromLong(Long val) { - return Base64.getEncoder().encodeToString(Arrays.copyOfRange(ByteUtils.toByteArrayBE(val), 4, 8)); - } + .expectedBucketOwner(uploadRequest.getExpectedBucketOwner()); - private static long fromBase64String(String base64String) { - byte[] decodedBytes = Base64.getDecoder().decode(base64String); - if (decodedBytes.length != 4) { - throw new IllegalArgumentException("Invalid Base64 encoded CRC32 checksum"); - } - long result = 0; - for (int i = 0; i < 4; i++) { - result <<= 8; - result |= (decodedBytes[i] & 0xFF); + if (isConditional && uploadRequest.getConditionalOptions() != null) { + applyConditionalHeaders(completeRequestBuilder, uploadRequest.getConditionalOptions()); } - return result; - } - private static void handleException(CompletableFuture returnFuture, Supplier message, Throwable throwable) { - Throwable cause = throwable instanceof CompletionException ? throwable.getCause() : throwable; + CompleteMultipartUploadRequest completeRequest = completeRequestBuilder.build(); - if (cause instanceof Error) { - returnFuture.completeExceptionally(cause); - } else { - SdkClientException exception = SdkClientException.create(message.get(), cause); - returnFuture.completeExceptionally(exception); - } + return SocketAccess.doPrivileged(() -> s3AsyncClient.completeMultipartUpload(completeRequest)) + .thenApply(response -> ConditionalWriteResponse.success(response.eTag())); } /** - * Calculates the optimal part size of each part request if the upload operation is carried out as multipart upload. - */ - public long calculateOptimalPartSize(long contentLengthOfSource, WritePriority writePriority, boolean uploadRetryEnabled) { - if (contentLengthOfSource < ByteSizeUnit.MB.toBytes(5)) { - return contentLengthOfSource; - } - if (uploadRetryEnabled && (writePriority == WritePriority.HIGH || writePriority == WritePriority.URGENT)) { - return new ByteSizeValue(5, ByteSizeUnit.MB).getBytes(); - } - double optimalPartSize = contentLengthOfSource / (double) MAX_UPLOAD_PARTS; - optimalPartSize = Math.ceil(optimalPartSize); - return (long) Math.max(optimalPartSize, minimumPartSize); - } + single chunk upload for both regular and conditional uploads + */ @SuppressWarnings("unchecked") private void uploadInOneChunk( S3AsyncClient s3AsyncClient, UploadRequest uploadRequest, StreamContext streamContext, - CompletableFuture returnFuture, + CompletableFuture returnFuture, StatsMetricPublisher statsMetricPublisher, - Semaphore semaphore + Semaphore semaphore, + boolean isConditional ) { PutObjectRequest.Builder putObjectRequestBuilder = PutObjectRequest.builder() .bucket(uploadRequest.getBucket()) @@ -376,6 +407,10 @@ private void uploadInOneChunk( configureEncryptionSettings(putObjectRequestBuilder, uploadRequest); + if (isConditional && uploadRequest.getConditionalOptions() != null) { + applyConditionalHeaders(putObjectRequestBuilder, uploadRequest.getConditionalOptions()); + } + PutObjectRequest putObjectRequest = putObjectRequestBuilder.build(); ExecutorService streamReadExecutor; if (uploadRequest.getWritePriority() == WritePriority.URGENT) { @@ -386,7 +421,7 @@ private void uploadInOneChunk( streamReadExecutor = executorService; } - CompletableFuture putObjectFuture = SocketAccess.doPrivileged(() -> { + CompletableFuture putObjectFuture = SocketAccess.doPrivileged(() -> { InputStream inputStream = null; CompletableFuture putObjectRespFuture; try { @@ -410,6 +445,7 @@ private void uploadInOneChunk( } InputStream finalInputStream = inputStream; + return putObjectRespFuture.handle((resp, throwable) -> { releaseResourcesSafely(semaphore, finalInputStream, uploadRequest.getKey()); @@ -429,7 +465,7 @@ private void uploadInOneChunk( } catch (IOException e) { throw new RuntimeException(e); } - returnFuture.complete(null); + returnFuture.complete(ConditionalWriteResponse.success(resp.eTag())); } return null; @@ -447,6 +483,155 @@ private void uploadInOneChunk( CompletableFutureUtils.forwardResultTo(putObjectFuture, returnFuture); } + /** + exception handling for both regular and conditional uploads + */ + private void handleUploadException( + CompletableFuture returnFuture, + Throwable throwable, + String resourceName, + boolean isConditional + ) { + if (isConditional) { + handleConditionalException(returnFuture, throwable, resourceName); + } else { + CompletableFuture voidFuture = new CompletableFuture<>(); + handleException(voidFuture, () -> "Upload failed for " + resourceName, throwable); + voidFuture.whenComplete((result, ex) -> { + if (ex != null) { + returnFuture.completeExceptionally(ex); + } else { + returnFuture.complete(ConditionalWriteResponse.success("")); + } + }); + } + } + + /** + + Apply conditional headers to a request builder + */ + private void applyConditionalHeaders(Object builder, ConditionalWriteOptions options) { + if (options == null) { + return; + } + + if (builder instanceof PutObjectRequest.Builder putBuilder) { + if (options.isIfNotExists()) { + putBuilder.ifNoneMatch("*"); + } else if (options.isIfMatch()) { + putBuilder.ifMatch(options.getVersionIdentifier()); + } + } else if (builder instanceof CompleteMultipartUploadRequest.Builder completeBuilder) { + if (options.isIfNotExists()) { + completeBuilder.ifNoneMatch("*"); + } else if (options.isIfMatch()) { + completeBuilder.ifMatch(options.getVersionIdentifier()); + } + } + } + + /** + + Error handler for conditional uploads + */ + private void handleConditionalException(CompletableFuture returnFuture, Throwable throwable, String resourceName) { + Throwable cause = throwable; + while (cause instanceof CompletionException && cause.getCause() != null) { + cause = cause.getCause(); + } + + if (cause instanceof S3Exception s3e) { + if (s3e.statusCode() == HTTP_STATUS_PRECONDITION_FAILED) { + returnFuture.completeExceptionally( + S3Exception.builder() + .message("Conditional write failed: condition not met for " + resourceName) + .statusCode(HTTP_STATUS_PRECONDITION_FAILED) + .cause(s3e) + .build() + ); + return; + } else if (s3e.statusCode() == HTTP_STATUS_CONFLICT) { + returnFuture.completeExceptionally( + S3Exception.builder() + .message("Blob already exists: " + resourceName) + .statusCode(HTTP_STATUS_CONFLICT) + .cause(s3e) + .build() + ); + return; + } + } + if (cause instanceof Error) { + returnFuture.completeExceptionally(cause); + } else { + SdkClientException exception = SdkClientException.create("Failed conditional upload of " + resourceName, cause); + returnFuture.completeExceptionally(exception); + } + } + + private void mergeAndVerifyChecksum( + AtomicReferenceArray inputStreamContainers, + String fileName, + long expectedChecksum + ) { + long resultantChecksum = fromBase64String(inputStreamContainers.get(0).getChecksum()); + for (int index = 1; index < inputStreamContainers.length(); index++) { + long curChecksum = fromBase64String(inputStreamContainers.get(index).getChecksum()); + resultantChecksum = JZlib.crc32_combine(resultantChecksum, curChecksum, inputStreamContainers.get(index).getContentLength()); + } + + if (resultantChecksum != expectedChecksum) { + throw new RuntimeException(new CorruptFileException("File level checksums didn't match combined part checksums", fileName)); + } + + } + + private static String base64StringFromLong(Long val) { + return Base64.getEncoder().encodeToString(Arrays.copyOfRange(ByteUtils.toByteArrayBE(val), 4, 8)); + } + + private static long fromBase64String(String base64String) { + byte[] decodedBytes = Base64.getDecoder().decode(base64String); + if (decodedBytes.length != 4) { + throw new IllegalArgumentException("Invalid Base64 encoded CRC32 checksum"); + } + long result = 0; + for (int i = 0; i < 4; i++) { + result <<= 8; + result |= (decodedBytes[i] & 0xFF); + } + return result; + } + + private static void handleException(CompletableFuture returnFuture, Supplier message, Throwable throwable) { + Throwable cause = throwable instanceof CompletionException ? throwable.getCause() : throwable; + + if (cause instanceof Error) { + returnFuture.completeExceptionally(cause); + } else { + SdkClientException exception = SdkClientException.create(message.get(), cause); + returnFuture.completeExceptionally(exception); + } + + } + + /** + + Calculates the optimal part size of each part request if the upload operation is carried out as multipart upload. + */ + public long calculateOptimalPartSize(long contentLengthOfSource, WritePriority writePriority, boolean uploadRetryEnabled) { + if (contentLengthOfSource < ByteSizeUnit.MB.toBytes(5)) { + return contentLengthOfSource; + } + if (uploadRetryEnabled && (writePriority == WritePriority.HIGH || writePriority == WritePriority.URGENT)) { + return new ByteSizeValue(5, ByteSizeUnit.MB).getBytes(); + } + double optimalPartSize = contentLengthOfSource / (double) MAX_UPLOAD_PARTS; + optimalPartSize = Math.ceil(optimalPartSize); + return (long) Math.max(optimalPartSize, minimumPartSize); + } + private void releaseResourcesSafely(Semaphore semaphore, InputStream inputStream, String file) { if (semaphore != null) { semaphore.release(); diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/UploadRequest.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/UploadRequest.java index 40ce391fe562d..f3b1e9271d781 100644 --- a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/UploadRequest.java +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/async/UploadRequest.java @@ -10,6 +10,7 @@ import org.opensearch.common.CheckedConsumer; import org.opensearch.common.Nullable; +import org.opensearch.common.blobstore.ConditionalWrite.ConditionalWriteOptions; import org.opensearch.common.blobstore.stream.write.WritePriority; import java.io.IOException; @@ -28,6 +29,7 @@ public class UploadRequest { private final Long expectedChecksum; private final Map metadata; private final boolean uploadRetryEnabled; + private final ConditionalWriteOptions conditionalOptions; private volatile String serverSideEncryptionType; private volatile String serverSideEncryptionKmsKey; private volatile boolean serverSideEncryptionBucketKey; @@ -45,6 +47,12 @@ public class UploadRequest { * @param doRemoteDataIntegrityCheck A boolean to inform vendor plugins whether remote data integrity checks need to be done * @param expectedChecksum Checksum of the file being uploaded for remote data integrity check * @param metadata Metadata of the file being uploaded + * @param conditionalOptions Conditions that must be satisfied for the write to succeed + * @param serverSideEncryptionType Type of server-side encryption + * @param serverSideEncryptionKmsKey KMS key for server-side encryption + * @param serverSideEncryptionBucketKey Whether to use bucket keys for server-side encryption + * @param serverSideEncryptionEncryptionContext Encryption context for server-side encryption + * @param expectedBucketOwner Expected owner of the bucket */ public UploadRequest( String bucket, @@ -56,6 +64,7 @@ public UploadRequest( Long expectedChecksum, boolean uploadRetryEnabled, @Nullable Map metadata, + @Nullable ConditionalWriteOptions conditionalOptions, String serverSideEncryptionType, String serverSideEncryptionKmsKey, boolean serverSideEncryptionBucketKey, @@ -71,6 +80,7 @@ public UploadRequest( this.expectedChecksum = expectedChecksum; this.uploadRetryEnabled = uploadRetryEnabled; this.metadata = metadata; + this.conditionalOptions = conditionalOptions; this.serverSideEncryptionType = serverSideEncryptionType; this.serverSideEncryptionKmsKey = serverSideEncryptionKmsKey; this.serverSideEncryptionBucketKey = serverSideEncryptionBucketKey; @@ -117,6 +127,14 @@ public Map getMetadata() { return metadata; } + /** + * @return conditional write options for this upload, or null if none are specified + */ + @Nullable + public ConditionalWriteOptions getConditionalOptions() { + return conditionalOptions; + } + public String getServerSideEncryptionType() { return serverSideEncryptionType; } diff --git a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerMockClientTests.java b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerMockClientTests.java index 8b96c6a6867a5..b77d6634a8995 100644 --- a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerMockClientTests.java +++ b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerMockClientTests.java @@ -30,10 +30,13 @@ import software.amazon.awssdk.services.s3.model.UploadPartResponse; import org.apache.lucene.store.IndexInput; +import org.opensearch.OpenSearchException; import org.opensearch.cluster.metadata.RepositoryMetadata; import org.opensearch.common.CheckedConsumer; import org.opensearch.common.StreamContext; import org.opensearch.common.blobstore.BlobPath; +import org.opensearch.common.blobstore.ConditionalWrite.ConditionalWriteOptions; +import org.opensearch.common.blobstore.ConditionalWrite.ConditionalWriteResponse; import org.opensearch.common.blobstore.stream.write.StreamContextSupplier; import org.opensearch.common.blobstore.stream.write.WriteContext; import org.opensearch.common.blobstore.stream.write.WritePriority; @@ -76,6 +79,7 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.stream.IntStream; +import org.mockito.ArgumentMatchers; import org.mockito.Mockito; import org.mockito.invocation.InvocationOnMock; @@ -90,6 +94,7 @@ import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doCallRealMethod; import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -774,4 +779,224 @@ private void testLargeFilesRedirectedToSlowSyncClient(boolean expectException, W } }); } + + private void testLargeFilesRedirectedToSlowSyncClientConditional( + boolean expectException, + WritePriority writePriority, + int conditionalResponseCode + ) throws IOException, InterruptedException { + + ByteSizeValue capacity = new ByteSizeValue(1, ByteSizeUnit.GB); + int numberOfParts = 20; + final ByteSizeValue partSize = new ByteSizeValue(capacity.getBytes() / numberOfParts + 1, ByteSizeUnit.BYTES); + + GenericStatsMetricPublisher genericStatsMetricPublisher = new GenericStatsMetricPublisher(10000L, 10, 10000L, 10); + SizeBasedBlockingQ sizeBasedBlockingQ = new SizeBasedBlockingQ( + capacity, + transferQueueConsumerService, + 10, + genericStatsMetricPublisher, + SizeBasedBlockingQ.QueueEventType.NORMAL + ); + + final long lastPartSize = new ByteSizeValue(200, ByteSizeUnit.MB).getBytes(); + final long blobSize = ((numberOfParts - 1) * partSize.getBytes()) + lastPartSize; + CountDownLatch countDownLatch = new CountDownLatch(1); + AtomicReference responseRef = new AtomicReference<>(); + AtomicReference exceptionRef = new AtomicReference<>(); + + ActionListener completionListener = ActionListener.wrap(resp -> { + responseRef.set(resp); + countDownLatch.countDown(); + }, ex -> { + exceptionRef.set(ex); + countDownLatch.countDown(); + }); + + final String bucketName = randomAlphaOfLengthBetween(1, 10); + + final BlobPath blobPath = new BlobPath(); + if (randomBoolean()) { + IntStream.of(randomIntBetween(1, 5)).forEach(value -> blobPath.add("path_" + value)); + } + + final long bufferSize = ByteSizeUnit.MB.toBytes(randomIntBetween(5, 1024)); + + final S3BlobStore blobStore = mock(S3BlobStore.class); + when(blobStore.bucket()).thenReturn(bucketName); + when(blobStore.getStatsMetricPublisher()).thenReturn(new StatsMetricPublisher()); + when(blobStore.bufferSizeInBytes()).thenReturn(bufferSize); + + when(blobStore.getLowPrioritySizeBasedBlockingQ()).thenReturn(sizeBasedBlockingQ); + when(blobStore.getNormalPrioritySizeBasedBlockingQ()).thenReturn(sizeBasedBlockingQ); + + final StorageClass storageClass = randomFrom(StorageClass.values()); + when(blobStore.getStorageClass()).thenReturn(storageClass); + when(blobStore.isRedirectLargeUploads()).thenReturn(true); + boolean uploadRetryEnabled = randomBoolean(); + when(blobStore.isUploadRetryEnabled()).thenReturn(uploadRetryEnabled); + + final ObjectCannedACL cannedAccessControlList = randomBoolean() ? randomFrom(ObjectCannedACL.values()) : null; + if (cannedAccessControlList != null) { + when(blobStore.getCannedACL()).thenReturn(cannedAccessControlList); + } + + if (randomBoolean()) { + when(blobStore.serverSideEncryptionType()).thenReturn(ServerSideEncryption.AES256.toString()); + } else { + when(blobStore.serverSideEncryptionType()).thenReturn(ServerSideEncryption.AWS_KMS.toString()); + when(blobStore.serverSideEncryptionKmsKey()).thenReturn(randomAlphaOfLength(10)); + when(blobStore.serverSideEncryptionBucketKey()).thenReturn(randomBoolean()); + when(blobStore.serverSideEncryptionEncryptionContext()).thenReturn(randomAlphaOfLength(10)); + } + + final S3Client client = mock(S3Client.class); + final AmazonS3Reference clientReference = Mockito.spy(new AmazonS3Reference(client)); + doNothing().when(clientReference).close(); + when(blobStore.clientReference()).thenReturn(clientReference); + + final String uploadId = randomAlphaOfLength(10); + final CreateMultipartUploadResponse createMultipartUploadResponse = CreateMultipartUploadResponse.builder() + .uploadId(uploadId) + .build(); + when(client.createMultipartUpload(any(CreateMultipartUploadRequest.class))).thenReturn(createMultipartUploadResponse); + + if (expectException) { + when(client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))).thenThrow( + SdkException.create("Expected upload part request to fail", new RuntimeException()) + ); + } else { + when(client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))).thenReturn( + UploadPartResponse.builder().eTag("part-etag-" + randomAlphaOfLength(5)).build() + ); + } + + if (conditionalResponseCode == 412) { + when(client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))).thenThrow( + software.amazon.awssdk.services.s3.model.S3Exception.builder().statusCode(412).message("Precondition Failed").build() + ); + } else if (conditionalResponseCode == 409) { + when(client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))).thenThrow( + software.amazon.awssdk.services.s3.model.S3Exception.builder().statusCode(409).message("Resource Already Exists").build() + ); + } else { + String eTag = "\"multipart-etag-" + randomAlphaOfLength(5) + "\""; + when(client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))).thenReturn( + CompleteMultipartUploadResponse.builder().eTag(eTag).build() + ); + } + + when(client.abortMultipartUpload(any(AbortMultipartUploadRequest.class))).thenReturn( + AbortMultipartUploadResponse.builder().build() + ); + + List openInputStreams = new ArrayList<>(); + final S3BlobContainer s3BlobContainer = Mockito.spy(new S3BlobContainer(blobPath, blobStore)); + + doCallRealMethod().when(s3BlobContainer) + .executeMultipartUploadConditionally( + any(S3BlobStore.class), + anyString(), + any(InputStream.class), + anyLong(), + any(), + any(ConditionalWriteOptions.class), + ArgumentMatchers.>any() + ); + + StreamContextSupplier streamContextSupplier = partSize1 -> new StreamContext((partNo, size, position) -> { + InputStream inputStream = new OffsetRangeIndexInputStream(new ZeroIndexInput("desc", blobSize), size, position); + openInputStreams.add(inputStream); + return new InputStreamContainer(inputStream, size, position); + }, partSize1, calculateLastPartSize(blobSize, partSize1), calculateNumberOfParts(blobSize, partSize1)); + + WriteContext writeContext = new WriteContext.Builder().fileName("write_large_blob_conditional") + .streamContextSupplier(streamContextSupplier) + .fileSize(blobSize) + .failIfAlreadyExists(false) + .writePriority(writePriority) + .uploadFinalizer(success -> { + Assert.assertTrue(success); + }) + .doRemoteDataIntegrityCheck(false) + .metadata(new HashMap<>()) + .build(); + + ConditionalWriteOptions conditionalOptions; + if (conditionalResponseCode == 412) { + conditionalOptions = ConditionalWriteOptions.ifMatch("invalid-etag"); + } else if (conditionalResponseCode == 409) { + conditionalOptions = ConditionalWriteOptions.ifNotExists(); + } else { + conditionalOptions = ConditionalWriteOptions.ifMatch("valid-etag"); + } + + s3BlobContainer.asyncBlobUploadConditionally(writeContext, conditionalOptions, completionListener); + + boolean awaitSuccess = countDownLatch.await(5000, TimeUnit.SECONDS); + assertTrue(awaitSuccess); + + if (expectException || conditionalResponseCode != 0) { + assertNotNull("Should have received an exception", exceptionRef.get()); + + if (conditionalResponseCode == 412 && !expectException) { + assertTrue( + "Should have conditional error", + exceptionRef.get() instanceof OpenSearchException + || exceptionRef.get().getMessage().contains("Precondition Failed") + || exceptionRef.get().getMessage().contains("ETag mismatch") + ); + } else if (conditionalResponseCode == 409 && !expectException) { + assertTrue( + "Should have conflict error", + exceptionRef.get().getMessage().contains("Resource Already Exists") + || exceptionRef.get().getMessage().contains("already exists") + ); + } + } else { + assertNull("Should not have received an exception", exceptionRef.get()); + assertNotNull("Should have received a response", responseRef.get()); + assertNotNull("Response should have version identifier", responseRef.get().getVersionIdentifier()); + } + + verify(s3BlobContainer, times(1)).executeMultipartUploadConditionally( + any(S3BlobStore.class), + anyString(), + any(InputStream.class), + anyLong(), + anyMap(), + any(ConditionalWriteOptions.class), + ArgumentMatchers.>any() + ); + + boolean shouldAbort = expectException || (conditionalResponseCode != 0 && !expectException); + verify(client, times(shouldAbort ? 1 : 0)).abortMultipartUpload(any(AbortMultipartUploadRequest.class)); + + openInputStreams.forEach(inputStream -> { + try { + inputStream.close(); + } catch (IOException ex) {} + }); + } + + public void testFailureWhenLargeFileRedirectedConditional() throws IOException, InterruptedException { + testLargeFilesRedirectedToSlowSyncClientConditional(true, WritePriority.LOW, 0); + testLargeFilesRedirectedToSlowSyncClientConditional(true, WritePriority.NORMAL, 0); + } + + public void testLargeFileRedirectedConditional() throws IOException, InterruptedException { + testLargeFilesRedirectedToSlowSyncClientConditional(false, WritePriority.LOW, 0); + testLargeFilesRedirectedToSlowSyncClientConditional(false, WritePriority.NORMAL, 0); + } + + public void testLargeFileRedirectedConditionalPreconditionFailed() throws IOException, InterruptedException { + testLargeFilesRedirectedToSlowSyncClientConditional(false, WritePriority.LOW, 412); + testLargeFilesRedirectedToSlowSyncClientConditional(false, WritePriority.NORMAL, 412); + } + + public void testLargeFileRedirectedConditionalConflict() throws IOException, InterruptedException { + testLargeFilesRedirectedToSlowSyncClientConditional(false, WritePriority.LOW, 409); + testLargeFilesRedirectedToSlowSyncClientConditional(false, WritePriority.NORMAL, 409); + } + } diff --git a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerRetriesTests.java b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerRetriesTests.java index 4193609ac520d..774eab80503fe 100644 --- a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerRetriesTests.java +++ b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobContainerRetriesTests.java @@ -33,6 +33,7 @@ import software.amazon.awssdk.core.exception.SdkClientException; import software.amazon.awssdk.core.io.SdkDigestInputStream; +import software.amazon.awssdk.services.s3.model.S3Exception; import software.amazon.awssdk.utils.internal.Base16; import org.apache.http.HttpStatus; @@ -43,6 +44,8 @@ import org.opensearch.common.blobstore.AsyncMultiStreamBlobContainer; import org.opensearch.common.blobstore.BlobContainer; import org.opensearch.common.blobstore.BlobPath; +import org.opensearch.common.blobstore.ConditionalWrite.ConditionalWriteOptions; +import org.opensearch.common.blobstore.ConditionalWrite.ConditionalWriteResponse; import org.opensearch.common.blobstore.stream.write.StreamContextSupplier; import org.opensearch.common.blobstore.stream.write.WriteContext; import org.opensearch.common.blobstore.stream.write.WritePriority; @@ -427,6 +430,68 @@ public void testWriteBlobByStreamsWithRetries() throws Exception { }); } + public void testWriteBlobByStreamsConditionalFailure() throws Exception { + byte[] bytes = randomBlobContent(); + + httpServer.createContext("/bucket/write_blob_conditionally_failure", exchange -> { + if ("PUT".equals(exchange.getRequestMethod()) && exchange.getRequestURI().getQuery() == null) { + Streams.readFully(exchange.getRequestBody()); + exchange.sendResponseHeaders(HttpStatus.SC_PRECONDITION_FAILED, -1); + exchange.close(); + } + }); + + AsyncMultiStreamBlobContainer blobContainer = createBlobContainer(0, null, true, null); + List streams = new ArrayList<>(); + CountDownLatch latch = new CountDownLatch(1); + AtomicReference exRef = new AtomicReference<>(); + + ActionListener listener = ActionListener.wrap(resp -> { + fail("Expected failure"); + latch.countDown(); + }, ex -> { + exRef.set(ex); + latch.countDown(); + }); + + int partSize = 5 * 1024 * 1024; + StreamContextSupplier supplier = ps -> new StreamContext((part, size, pos) -> { + InputStream in = new OffsetRangeIndexInputStream(new ByteArrayIndexInput("desc", bytes), size, pos); + streams.add(in); + return new InputStreamContainer(in, size, pos); + }, partSize, calculateLastPartSize(bytes.length, partSize), calculateNumberOfParts(bytes.length, partSize)); + + ConditionalWriteOptions options = ConditionalWriteOptions.ifMatch("non-matching-etag"); + WriteContext context = new WriteContext.Builder().fileName("write_blob_conditionally_failure") + .streamContextSupplier(supplier) + .fileSize(bytes.length) + .failIfAlreadyExists(false) + .writePriority(WritePriority.NORMAL) + .uploadFinalizer(Assert::assertTrue) + .doRemoteDataIntegrityCheck(false) + .build(); + + blobContainer.asyncBlobUploadConditionally(context, options, listener); + assertTrue(latch.await(5, TimeUnit.SECONDS)); + + Exception actual = exRef.get(); + assertNotNull(actual); + String msg = actual.getMessage(); + assertTrue( + msg.contains("Precondition Failed") + || msg.contains("412") + || (actual instanceof S3Exception && ((S3Exception) actual).statusCode() == 412) + ); + + for (InputStream in : streams) { + try { + in.close(); + } catch (IOException e) { + fail("Stream close failure"); + } + } + } + private long calculateLastPartSize(long totalSize, long partSize) { return totalSize % partSize == 0 ? partSize : totalSize % partSize; } diff --git a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java index 7a56541a50364..ec2a05e33476c 100644 --- a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java +++ b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java @@ -64,6 +64,7 @@ import software.amazon.awssdk.services.s3.model.PutObjectRequest; import software.amazon.awssdk.services.s3.model.PutObjectResponse; import software.amazon.awssdk.services.s3.model.S3Error; +import software.amazon.awssdk.services.s3.model.S3Exception; import software.amazon.awssdk.services.s3.model.S3Object; import software.amazon.awssdk.services.s3.model.ServerSideEncryption; import software.amazon.awssdk.services.s3.model.StorageClass; @@ -72,11 +73,14 @@ import software.amazon.awssdk.services.s3.paginators.ListObjectsV2Iterable; import software.amazon.awssdk.services.s3.paginators.ListObjectsV2Publisher; +import org.opensearch.OpenSearchException; import org.opensearch.action.LatchedActionListener; import org.opensearch.common.blobstore.BlobContainer; import org.opensearch.common.blobstore.BlobMetadata; import org.opensearch.common.blobstore.BlobPath; import org.opensearch.common.blobstore.BlobStoreException; +import org.opensearch.common.blobstore.ConditionalWrite.ConditionalWriteOptions; +import org.opensearch.common.blobstore.ConditionalWrite.ConditionalWriteResponse; import org.opensearch.common.blobstore.DeleteResult; import org.opensearch.common.blobstore.stream.read.ReadContext; import org.opensearch.common.collect.Tuple; @@ -97,6 +101,7 @@ import java.util.List; import java.util.Map; import java.util.NoSuchElementException; +import java.util.Random; import java.util.Set; import java.util.UUID; import java.util.concurrent.CompletableFuture; @@ -120,6 +125,8 @@ import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoInteractions; @@ -985,6 +992,1148 @@ public void testNumberOfMultiparts() { assertNumberOfMultiparts(factor + 1, remaining, (size * factor) + remaining, size); } + public void testExecuteMultipartUploadConditionallyWithEtagMatchSuccess() throws IOException { + final String bucketName = randomAlphaOfLengthBetween(1, 10); + final String blobName = randomAlphaOfLengthBetween(1, 10); + final String inputETag = randomAlphaOfLengthBetween(8, 32); + final String finalETag = randomAlphaOfLengthBetween(8, 32); + final String uploadId = randomAlphaOfLengthBetween(10, 20); + + final Map metadata = new HashMap<>(); + metadata.put("key1", "value1"); + metadata.put("key2", "value2"); + + final BlobPath blobPath = new BlobPath(); + if (randomBoolean()) { + IntStream.of(randomIntBetween(1, 5)).forEach(value -> blobPath.add("path_" + value)); + } + + final long partSize = ByteSizeUnit.MB.toBytes(5); + final int partCount = randomIntBetween(2, 5); + final long lastPartSize = randomIntBetween(1, (int) partSize); + final long blobSize = partSize * (partCount - 1) + lastPartSize; + + final S3BlobStore blobStore = mock(S3BlobStore.class); + final StatsMetricPublisher metricPublisher = new StatsMetricPublisher(); + when(blobStore.bucket()).thenReturn(bucketName); + when(blobStore.bufferSizeInBytes()).thenReturn(partSize); + when(blobStore.getStatsMetricPublisher()).thenReturn(metricPublisher); + + final StorageClass storageClass = randomFrom(StorageClass.values()); + when(blobStore.getStorageClass()).thenReturn(storageClass); + + final boolean useSseKms = randomBoolean(); + final String kmsKeyId = randomAlphaOfLengthBetween(10, 20); + final String kmsContext = randomAlphaOfLengthBetween(10, 20); + final boolean useBucketKey = randomBoolean(); + if (useSseKms) { + when(blobStore.serverSideEncryptionType()).thenReturn(ServerSideEncryption.AWS_KMS.toString()); + when(blobStore.serverSideEncryptionKmsKey()).thenReturn(kmsKeyId); + when(blobStore.serverSideEncryptionBucketKey()).thenReturn(useBucketKey); + when(blobStore.serverSideEncryptionEncryptionContext()).thenReturn(kmsContext); + } else { + when(blobStore.serverSideEncryptionType()).thenReturn(ServerSideEncryption.AES256.toString()); + } + + final ObjectCannedACL cannedAccessControlList = randomBoolean() ? randomFrom(ObjectCannedACL.values()) : null; + if (cannedAccessControlList != null) { + when(blobStore.getCannedACL()).thenReturn(cannedAccessControlList); + } + + final boolean isUploadRetryEnabled = randomBoolean(); + when(blobStore.isUploadRetryEnabled()).thenReturn(isUploadRetryEnabled); + + final S3BlobContainer blobContainer = new S3BlobContainer(blobPath, blobStore); + + final S3Client client = mock(S3Client.class); + final AmazonS3Reference clientReference = mock(AmazonS3Reference.class); + when(blobStore.clientReference()).thenReturn(clientReference); + when(clientReference.get()).thenReturn(client); + + final ArgumentCaptor createRequestCaptor = ArgumentCaptor.forClass( + CreateMultipartUploadRequest.class + ); + final ArgumentCaptor uploadPartRequestCaptor = ArgumentCaptor.forClass(UploadPartRequest.class); + final ArgumentCaptor requestBodyCaptor = ArgumentCaptor.forClass(RequestBody.class); + final ArgumentCaptor completeRequestCaptor = ArgumentCaptor.forClass( + CompleteMultipartUploadRequest.class + ); + + when(client.createMultipartUpload(createRequestCaptor.capture())).thenReturn( + CreateMultipartUploadResponse.builder().uploadId(uploadId).build() + ); + + final List partETags = new ArrayList<>(); + for (int i = 0; i < partCount; i++) { + partETags.add("etag-part-" + (i + 1)); + } + + when(client.uploadPart(uploadPartRequestCaptor.capture(), requestBodyCaptor.capture())).thenAnswer(invocation -> { + UploadPartRequest request = (UploadPartRequest) invocation.getArguments()[0]; + int partNumber = request.partNumber(); + return UploadPartResponse.builder().eTag(partETags.get(partNumber - 1)).build(); + }); + + when(client.completeMultipartUpload(completeRequestCaptor.capture())).thenReturn( + CompleteMultipartUploadResponse.builder().eTag(finalETag).build() + ); + + @SuppressWarnings("unchecked") + ActionListener responseListener = mock(ActionListener.class); + + final ByteArrayInputStream inputStream = new ByteArrayInputStream(new byte[(int) blobSize]); + + blobContainer.executeMultipartUploadConditionally( + blobStore, + blobName, + inputStream, + blobSize, + metadata, + ConditionalWriteOptions.ifMatch(inputETag), + responseListener + ); + + final CreateMultipartUploadRequest createRequest = createRequestCaptor.getValue(); + assertEquals(bucketName, createRequest.bucket()); + assertEquals(blobPath.buildAsString() + blobName, createRequest.key()); + assertEquals(storageClass, createRequest.storageClass()); + assertEquals(cannedAccessControlList, createRequest.acl()); + assertEquals(metadata, createRequest.metadata()); + + // ENCRYPTION VERIFICATION: Updated encryption verification + if (useSseKms) { + assertEquals(ServerSideEncryption.AWS_KMS, createRequest.serverSideEncryption()); + assertEquals(kmsKeyId, createRequest.ssekmsKeyId()); + assertEquals(kmsContext, createRequest.ssekmsEncryptionContext()); + assertEquals(useBucketKey, createRequest.bucketKeyEnabled()); + } else { + assertEquals(ServerSideEncryption.AES256, createRequest.serverSideEncryption()); + } + + List partRequests = uploadPartRequestCaptor.getAllValues(); + assertEquals(partCount, partRequests.size()); + + List requestBodies = requestBodyCaptor.getAllValues(); + assertEquals(partCount, requestBodies.size()); + + for (int i = 0; i < partCount; i++) { + UploadPartRequest partRequest = partRequests.get(i); + assertEquals(bucketName, partRequest.bucket()); + assertEquals(blobPath.buildAsString() + blobName, partRequest.key()); + assertEquals(uploadId, partRequest.uploadId()); + assertEquals(Integer.valueOf(i + 1), partRequest.partNumber()); + long expectedPartSize = (i < partCount - 1) ? partSize : lastPartSize; + assertEquals(expectedPartSize, partRequest.contentLength().longValue()); + + if (i == 0) { + RequestBody body = requestBodies.get(i); + try (InputStream is = body.contentStreamProvider().newStream()) { + assertNotNull(is); + } + } + } + + CompleteMultipartUploadRequest completeRequest = completeRequestCaptor.getValue(); + assertEquals(bucketName, completeRequest.bucket()); + assertEquals(blobPath.buildAsString() + blobName, completeRequest.key()); + assertEquals(uploadId, completeRequest.uploadId()); + + assertEquals(inputETag, completeRequest.ifMatch()); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(ConditionalWriteResponse.class); + verify(responseListener).onResponse(responseCaptor.capture()); + assertEquals(finalETag, responseCaptor.getValue().getVersionIdentifier()); + + verify(responseListener, never()).onFailure(any()); + verify(clientReference).close(); + } + + public void testExecuteMultipartUploadConditionallyWithMetadataAndSSE() throws IOException { + final String bucketName = randomAlphaOfLengthBetween(1, 10); + final String blobName = randomAlphaOfLengthBetween(1, 10); + final String inputETag = randomAlphaOfLengthBetween(8, 32); + final String finalETag = randomAlphaOfLengthBetween(8, 32); + final String uploadId = randomAlphaOfLengthBetween(10, 20); + + final Map metadata = new HashMap<>(); + metadata.put("key1", "value1"); + metadata.put("key2", "value2"); + + final BlobPath blobPath = new BlobPath(); + + final long partSize = ByteSizeUnit.MB.toBytes(5); + final long blobSize = partSize * 2; + + final S3BlobStore blobStore = mock(S3BlobStore.class); + final StatsMetricPublisher metricPublisher = new StatsMetricPublisher(); + when(blobStore.bucket()).thenReturn(bucketName); + when(blobStore.bufferSizeInBytes()).thenReturn(partSize); + when(blobStore.getStatsMetricPublisher()).thenReturn(metricPublisher); + + // ENCRYPTION CHANGES: Replace hard-coded encryption with enhanced configuration + final boolean useSseKms = randomBoolean(); + final String kmsKeyId = randomAlphaOfLengthBetween(10, 20); + final String kmsContext = randomAlphaOfLengthBetween(10, 20); + final boolean useBucketKey = randomBoolean(); + if (useSseKms) { + when(blobStore.serverSideEncryptionType()).thenReturn(ServerSideEncryption.AWS_KMS.toString()); + when(blobStore.serverSideEncryptionKmsKey()).thenReturn(kmsKeyId); + when(blobStore.serverSideEncryptionBucketKey()).thenReturn(useBucketKey); + when(blobStore.serverSideEncryptionEncryptionContext()).thenReturn(kmsContext); + } else { + when(blobStore.serverSideEncryptionType()).thenReturn(ServerSideEncryption.AES256.toString()); + } + + final StorageClass storageClass = randomFrom(StorageClass.values()); + when(blobStore.getStorageClass()).thenReturn(storageClass); + + final ObjectCannedACL cannedAccessControlList = randomFrom(ObjectCannedACL.values()); + when(blobStore.getCannedACL()).thenReturn(cannedAccessControlList); + + final boolean isUploadRetryEnabled = randomBoolean(); + when(blobStore.isUploadRetryEnabled()).thenReturn(isUploadRetryEnabled); + + final S3BlobContainer blobContainer = new S3BlobContainer(blobPath, blobStore); + + final S3Client client = mock(S3Client.class); + final AmazonS3Reference clientReference = mock(AmazonS3Reference.class); + when(blobStore.clientReference()).thenReturn(clientReference); + when(clientReference.get()).thenReturn(client); + + final ArgumentCaptor createRequestCaptor = ArgumentCaptor.forClass( + CreateMultipartUploadRequest.class + ); + final ArgumentCaptor uploadPartRequestCaptor = ArgumentCaptor.forClass(UploadPartRequest.class); + final ArgumentCaptor requestBodyCaptor = ArgumentCaptor.forClass(RequestBody.class); + final ArgumentCaptor completeRequestCaptor = ArgumentCaptor.forClass( + CompleteMultipartUploadRequest.class + ); + + when(client.createMultipartUpload(createRequestCaptor.capture())).thenReturn( + CreateMultipartUploadResponse.builder().uploadId(uploadId).build() + ); + + final List partETags = List.of("etag-part-1", "etag-part-2"); + + when(client.uploadPart(uploadPartRequestCaptor.capture(), requestBodyCaptor.capture())).thenAnswer(invocation -> { + UploadPartRequest request = (UploadPartRequest) invocation.getArguments()[0]; + int partNumber = request.partNumber(); + return UploadPartResponse.builder().eTag(partETags.get(partNumber - 1)).build(); + }); + + when(client.completeMultipartUpload(completeRequestCaptor.capture())).thenReturn( + CompleteMultipartUploadResponse.builder().eTag(finalETag).build() + ); + + @SuppressWarnings("unchecked") + ActionListener responseListener = mock(ActionListener.class); + + final ByteArrayInputStream inputStream = new ByteArrayInputStream(new byte[(int) blobSize]); + + blobContainer.executeMultipartUploadConditionally( + blobStore, + blobName, + inputStream, + blobSize, + metadata, + ConditionalWriteOptions.ifMatch(inputETag), + responseListener + ); + + final CreateMultipartUploadRequest createRequest = createRequestCaptor.getValue(); + assertEquals(bucketName, createRequest.bucket()); + assertEquals(blobPath.buildAsString() + blobName, createRequest.key()); + assertEquals(storageClass, createRequest.storageClass()); + assertEquals(cannedAccessControlList, createRequest.acl()); + assertEquals(metadata, createRequest.metadata()); + + // ENCRYPTION VERIFICATION: Updated encryption verification + if (useSseKms) { + assertEquals(ServerSideEncryption.AWS_KMS, createRequest.serverSideEncryption()); + assertEquals(kmsKeyId, createRequest.ssekmsKeyId()); + assertEquals(kmsContext, createRequest.ssekmsEncryptionContext()); + assertEquals(useBucketKey, createRequest.bucketKeyEnabled()); + } else { + assertEquals(ServerSideEncryption.AES256, createRequest.serverSideEncryption()); + } + + List partRequests = uploadPartRequestCaptor.getAllValues(); + assertEquals(2, partRequests.size()); + + for (int i = 0; i < 2; i++) { + UploadPartRequest partRequest = partRequests.get(i); + assertEquals(bucketName, partRequest.bucket()); + assertEquals(blobPath.buildAsString() + blobName, partRequest.key()); + assertEquals(uploadId, partRequest.uploadId()); + assertEquals(Integer.valueOf(i + 1), partRequest.partNumber()); + assertEquals(partSize, partRequest.contentLength().longValue()); + + if (i == 0) { + RequestBody body = requestBodyCaptor.getAllValues().get(i); + try (InputStream is = body.contentStreamProvider().newStream()) { + assertNotNull(is); + assertTrue("Content stream should be available", is.available() > 0); + } + } + } + + CompleteMultipartUploadRequest completeRequest = completeRequestCaptor.getValue(); + assertEquals(bucketName, completeRequest.bucket()); + assertEquals(blobPath.buildAsString() + blobName, completeRequest.key()); + assertEquals(uploadId, completeRequest.uploadId()); + + assertEquals(inputETag, completeRequest.ifMatch()); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(ConditionalWriteResponse.class); + verify(responseListener).onResponse(responseCaptor.capture()); + assertEquals(finalETag, responseCaptor.getValue().getVersionIdentifier()); + + verify(responseListener, never()).onFailure(any()); + + verify(clientReference).close(); + } + + public void testExecuteMultipartUploadConditionallyContentIntegrity() throws IOException { + final String bucketName = randomAlphaOfLengthBetween(1, 10); + final String blobName = randomAlphaOfLengthBetween(1, 10); + final String inputETag = randomAlphaOfLengthBetween(8, 32); + final String finalETag = randomAlphaOfLengthBetween(8, 32); + final String uploadId = randomAlphaOfLengthBetween(10, 20); + + final BlobPath blobPath = new BlobPath(); + + final int partCount = 3; + final long partSize = ByteSizeUnit.MB.toBytes(5); + final long blobSize = partSize * partCount; + + final byte[] blobContent = new byte[(int) blobSize]; + Random random = new Random(0); + random.nextBytes(blobContent); + + final S3BlobStore blobStore = mock(S3BlobStore.class); + + final StatsMetricPublisher metricPublisher = new StatsMetricPublisher(); + + when(blobStore.bucket()).thenReturn(bucketName); + when(blobStore.bufferSizeInBytes()).thenReturn(partSize); + when(blobStore.getStatsMetricPublisher()).thenReturn(metricPublisher); + when(blobStore.getStorageClass()).thenReturn(StorageClass.STANDARD); + + // ENCRYPTION CHANGES: Replace serverSideEncryption with enhanced configuration + final boolean useSseKms = randomBoolean(); + final String kmsKeyId = randomAlphaOfLengthBetween(10, 20); + final String kmsContext = randomAlphaOfLengthBetween(10, 20); + final boolean useBucketKey = randomBoolean(); + if (useSseKms) { + when(blobStore.serverSideEncryptionType()).thenReturn(ServerSideEncryption.AWS_KMS.toString()); + when(blobStore.serverSideEncryptionKmsKey()).thenReturn(kmsKeyId); + when(blobStore.serverSideEncryptionBucketKey()).thenReturn(useBucketKey); + when(blobStore.serverSideEncryptionEncryptionContext()).thenReturn(kmsContext); + } else { + when(blobStore.serverSideEncryptionType()).thenReturn(ServerSideEncryption.AES256.toString()); + } + + when(blobStore.isUploadRetryEnabled()).thenReturn(false); + + final S3BlobContainer blobContainer = new S3BlobContainer(blobPath, blobStore); + + final S3Client client = mock(S3Client.class); + final AmazonS3Reference clientReference = mock(AmazonS3Reference.class); + when(blobStore.clientReference()).thenReturn(clientReference); + when(clientReference.get()).thenReturn(client); + + final ArgumentCaptor createRequestCaptor = ArgumentCaptor.forClass( + CreateMultipartUploadRequest.class + ); + final ArgumentCaptor completeRequestCaptor = ArgumentCaptor.forClass( + CompleteMultipartUploadRequest.class + ); + final ArgumentCaptor requestBodyCaptor = ArgumentCaptor.forClass(RequestBody.class); + + when(client.createMultipartUpload(createRequestCaptor.capture())).thenReturn( + CreateMultipartUploadResponse.builder().uploadId(uploadId).build() + ); + + final List capturedPartContents = new ArrayList<>(); + + when(client.uploadPart(any(UploadPartRequest.class), requestBodyCaptor.capture())).thenAnswer(invocation -> { + RequestBody requestBody = requestBodyCaptor.getValue(); + try (InputStream contentStream = requestBody.contentStreamProvider().newStream()) { + byte[] partContent = contentStream.readAllBytes(); + capturedPartContents.add(partContent); + } + return UploadPartResponse.builder().eTag("etag-for-part").build(); + }); + + when(client.completeMultipartUpload(completeRequestCaptor.capture())).thenReturn( + CompleteMultipartUploadResponse.builder().eTag(finalETag).build() + ); + + @SuppressWarnings("unchecked") + ActionListener responseListener = mock(ActionListener.class); + + final ByteArrayInputStream inputStream = new ByteArrayInputStream(blobContent); + + blobContainer.executeMultipartUploadConditionally( + blobStore, + blobName, + inputStream, + blobSize, + null, + ConditionalWriteOptions.ifMatch(inputETag), + responseListener + ); + + final CreateMultipartUploadRequest createRequest = createRequestCaptor.getValue(); + assertEquals(bucketName, createRequest.bucket()); + assertEquals(blobPath.buildAsString() + blobName, createRequest.key()); + + // No explicit encryption verification needed for this test as it focuses on content integrity + + final CompleteMultipartUploadRequest completeRequest = completeRequestCaptor.getValue(); + assertEquals(inputETag, completeRequest.ifMatch()); + assertEquals(bucketName, completeRequest.bucket()); + assertEquals(blobPath.buildAsString() + blobName, completeRequest.key()); + assertEquals(uploadId, completeRequest.uploadId()); + + assertEquals(partCount, capturedPartContents.size()); + + byte[] reassembledContent = new byte[(int) blobSize]; + int offset = 0; + for (byte[] partContent : capturedPartContents) { + System.arraycopy(partContent, 0, reassembledContent, offset, partContent.length); + offset += partContent.length; + } + + assertArrayEquals("Uploaded content should match original content", blobContent, reassembledContent); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(ConditionalWriteResponse.class); + verify(responseListener).onResponse(responseCaptor.capture()); + assertEquals(finalETag, responseCaptor.getValue().getVersionIdentifier()); + + verify(responseListener, never()).onFailure(any()); + + verify(clientReference).close(); + } + + public void testExecuteMultipartUploadConditionallySizeValidation() { + final S3BlobStore blobStore = mock(S3BlobStore.class); + final S3BlobContainer blobContainer = new S3BlobContainer(mock(BlobPath.class), blobStore); + final String blobName = randomAlphaOfLengthBetween(1, 10); + final String inputETag = randomAlphaOfLengthBetween(8, 32); + final String finalETag = randomAlphaOfLengthBetween(8, 32); + + @SuppressWarnings("unchecked") + ActionListener invalidSizeListener = mock(ActionListener.class); + + { + final long tooSmallSize = ByteSizeUnit.MB.toBytes(5) - 1024; + + final IllegalArgumentException tooSmallException = expectThrows( + IllegalArgumentException.class, + () -> blobContainer.executeMultipartUploadConditionally( + blobStore, + blobName, + new ByteArrayInputStream(new byte[0]), + tooSmallSize, + null, + ConditionalWriteOptions.ifMatch(inputETag), + invalidSizeListener + ) + ); + + assertTrue(tooSmallException.getMessage().contains("can't be smaller than")); + verify(invalidSizeListener, never()).onResponse(any()); + verify(invalidSizeListener, never()).onFailure(any()); + } + + { + final long tooLargeSize = ByteSizeUnit.TB.toBytes(5) + 1; + + final IllegalArgumentException tooLargeException = expectThrows( + IllegalArgumentException.class, + () -> blobContainer.executeMultipartUploadConditionally( + blobStore, + blobName, + new ByteArrayInputStream(new byte[0]), + tooLargeSize, + null, + ConditionalWriteOptions.ifMatch(inputETag), + invalidSizeListener + ) + ); + + assertTrue(tooLargeException.getMessage().contains("can't be larger than")); + verify(invalidSizeListener, never()).onResponse(any()); + verify(invalidSizeListener, never()).onFailure(any()); + } + + final S3Client client = mock(S3Client.class); + final AmazonS3Reference clientReference = mock(AmazonS3Reference.class); + final StatsMetricPublisher metricPublisher = new StatsMetricPublisher(); + + when(blobStore.getStatsMetricPublisher()).thenReturn(metricPublisher); + when(blobStore.clientReference()).thenReturn(clientReference); + when(clientReference.get()).thenReturn(client); + when(blobStore.bucket()).thenReturn("test-bucket"); + + when(blobStore.bufferSizeInBytes()).thenReturn(ByteSizeUnit.MB.toBytes(5)); + + final boolean useSseKms = randomBoolean(); + final String kmsKeyId = randomAlphaOfLengthBetween(10, 20); + final String kmsContext = randomAlphaOfLengthBetween(10, 20); + final boolean useBucketKey = randomBoolean(); + if (useSseKms) { + when(blobStore.serverSideEncryptionType()).thenReturn(ServerSideEncryption.AWS_KMS.toString()); + when(blobStore.serverSideEncryptionKmsKey()).thenReturn(kmsKeyId); + when(blobStore.serverSideEncryptionBucketKey()).thenReturn(useBucketKey); + when(blobStore.serverSideEncryptionEncryptionContext()).thenReturn(kmsContext); + } else { + when(blobStore.serverSideEncryptionType()).thenReturn(ServerSideEncryption.AES256.toString()); + } + + ArgumentCaptor completeRequestCaptor = ArgumentCaptor.forClass( + CompleteMultipartUploadRequest.class + ); + + when(client.createMultipartUpload(any(CreateMultipartUploadRequest.class))).thenReturn( + CreateMultipartUploadResponse.builder().uploadId("test-upload-id").build() + ); + + when(client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))).thenReturn( + UploadPartResponse.builder().eTag("test-etag").build() + ); + + when(client.completeMultipartUpload(completeRequestCaptor.capture())).thenReturn( + CompleteMultipartUploadResponse.builder().eTag(finalETag).build() + ); + + { + final long exactMinimumSize = ByteSizeUnit.MB.toBytes(5); + @SuppressWarnings("unchecked") + ActionListener validSizeListener = mock(ActionListener.class); + + InputStream zeroStream = new InputStream() { + long remaining = exactMinimumSize; + + @Override + public int read() { + if (remaining > 0) { + remaining--; + return 0; + } else { + return -1; + } + } + + @Override + public int read(byte[] b, int off, int len) { + if (remaining <= 0) { + return -1; + } + int toRead = (int) Math.min(len, remaining); + Arrays.fill(b, off, off + toRead, (byte) 0); + remaining -= toRead; + return toRead; + } + }; + + try { + blobContainer.executeMultipartUploadConditionally( + blobStore, + blobName, + zeroStream, + exactMinimumSize, + null, + ConditionalWriteOptions.ifMatch(inputETag), + validSizeListener + ); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(ConditionalWriteResponse.class); + verify(validSizeListener).onResponse(responseCaptor.capture()); + assertEquals(finalETag, responseCaptor.getValue().getVersionIdentifier()); + verify(validSizeListener, never()).onFailure(any()); + + verify(clientReference).close(); + + CompleteMultipartUploadRequest completeRequest = completeRequestCaptor.getValue(); + assertEquals(inputETag, completeRequest.ifMatch()); + + } catch (IOException e) { + fail("Should not throw exception for exact minimum size: " + e); + } + } + + reset(clientReference); + when(blobStore.clientReference()).thenReturn(clientReference); + when(clientReference.get()).thenReturn(client); + + { + final long testSize = ByteSizeUnit.MB.toBytes(10); + @SuppressWarnings("unchecked") + ActionListener validSizeListener = mock(ActionListener.class); + + InputStream zeroStream = new InputStream() { + long remaining = testSize; + + @Override + public int read() { + if (remaining > 0) { + remaining--; + return 0; + } else { + return -1; + } + } + + @Override + public int read(byte[] b, int off, int len) { + if (remaining <= 0) { + return -1; + } + int toRead = (int) Math.min(len, remaining); + Arrays.fill(b, off, off + toRead, (byte) 0); + remaining -= toRead; + return toRead; + } + }; + + try { + blobContainer.executeMultipartUploadConditionally( + blobStore, + blobName, + zeroStream, + testSize, + null, + ConditionalWriteOptions.ifMatch(inputETag), + validSizeListener + ); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(ConditionalWriteResponse.class); + verify(validSizeListener).onResponse(responseCaptor.capture()); + assertEquals(finalETag, responseCaptor.getValue().getVersionIdentifier()); + verify(validSizeListener, never()).onFailure(any()); + + verify(clientReference).close(); + + CompleteMultipartUploadRequest completeRequest = completeRequestCaptor.getValue(); + assertEquals(inputETag, completeRequest.ifMatch()); + + } catch (IOException e) { + fail("Should not fail with size validation error: " + e); + } + } + } + + public void testExecuteMultipartUploadConditionallyPreconditionFailed() { + final String bucketName = randomAlphaOfLengthBetween(1, 10); + final String blobName = randomAlphaOfLengthBetween(1, 10); + final String eTag = randomAlphaOfLengthBetween(8, 32); + + final BlobPath blobPath = new BlobPath(); + final long partSize = ByteSizeUnit.MB.toBytes(5); + final long blobSize = partSize * 2; + + final S3BlobStore blobStore = mock(S3BlobStore.class); + final StatsMetricPublisher metricPublisher = new StatsMetricPublisher(); + when(blobStore.bucket()).thenReturn(bucketName); + when(blobStore.bufferSizeInBytes()).thenReturn(partSize); + when(blobStore.getStatsMetricPublisher()).thenReturn(metricPublisher); + when(blobStore.getStorageClass()).thenReturn(StorageClass.STANDARD); + + final boolean useSseKms = randomBoolean(); + final String kmsKeyId = randomAlphaOfLengthBetween(10, 20); + final String kmsContext = randomAlphaOfLengthBetween(10, 20); + final boolean useBucketKey = randomBoolean(); + if (useSseKms) { + when(blobStore.serverSideEncryptionType()).thenReturn(ServerSideEncryption.AWS_KMS.toString()); + when(blobStore.serverSideEncryptionKmsKey()).thenReturn(kmsKeyId); + when(blobStore.serverSideEncryptionBucketKey()).thenReturn(useBucketKey); + when(blobStore.serverSideEncryptionEncryptionContext()).thenReturn(kmsContext); + } else { + when(blobStore.serverSideEncryptionType()).thenReturn(ServerSideEncryption.AES256.toString()); + } + + when(blobStore.isUploadRetryEnabled()).thenReturn(false); + + final S3BlobContainer blobContainer = new S3BlobContainer(blobPath, blobStore); + + final S3Client client = mock(S3Client.class); + final AmazonS3Reference clientReference = mock(AmazonS3Reference.class); + final AmazonS3Reference abortClientReference = mock(AmazonS3Reference.class); + + when(blobStore.clientReference()).thenReturn(clientReference).thenReturn(abortClientReference); + when(clientReference.get()).thenReturn(client); + when(abortClientReference.get()).thenReturn(client); + + S3Exception preconditionFailedException = (S3Exception) S3Exception.builder() + .message("Precondition Failed") + .statusCode(S3BlobContainer.HTTP_STATUS_PRECONDITION_FAILED) + .build(); + + final String uploadId = randomAlphaOfLengthBetween(10, 20); + when(client.createMultipartUpload(any(CreateMultipartUploadRequest.class))).thenReturn( + CreateMultipartUploadResponse.builder().uploadId(uploadId).build() + ); + + when(client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))).thenReturn( + UploadPartResponse.builder().eTag("part-etag").build() + ); + + when(client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))).thenThrow(preconditionFailedException); + + when(client.abortMultipartUpload(any(AbortMultipartUploadRequest.class))).thenReturn( + AbortMultipartUploadResponse.builder().build() + ); + + final AtomicReference capturedException = new AtomicReference<>(); + ActionListener responseListener = ActionListener.wrap( + r -> fail("Should have failed with precondition failure"), + capturedException::set + ); + + final ByteArrayInputStream inputStream = new ByteArrayInputStream(new byte[(int) blobSize]); + + IOException ioException = expectThrows( + IOException.class, + () -> blobContainer.executeMultipartUploadConditionally( + blobStore, + blobName, + inputStream, + blobSize, + null, + ConditionalWriteOptions.ifMatch(eTag), + responseListener + ) + ); + + assertEquals("Unable to upload object [" + blobName + "] due to ETag mismatch", ioException.getMessage()); + assertEquals(preconditionFailedException, ioException.getCause()); + + Exception exception = capturedException.get(); + assertNotNull("Expected an exception to be captured", exception); + assertTrue("Exception should be an OpenSearchException", exception instanceof OpenSearchException); + assertEquals("Precondition Failed : Etag Mismatch", exception.getMessage()); + + verify(client).createMultipartUpload(any(CreateMultipartUploadRequest.class)); + verify(client).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); + verify(client).abortMultipartUpload(any(AbortMultipartUploadRequest.class)); + + verify(clientReference).close(); + verify(abortClientReference).close(); + } + + public void testExecuteMultipartUploadConditionallyS3ExceptionTypes() { + final String bucketName = randomAlphaOfLengthBetween(1, 10); + final String blobName = randomAlphaOfLengthBetween(1, 10); + final String eTag = randomAlphaOfLengthBetween(8, 32); + + final BlobPath blobPath = new BlobPath(); + final long partSize = ByteSizeUnit.MB.toBytes(5); + final long blobSize = partSize * 2; + + Map metadata = new HashMap<>(); + metadata.put("key1", "value1"); + + Object[][] testCases = { + { "S3Exception", 0, 403 }, + { "S3Exception", 1, 404 }, + { "S3Exception", 2, 412 }, + { "SdkException", 1, 0 }, }; + + for (Object[] testCase : testCases) { + String exceptionType = (String) testCase[0]; + int errorPhase = (int) testCase[1]; + int statusCode = (int) testCase[2]; + + final S3BlobStore blobStore = mock(S3BlobStore.class); + final StatsMetricPublisher metricPublisher = new StatsMetricPublisher(); + + when(blobStore.bucket()).thenReturn(bucketName); + when(blobStore.bufferSizeInBytes()).thenReturn(partSize); + when(blobStore.getStatsMetricPublisher()).thenReturn(metricPublisher); + when(blobStore.getStorageClass()).thenReturn(StorageClass.STANDARD); + + final boolean useSseKms = randomBoolean(); + final String kmsKeyId = randomAlphaOfLengthBetween(10, 20); + final String kmsContext = randomAlphaOfLengthBetween(10, 20); + final boolean useBucketKey = randomBoolean(); + if (useSseKms) { + when(blobStore.serverSideEncryptionType()).thenReturn(ServerSideEncryption.AWS_KMS.toString()); + when(blobStore.serverSideEncryptionKmsKey()).thenReturn(kmsKeyId); + when(blobStore.serverSideEncryptionBucketKey()).thenReturn(useBucketKey); + when(blobStore.serverSideEncryptionEncryptionContext()).thenReturn(kmsContext); + } else { + when(blobStore.serverSideEncryptionType()).thenReturn(ServerSideEncryption.AES256.toString()); + } + + final S3BlobContainer blobContainer = new S3BlobContainer(blobPath, blobStore); + + final S3Client client = mock(S3Client.class); + final AmazonS3Reference clientReference = mock(AmazonS3Reference.class); + final AmazonS3Reference abortClientReference = mock(AmazonS3Reference.class); + + when(blobStore.clientReference()).thenReturn(clientReference).thenReturn(abortClientReference); + when(clientReference.get()).thenReturn(client); + when(abortClientReference.get()).thenReturn(client); + + Exception testException; + if ("S3Exception".equals(exceptionType)) { + testException = (S3Exception) S3Exception.builder() + .message("S3 Error with status code " + statusCode) + .statusCode(statusCode) + .build(); + } else { + testException = SdkException.builder().message("SDK Error occurred").build(); + } + + final String uploadId = "test-upload-id"; + if (errorPhase >= 1) { + when(client.createMultipartUpload(any(CreateMultipartUploadRequest.class))).thenReturn( + CreateMultipartUploadResponse.builder().uploadId(uploadId).build() + ); + } else { + when(client.createMultipartUpload(any(CreateMultipartUploadRequest.class))).thenThrow(testException); + } + + if (errorPhase >= 2) { + when(client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))).thenReturn( + UploadPartResponse.builder().eTag("part-etag").build() + ); + } else if (errorPhase == 1) { + when(client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))).thenThrow(testException); + } + + if (errorPhase == 2) { + when(client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))).thenThrow(testException); + } + + when(client.abortMultipartUpload(any(AbortMultipartUploadRequest.class))).thenReturn( + AbortMultipartUploadResponse.builder().build() + ); + + final AtomicReference capturedException = new AtomicReference<>(); + ActionListener responseListener = ActionListener.wrap( + r -> fail("Should have failed with exception"), + capturedException::set + ); + + final ByteArrayInputStream inputStream = new ByteArrayInputStream(new byte[(int) blobSize]); + IOException exception = expectThrows( + IOException.class, + () -> blobContainer.executeMultipartUploadConditionally( + blobStore, + blobName, + inputStream, + blobSize, + metadata, + ConditionalWriteOptions.ifMatch(eTag), + responseListener + ) + ); + + assertEquals(testException, exception.getCause()); + + Exception listenerException = capturedException.get(); + assertNotNull("Expected a listener exception", listenerException); + + if ("S3Exception".equals(exceptionType) && statusCode == 412) { + assertTrue(listenerException instanceof OpenSearchException); + assertEquals("Precondition Failed : Etag Mismatch", listenerException.getMessage()); + } else { + assertTrue(listenerException instanceof IOException); + } + + verify(clientReference).close(); + + if (errorPhase >= 1) { + verify(client).abortMultipartUpload(any(AbortMultipartUploadRequest.class)); + verify(abortClientReference).close(); + } else { + verify(client, never()).abortMultipartUpload(any(AbortMultipartUploadRequest.class)); + } + } + } + + public void testExecuteMultipartUploadConditionallySdkException() { + final String bucketName = randomAlphaOfLengthBetween(1, 10); + final String blobName = randomAlphaOfLengthBetween(1, 10); + final String eTag = randomAlphaOfLengthBetween(8, 32); + + final BlobPath blobPath = new BlobPath(); + final long partSize = ByteSizeUnit.MB.toBytes(5); + final long blobSize = partSize * 2; + + Object[][] testScenarios = { + { "initialization error", 0, false }, + { "part upload error", 1, false }, + { "abort failure", 1, true } }; + + for (Object[] scenario : testScenarios) { + String scenarioName = (String) scenario[0]; + int errorStage = (int) scenario[1]; + boolean abortFails = (boolean) scenario[2]; + + final S3BlobStore blobStore = mock(S3BlobStore.class); + final StatsMetricPublisher metricPublisher = new StatsMetricPublisher(); + when(blobStore.bucket()).thenReturn(bucketName); + when(blobStore.bufferSizeInBytes()).thenReturn(partSize); + when(blobStore.getStatsMetricPublisher()).thenReturn(metricPublisher); + when(blobStore.getStorageClass()).thenReturn(StorageClass.STANDARD); + + final boolean useSseKms = randomBoolean(); + final String kmsKeyId = randomAlphaOfLengthBetween(10, 20); + final String kmsContext = randomAlphaOfLengthBetween(10, 20); + final boolean useBucketKey = randomBoolean(); + if (useSseKms) { + when(blobStore.serverSideEncryptionType()).thenReturn(ServerSideEncryption.AWS_KMS.toString()); + when(blobStore.serverSideEncryptionKmsKey()).thenReturn(kmsKeyId); + when(blobStore.serverSideEncryptionBucketKey()).thenReturn(useBucketKey); + when(blobStore.serverSideEncryptionEncryptionContext()).thenReturn(kmsContext); + } else { + when(blobStore.serverSideEncryptionType()).thenReturn(ServerSideEncryption.AES256.toString()); + } + + final S3BlobContainer blobContainer = new S3BlobContainer(blobPath, blobStore); + + final S3Client client = mock(S3Client.class); + final AmazonS3Reference clientReference = mock(AmazonS3Reference.class); + final AmazonS3Reference abortClientReference = mock(AmazonS3Reference.class); + when(blobStore.clientReference()).thenReturn(clientReference).thenReturn(abortClientReference); + when(clientReference.get()).thenReturn(client); + when(abortClientReference.get()).thenReturn(client); + + SdkException primaryException = SdkException.builder().message("SDK error during " + scenarioName).build(); + + final String uploadId = "test-upload-id"; + if (errorStage == 0) { + when(client.createMultipartUpload(any(CreateMultipartUploadRequest.class))).thenThrow(primaryException); + } else { + when(client.createMultipartUpload(any(CreateMultipartUploadRequest.class))).thenReturn( + CreateMultipartUploadResponse.builder().uploadId(uploadId).build() + ); + + when(client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))).thenThrow(primaryException); + } + + ArgumentCaptor abortCaptor = ArgumentCaptor.forClass(AbortMultipartUploadRequest.class); + + if (abortFails) { + SdkException abortException = SdkException.builder().message("Abort failure").build(); + when(client.abortMultipartUpload(abortCaptor.capture())).thenThrow(abortException); + } else { + when(client.abortMultipartUpload(abortCaptor.capture())).thenReturn(AbortMultipartUploadResponse.builder().build()); + } + + final AtomicReference capturedException = new AtomicReference<>(); + ActionListener responseListener = ActionListener.wrap( + r -> fail("Should have failed with SdkException"), + capturedException::set + ); + + final ByteArrayInputStream inputStream = new ByteArrayInputStream(new byte[(int) blobSize]); + IOException exception = expectThrows( + IOException.class, + () -> blobContainer.executeMultipartUploadConditionally( + blobStore, + blobName, + inputStream, + blobSize, + null, + ConditionalWriteOptions.ifMatch(eTag), + responseListener + ) + ); + + assertEquals("Original SdkException should be preserved as cause", primaryException, exception.getCause()); + + Exception listenerException = capturedException.get(); + assertNotNull("Expected an exception", listenerException); + assertTrue("Exception should be IOException", listenerException instanceof IOException); + + verify(clientReference).close(); + + if (errorStage > 0) { + verify(client).abortMultipartUpload(any(AbortMultipartUploadRequest.class)); + verify(abortClientReference).close(); + + if (!abortCaptor.getAllValues().isEmpty()) { + AbortMultipartUploadRequest abortRequest = abortCaptor.getValue(); + assertEquals("Abort request should have correct upload ID", uploadId, abortRequest.uploadId()); + assertEquals("Abort request should have correct bucket", bucketName, abortRequest.bucket()); + assertEquals("Abort request should have correct key", blobPath.buildAsString() + blobName, abortRequest.key()); + } + } else { + verify(client, never()).abortMultipartUpload(any(AbortMultipartUploadRequest.class)); + } + } + } + + public void testExecuteMultipartUploadConditionallyResourceManagement() throws IOException { + final String bucketName = randomAlphaOfLengthBetween(1, 10); + final String blobName = randomAlphaOfLengthBetween(1, 10); + final String eTag = randomAlphaOfLengthBetween(8, 32); + final String uploadId = randomAlphaOfLengthBetween(10, 20); + + final BlobPath blobPath = new BlobPath(); + final long partSize = ByteSizeUnit.MB.toBytes(5); + final long blobSize = partSize * 2; + + enum ResourceScenario { + SUCCESS_PATH, + INIT_FAILURE, + PART_UPLOAD_FAILURE, + COMPLETION_FAILURE, + ABORT_FAILURE + } + + for (ResourceScenario scenario : ResourceScenario.values()) { + final S3BlobStore blobStore = mock(S3BlobStore.class); + final StatsMetricPublisher metricPublisher = new StatsMetricPublisher(); + when(blobStore.bucket()).thenReturn(bucketName); + when(blobStore.bufferSizeInBytes()).thenReturn(partSize); + when(blobStore.getStatsMetricPublisher()).thenReturn(metricPublisher); + when(blobStore.getStorageClass()).thenReturn(StorageClass.STANDARD); + + final boolean useSseKms = randomBoolean(); + final String kmsKeyId = randomAlphaOfLengthBetween(10, 20); + final String kmsContext = randomAlphaOfLengthBetween(10, 20); + final boolean useBucketKey = randomBoolean(); + if (useSseKms) { + when(blobStore.serverSideEncryptionType()).thenReturn(ServerSideEncryption.AWS_KMS.toString()); + when(blobStore.serverSideEncryptionKmsKey()).thenReturn(kmsKeyId); + when(blobStore.serverSideEncryptionBucketKey()).thenReturn(useBucketKey); + when(blobStore.serverSideEncryptionEncryptionContext()).thenReturn(kmsContext); + } else { + when(blobStore.serverSideEncryptionType()).thenReturn(ServerSideEncryption.AES256.toString()); + } + + final S3Client client = mock(S3Client.class); + + AmazonS3Reference primaryClientReference = mock(AmazonS3Reference.class); + AmazonS3Reference abortClientReference = mock(AmazonS3Reference.class); + + when(primaryClientReference.get()).thenReturn(client); + when(abortClientReference.get()).thenReturn(client); + + when(blobStore.clientReference()).thenReturn(primaryClientReference).thenReturn(abortClientReference); + + final S3BlobContainer blobContainer = new S3BlobContainer(blobPath, blobStore); + + SdkException stageException = SdkException.builder().message("Failure during " + scenario.name()).build(); + + ArgumentCaptor abortCaptor = ArgumentCaptor.forClass(AbortMultipartUploadRequest.class); + + switch (scenario) { + case SUCCESS_PATH: + when(client.createMultipartUpload(any(CreateMultipartUploadRequest.class))).thenReturn( + CreateMultipartUploadResponse.builder().uploadId(uploadId).build() + ); + when(client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))).thenReturn( + UploadPartResponse.builder().eTag("test-etag").build() + ); + when(client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))).thenReturn( + CompleteMultipartUploadResponse.builder().eTag("final-etag").build() + ); + break; + + case INIT_FAILURE: + when(client.createMultipartUpload(any(CreateMultipartUploadRequest.class))).thenThrow(stageException); + break; + + case PART_UPLOAD_FAILURE: + when(client.createMultipartUpload(any(CreateMultipartUploadRequest.class))).thenReturn( + CreateMultipartUploadResponse.builder().uploadId(uploadId).build() + ); + when(client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))).thenThrow(stageException); + when(client.abortMultipartUpload(abortCaptor.capture())).thenReturn(AbortMultipartUploadResponse.builder().build()); + break; + + case COMPLETION_FAILURE: + when(client.createMultipartUpload(any(CreateMultipartUploadRequest.class))).thenReturn( + CreateMultipartUploadResponse.builder().uploadId(uploadId).build() + ); + when(client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))).thenReturn( + UploadPartResponse.builder().eTag("test-etag").build() + ); + when(client.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))).thenThrow(stageException); + when(client.abortMultipartUpload(abortCaptor.capture())).thenReturn(AbortMultipartUploadResponse.builder().build()); + break; + + case ABORT_FAILURE: + when(client.createMultipartUpload(any(CreateMultipartUploadRequest.class))).thenReturn( + CreateMultipartUploadResponse.builder().uploadId(uploadId).build() + ); + when(client.uploadPart(any(UploadPartRequest.class), any(RequestBody.class))).thenThrow(stageException); + when(client.abortMultipartUpload(abortCaptor.capture())).thenThrow( + SdkException.builder().message("Abort failure").build() + ); + break; + } + + @SuppressWarnings("unchecked") + ActionListener responseListener = mock(ActionListener.class); + + final ByteArrayInputStream inputStream = new ByteArrayInputStream(new byte[(int) blobSize]); + + if (scenario == ResourceScenario.SUCCESS_PATH) { + blobContainer.executeMultipartUploadConditionally( + blobStore, + blobName, + inputStream, + blobSize, + null, + ConditionalWriteOptions.ifMatch(eTag), + responseListener + ); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(ConditionalWriteResponse.class); + verify(responseListener).onResponse(responseCaptor.capture()); + assertEquals("final-etag", responseCaptor.getValue().getVersionIdentifier()); + verify(responseListener, never()).onFailure(any()); + + verify(blobStore, times(1)).clientReference(); + verify(primaryClientReference).close(); + } else { + IOException exception = expectThrows( + IOException.class, + () -> blobContainer.executeMultipartUploadConditionally( + blobStore, + blobName, + inputStream, + blobSize, + null, + ConditionalWriteOptions.ifMatch(eTag), + responseListener + ) + ); + + assertEquals("Exception cause should be the original exception", stageException, exception.getCause()); + + verify(responseListener).onFailure(any(Exception.class)); + verify(responseListener, never()).onResponse(any()); + + verify(primaryClientReference).close(); + + if (scenario != ResourceScenario.INIT_FAILURE) { + verify(blobStore, times(2)).clientReference(); + verify(client).abortMultipartUpload(any(AbortMultipartUploadRequest.class)); + verify(abortClientReference).close(); + + if (!abortCaptor.getAllValues().isEmpty()) { + AbortMultipartUploadRequest abortRequest = abortCaptor.getValue(); + assertEquals("Upload ID should match", uploadId, abortRequest.uploadId()); + assertEquals("Bucket should match", bucketName, abortRequest.bucket()); + assertEquals("Key should match", blobPath.buildAsString() + blobName, abortRequest.key()); + } + } else { + verify(blobStore, times(1)).clientReference(); + verify(client, never()).abortMultipartUpload(any(AbortMultipartUploadRequest.class)); + } + } + } + } + public void testInitCannedACL() { String[] aclList = new String[] { "private", diff --git a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/async/AsyncTransferManagerTests.java b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/async/AsyncTransferManagerTests.java index 8b5ab0333997a..91383784a3b7f 100644 --- a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/async/AsyncTransferManagerTests.java +++ b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/async/AsyncTransferManagerTests.java @@ -29,6 +29,8 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.common.StreamContext; +import org.opensearch.common.blobstore.ConditionalWrite.ConditionalWriteOptions; +import org.opensearch.common.blobstore.ConditionalWrite.ConditionalWriteResponse; import org.opensearch.common.blobstore.exception.CorruptFileException; import org.opensearch.common.blobstore.stream.write.WritePriority; import org.opensearch.common.io.InputStreamContainer; @@ -97,8 +99,9 @@ public void testOneChunkUpload() { CompletableFuture resultFuture = asyncTransferManager.uploadObject( s3AsyncClient, new UploadRequest("bucket", "key", ByteSizeUnit.MB.toBytes(1), WritePriority.HIGH, uploadSuccess -> { - // do nothing - }, false, null, true, metadata, ServerSideEncryption.AWS_KMS.toString(), randomAlphaOfLength(10), true, null, null), + + }, false, null, true, metadata, null, ServerSideEncryption.AWS_KMS.toString(), randomAlphaOfLength(10), true, null, null), + new StreamContext((partIdx, partSize, position) -> { streamRef.set(new ZeroInputStream(partSize)); return new InputStreamContainer(streamRef.get(), partSize, position); @@ -145,9 +148,23 @@ public void testOneChunkUploadCorruption() { CompletableFuture resultFuture = asyncTransferManager.uploadObject( s3AsyncClient, - new UploadRequest("bucket", "key", ByteSizeUnit.MB.toBytes(1), WritePriority.HIGH, uploadSuccess -> { - // do nothing - }, false, null, true, metadata, ServerSideEncryption.AWS_KMS.toString(), randomAlphaOfLength(10), true, null, null), + new UploadRequest( + "bucket", + "key", + ByteSizeUnit.MB.toBytes(1), + WritePriority.HIGH, + uploadSuccess -> {}, + false, + null, + true, + metadata, + null, + ServerSideEncryption.AWS_KMS.toString(), + randomAlphaOfLength(10), + true, + null, + null + ), new StreamContext( (partIdx, partSize, position) -> new InputStreamContainer(new ZeroInputStream(partSize), partSize, position), ByteSizeUnit.MB.toBytes(1), @@ -203,8 +220,9 @@ public void testMultipartUpload() { CompletableFuture resultFuture = asyncTransferManager.uploadObject( s3AsyncClient, new UploadRequest("bucket", "key", ByteSizeUnit.MB.toBytes(5), WritePriority.HIGH, uploadSuccess -> { - // do nothing - }, true, 3376132981L, true, metadata, ServerSideEncryption.AWS_KMS.toString(), randomAlphaOfLength(10), true, null, null), + + }, true, 3376132981L, true, metadata, null, ServerSideEncryption.AWS_KMS.toString(), randomAlphaOfLength(10), true, null, null), + new StreamContext((partIdx, partSize, position) -> { InputStream stream = new ZeroInputStream(partSize); streams.add(stream); @@ -267,8 +285,9 @@ public void testMultipartUploadCorruption() { CompletableFuture resultFuture = asyncTransferManager.uploadObject( s3AsyncClient, new UploadRequest("bucket", "key", ByteSizeUnit.MB.toBytes(5), WritePriority.HIGH, uploadSuccess -> { - // do nothing - }, true, 0L, true, metadata, ServerSideEncryption.AWS_KMS.toString(), randomAlphaOfLength(10), true, null, null), + + }, true, 0L, true, metadata, null, ServerSideEncryption.AWS_KMS.toString(), randomAlphaOfLength(10), true, null, null), + new StreamContext( (partIdx, partSize, position) -> new InputStreamContainer(new ZeroInputStream(partSize), partSize, position), ByteSizeUnit.MB.toBytes(1), @@ -292,4 +311,349 @@ public void testMultipartUploadCorruption() { verify(s3AsyncClient, times(0)).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); verify(s3AsyncClient, times(1)).abortMultipartUpload(any(AbortMultipartUploadRequest.class)); } + + public void testConditionalOneChunkUpload() { + CompletableFuture putObjectResponseCompletableFuture = new CompletableFuture<>(); + putObjectResponseCompletableFuture.complete(PutObjectResponse.builder().eTag("test-etag-1234").build()); + when(s3AsyncClient.putObject(any(PutObjectRequest.class), any(AsyncRequestBody.class))).thenReturn( + putObjectResponseCompletableFuture + ); + + ConditionalWriteOptions options = ConditionalWriteOptions.ifMatch("old-etag-value"); + AtomicReference streamRef = new AtomicReference<>(); + Map metadata = new HashMap<>(); + metadata.put("key1", "value1"); + metadata.put("key2", "value2"); + + CompletableFuture resultFuture = asyncTransferManager.uploadObjectConditionally( + s3AsyncClient, + new UploadRequest( + "bucket", + "key", + ByteSizeUnit.MB.toBytes(1), + WritePriority.HIGH, + uploadSuccess -> {}, + false, + null, + true, + metadata, + options, + ServerSideEncryption.UNKNOWN_TO_SDK_VERSION.toString(), + null, + false, + null, + null + ), + new StreamContext((partIdx, partSize, position) -> { + streamRef.set(new ZeroInputStream(partSize)); + return new InputStreamContainer(streamRef.get(), partSize, position); + }, ByteSizeUnit.MB.toBytes(1), ByteSizeUnit.MB.toBytes(1), 1), + new StatsMetricPublisher() + ); + + try { + ConditionalWriteResponse response = resultFuture.get(); + assertNotNull("Response should not be null", response); + assertEquals("ETag should match expected value", "test-etag-1234", response.getVersionIdentifier()); + } catch (ExecutionException | InterruptedException e) { + fail("Did not expect resultFuture to fail: " + e.getMessage()); + } + + verify(s3AsyncClient, times(1)).putObject(any(PutObjectRequest.class), any(AsyncRequestBody.class)); + + boolean closeError = false; + try { + streamRef.get().available(); + } catch (IOException e) { + closeError = e.getMessage().equals("Stream closed"); + } + assertTrue("InputStream was still open after upload", closeError); + } + + public void testConditionalOneChunkUploadPreconditionFailed() { + CompletableFuture putObjectResponseCompletableFuture = new CompletableFuture<>(); + S3Exception mockException = (S3Exception) S3Exception.builder().statusCode(412).message("Precondition Failed").build(); + + putObjectResponseCompletableFuture.completeExceptionally(mockException); + when(s3AsyncClient.putObject(any(PutObjectRequest.class), any(AsyncRequestBody.class))).thenReturn( + putObjectResponseCompletableFuture + ); + + String etag = "non-matching-etag"; + ConditionalWriteOptions options = ConditionalWriteOptions.ifMatch(etag); + + AtomicReference streamRef = new AtomicReference<>(); + CompletableFuture resultFuture = asyncTransferManager.uploadObjectConditionally( + s3AsyncClient, + new UploadRequest( + "bucket", + "key", + ByteSizeUnit.MB.toBytes(1), + WritePriority.HIGH, + uploadSuccess -> {}, + false, + null, + true, + null, + options, + ServerSideEncryption.UNKNOWN_TO_SDK_VERSION.toString(), + null, + false, + null, + null + ), + new StreamContext((partIdx, partSize, position) -> { + streamRef.set(new ZeroInputStream(partSize)); + return new InputStreamContainer(streamRef.get(), partSize, position); + }, ByteSizeUnit.MB.toBytes(1), ByteSizeUnit.MB.toBytes(1), 1), + new StatsMetricPublisher() + ); + + try { + resultFuture.get(); + fail("Expected an exception for precondition failed"); + } catch (ExecutionException | InterruptedException e) { + Throwable cause = e.getCause(); + + assertTrue("Should be S3Exception", cause instanceof S3Exception); + + S3Exception s3e = (S3Exception) cause; + + assertEquals("Should have 412 status code", 412, s3e.statusCode()); + + assertNotNull("Exception should have a message", s3e.getMessage()); + assertFalse("Exception message should not be empty", s3e.getMessage().isEmpty()); + } + + verify(s3AsyncClient, times(1)).putObject(any(PutObjectRequest.class), any(AsyncRequestBody.class)); + + boolean closeError = false; + try { + streamRef.get().available(); + } catch (IOException e) { + closeError = e.getMessage().equals("Stream closed"); + } + assertTrue("InputStream was still open after upload", closeError); + } + + public void testConditionalOneChunkUploadCorruption() { + CompletableFuture putObjectResponseCompletableFuture = new CompletableFuture<>(); + putObjectResponseCompletableFuture.completeExceptionally( + S3Exception.builder() + .statusCode(HttpStatusCode.BAD_REQUEST) + .awsErrorDetails(AwsErrorDetails.builder().errorCode("BadDigest").build()) + .build() + ); + when(s3AsyncClient.putObject(any(PutObjectRequest.class), any(AsyncRequestBody.class))).thenReturn( + putObjectResponseCompletableFuture + ); + + CompletableFuture deleteObjectResponseCompletableFuture = new CompletableFuture<>(); + deleteObjectResponseCompletableFuture.complete(DeleteObjectResponse.builder().build()); + when(s3AsyncClient.deleteObject(any(DeleteObjectRequest.class))).thenReturn(deleteObjectResponseCompletableFuture); + ConditionalWriteOptions options = ConditionalWriteOptions.ifMatch("test-etag"); + AtomicReference streamRef = new AtomicReference<>(); + CompletableFuture resultFuture = asyncTransferManager.uploadObjectConditionally( + s3AsyncClient, + new UploadRequest( + "bucket", + "key", + ByteSizeUnit.MB.toBytes(1), + WritePriority.HIGH, + uploadSuccess -> {}, + false, + null, + true, + null, + options, + ServerSideEncryption.UNKNOWN_TO_SDK_VERSION.toString(), + null, + false, + null, + null + ), + new StreamContext((partIdx, partSize, position) -> { + streamRef.set(new ZeroInputStream(partSize)); + return new InputStreamContainer(streamRef.get(), partSize, position); + }, ByteSizeUnit.MB.toBytes(1), ByteSizeUnit.MB.toBytes(1), 1), + new StatsMetricPublisher() + ); + + try { + resultFuture.get(); + fail("Expected a corruption exception"); + } catch (ExecutionException | InterruptedException e) { + Throwable throwable = ExceptionsHelper.unwrap(e, CorruptFileException.class); + assertNotNull("Exception should be a CorruptFileException", throwable); + assertTrue("Exception should be a CorruptFileException", throwable instanceof CorruptFileException); + } + + verify(s3AsyncClient, times(1)).putObject(any(PutObjectRequest.class), any(AsyncRequestBody.class)); + verify(s3AsyncClient, times(1)).deleteObject(any(DeleteObjectRequest.class)); + + boolean closeError = false; + try { + streamRef.get().available(); + } catch (IOException e) { + closeError = e.getMessage().equals("Stream closed"); + } + assertTrue("InputStream was still open after upload", closeError); + } + + public void testConditionalMultipartUploadPreconditionFailed() { + CompletableFuture createMultipartUploadRequestCompletableFuture = new CompletableFuture<>(); + createMultipartUploadRequestCompletableFuture.complete(CreateMultipartUploadResponse.builder().uploadId("uploadId").build()); + when(s3AsyncClient.createMultipartUpload(any(CreateMultipartUploadRequest.class))).thenReturn( + createMultipartUploadRequestCompletableFuture + ); + + CompletableFuture uploadPartResponseCompletableFuture = new CompletableFuture<>(); + uploadPartResponseCompletableFuture.complete(UploadPartResponse.builder().checksumCRC32("pzjqHA==").build()); + when(s3AsyncClient.uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class))).thenReturn( + uploadPartResponseCompletableFuture + ); + + CompletableFuture completeMultipartUploadResponseCompletableFuture = new CompletableFuture<>(); + completeMultipartUploadResponseCompletableFuture.completeExceptionally( + S3Exception.builder().statusCode(412).message("Precondition Failed").build() + ); + when(s3AsyncClient.completeMultipartUpload(any(CompleteMultipartUploadRequest.class))).thenReturn( + completeMultipartUploadResponseCompletableFuture + ); + + CompletableFuture abortMultipartUploadResponseCompletableFuture = new CompletableFuture<>(); + abortMultipartUploadResponseCompletableFuture.complete(AbortMultipartUploadResponse.builder().build()); + when(s3AsyncClient.abortMultipartUpload(any(AbortMultipartUploadRequest.class))).thenReturn( + abortMultipartUploadResponseCompletableFuture + ); + + ConditionalWriteOptions options = ConditionalWriteOptions.ifMatch("non-matching-etag"); + + List streams = new ArrayList<>(); + CompletableFuture resultFuture = asyncTransferManager.uploadObjectConditionally( + s3AsyncClient, + new UploadRequest( + "bucket", + "key", + ByteSizeUnit.MB.toBytes(5), + WritePriority.HIGH, + uploadSuccess -> {}, + true, + 3376132981L, + true, + null, + options, + ServerSideEncryption.UNKNOWN_TO_SDK_VERSION.toString(), + null, + false, + null, + null + ), + new StreamContext((partIdx, partSize, position) -> { + InputStream stream = new ZeroInputStream(partSize); + streams.add(stream); + return new InputStreamContainer(stream, partSize, position); + }, ByteSizeUnit.MB.toBytes(1), ByteSizeUnit.MB.toBytes(1), 5), + new StatsMetricPublisher() + ); + + try { + resultFuture.get(); + fail("Expected an exception for precondition failed"); + } catch (ExecutionException | InterruptedException e) { + Throwable cause = e.getCause(); + assertTrue("Should be S3Exception", cause instanceof S3Exception); + S3Exception s3e = (S3Exception) cause; + assertEquals("Should have 412 status code", 412, s3e.statusCode()); + assertTrue("Message should indicate condition failure", s3e.getMessage().contains("Conditional write failed")); + } + + verify(s3AsyncClient, times(1)).createMultipartUpload(any(CreateMultipartUploadRequest.class)); + verify(s3AsyncClient, times(5)).uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class)); + verify(s3AsyncClient, times(1)).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); + verify(s3AsyncClient, times(1)).abortMultipartUpload(any(AbortMultipartUploadRequest.class)); + + for (InputStream stream : streams) { + boolean closeError = false; + try { + stream.available(); + } catch (IOException e) { + closeError = e.getMessage().equals("Stream closed"); + } + assertTrue("InputStream was still open after upload", closeError); + } + } + + public void testConditionalMultipartUploadCorruption() { + CompletableFuture createMultipartUploadRequestCompletableFuture = new CompletableFuture<>(); + createMultipartUploadRequestCompletableFuture.complete(CreateMultipartUploadResponse.builder().uploadId("uploadId").build()); + when(s3AsyncClient.createMultipartUpload(any(CreateMultipartUploadRequest.class))).thenReturn( + createMultipartUploadRequestCompletableFuture + ); + + CompletableFuture uploadPartResponseCompletableFuture = new CompletableFuture<>(); + uploadPartResponseCompletableFuture.complete(UploadPartResponse.builder().checksumCRC32("pzjqHA==").build()); + when(s3AsyncClient.uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class))).thenReturn( + uploadPartResponseCompletableFuture + ); + + CompletableFuture abortMultipartUploadResponseCompletableFuture = new CompletableFuture<>(); + abortMultipartUploadResponseCompletableFuture.complete(AbortMultipartUploadResponse.builder().build()); + when(s3AsyncClient.abortMultipartUpload(any(AbortMultipartUploadRequest.class))).thenReturn( + abortMultipartUploadResponseCompletableFuture + ); + + ConditionalWriteOptions options = ConditionalWriteOptions.ifMatch("test-etag"); + + List streams = new ArrayList<>(); + CompletableFuture resultFuture = asyncTransferManager.uploadObjectConditionally( + s3AsyncClient, + new UploadRequest( + "bucket", + "key", + ByteSizeUnit.MB.toBytes(5), + WritePriority.HIGH, + uploadSuccess -> {}, + true, + 0L, + true, + null, + options, + ServerSideEncryption.UNKNOWN_TO_SDK_VERSION.toString(), + null, + false, + null, + null + ), + new StreamContext((partIdx, partSize, position) -> { + InputStream stream = new ZeroInputStream(partSize); + streams.add(stream); + return new InputStreamContainer(stream, partSize, position); + }, ByteSizeUnit.MB.toBytes(1), ByteSizeUnit.MB.toBytes(1), 5), + new StatsMetricPublisher() + ); + + try { + resultFuture.get(); + fail("Expected a corruption exception"); + } catch (ExecutionException | InterruptedException e) { + Throwable throwable = ExceptionsHelper.unwrap(e, CorruptFileException.class); + assertNotNull("Exception should be a CorruptFileException", throwable); + assertTrue("Exception should be a CorruptFileException", throwable instanceof CorruptFileException); + } + verify(s3AsyncClient, times(1)).createMultipartUpload(any(CreateMultipartUploadRequest.class)); + verify(s3AsyncClient, times(5)).uploadPart(any(UploadPartRequest.class), any(AsyncRequestBody.class)); + verify(s3AsyncClient, times(0)).completeMultipartUpload(any(CompleteMultipartUploadRequest.class)); + verify(s3AsyncClient, times(1)).abortMultipartUpload(any(AbortMultipartUploadRequest.class)); + + for (InputStream stream : streams) { + boolean closeError = false; + try { + stream.available(); + } catch (IOException e) { + closeError = e.getMessage().equals("Stream closed"); + } + assertTrue("InputStream was still open after upload", closeError); + } + } } diff --git a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamBlobContainer.java b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamBlobContainer.java index b769cdc2fe7ab..08bdd90f617af 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamBlobContainer.java +++ b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamBlobContainer.java @@ -9,6 +9,8 @@ package org.opensearch.common.blobstore; import org.opensearch.common.annotation.ExperimentalApi; +import org.opensearch.common.blobstore.ConditionalWrite.ConditionalWriteOptions; +import org.opensearch.common.blobstore.ConditionalWrite.ConditionalWriteResponse; import org.opensearch.common.blobstore.stream.read.ReadContext; import org.opensearch.common.blobstore.stream.write.WriteContext; import org.opensearch.core.action.ActionListener; @@ -35,6 +37,24 @@ public interface AsyncMultiStreamBlobContainer extends BlobContainer { */ void asyncBlobUpload(WriteContext writeContext, ActionListener completionListener) throws IOException; + /** + * Reads blob content basis a preconditional requirement, from multiple streams each from a specific part of the file, which is provided by the + * StreamContextSupplier in the WriteContext passed to this method. An {@link IOException} is thrown if reading + * any of the input streams fails, or writing to the target blob fails + * + * @param writeContext A WriteContext object encapsulating all information needed to perform the upload + * @param options The {@link ConditionalWriteOptions} specifying the preconditions that must be met for the upload to proceed. + * @param completionListener The {@link ActionListener} to which upload events and the result will be published. + * @throws IOException if any of the input streams could not be read, or the target blob could not be written to + */ + default void asyncBlobUploadConditionally( + WriteContext writeContext, + ConditionalWriteOptions options, + ActionListener completionListener + ) throws IOException { + throw new UnsupportedOperationException("asyncBlobUploadConditionally is not implemented yet"); + }; + /** * Creates an async callback of a {@link ReadContext} containing the multipart streams for a specified blob within the container. * @param blobName The name of the blob for which the {@link ReadContext} needs to be fetched. diff --git a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java index 286c01f9dca44..42019e49175ff 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java +++ b/server/src/main/java/org/opensearch/common/blobstore/AsyncMultiStreamEncryptedBlobContainer.java @@ -44,6 +44,16 @@ public void asyncBlobUpload(WriteContext writeContext, ActionListener comp blobContainer.asyncBlobUpload(encryptedWriteContext, completionListener); } + @Override + public void asyncBlobUploadConditionally( + WriteContext writeContext, + ConditionalWrite.ConditionalWriteOptions options, + ActionListener completionListener + ) throws IOException { + EncryptedWriteContext encryptedWriteContext = new EncryptedWriteContext<>(writeContext, cryptoHandler); + blobContainer.asyncBlobUploadConditionally(encryptedWriteContext, options, completionListener); + } + @Override public void readBlobAsync(String blobName, ActionListener listener) { try { diff --git a/server/src/main/java/org/opensearch/common/blobstore/ConditionalWrite.java b/server/src/main/java/org/opensearch/common/blobstore/ConditionalWrite.java new file mode 100644 index 0000000000000..2865c85e45ac5 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/blobstore/ConditionalWrite.java @@ -0,0 +1,171 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore; + +import java.time.Instant; + +/** + * Utility classes supporting conditional write operations on a {@link BlobContainer}. + * The main entry points are {@link ConditionalWriteOptions} for specifying conditions, and + * {@link ConditionalWriteResponse} for receiving the result of a conditional write. + */ +public final class ConditionalWrite { + private ConditionalWrite() {} + + /** + * Encapsulates options controlling preconditions to be deployed when a blob is to be written to the remote store. + * Immutable and thread-safe. Use the provided static factory methods or the {@link Builder} + * to construct instances with the desired conditions. These options can be supplied to + * blob store write operations to enforce preconditions + * + */ + public static final class ConditionalWriteOptions { + + private final boolean ifNotExists; + private final boolean ifMatch; + private final boolean ifUnmodifiedSince; + private final String versionIdentifier; + private final Instant unmodifiedSince; + + private ConditionalWriteOptions(Builder builder) { + this.ifNotExists = builder.ifNotExists; + this.ifMatch = builder.ifMatch; + this.ifUnmodifiedSince = builder.ifUnmodifiedSince; + this.versionIdentifier = builder.versionIdentifier; + this.unmodifiedSince = builder.unmodifiedSince; + } + + public static ConditionalWriteOptions none() { + return new Builder().build(); + } + + public static ConditionalWriteOptions ifNotExists() { + return new Builder().setIfNotExists(true).build(); + } + + public static ConditionalWriteOptions ifMatch(String versionIdentifier) { + return new Builder().setIfMatch(true).setVersionIdentifier(versionIdentifier).build(); + } + + public static ConditionalWriteOptions ifUnmodifiedSince(Instant ts) { + return new Builder().setIfUnmodifiedSince(true).setUnmodifiedSince(ts).build(); + } + + /** + * Returns a new {@link Builder} for constructing custom conditional write options. + */ + public static Builder builder() { + return new Builder(); + } + + public boolean isIfNotExists() { + return ifNotExists; + } + + public boolean isIfMatch() { + return ifMatch; + } + + public boolean isIfUnmodifiedSince() { + return ifUnmodifiedSince; + } + + public String getVersionIdentifier() { + return versionIdentifier; + } + + public Instant getUnmodifiedSince() { + return unmodifiedSince; + } + + /** + * Builder for {@link ConditionalWriteOptions}. + * Allows fine-grained construction of conditional write criteria. + */ + public static final class Builder { + private boolean ifNotExists = false; + private boolean ifMatch = false; + private boolean ifUnmodifiedSince = false; + private String versionIdentifier = null; + private Instant unmodifiedSince = null; + + private Builder() {} + + /** + * Sets the write to succeed only if the blob does not exist. + * @param flag true to enable this condition + * @return this builder + */ + public Builder setIfNotExists(boolean flag) { + this.ifNotExists = flag; + return this; + } + + /** + * Sets the write to succeed only if the blob matches the expected version. + * @param flag true to enable this condition + * @return this builder + */ + public Builder setIfMatch(boolean flag) { + this.ifMatch = flag; + return this; + } + + /** + * Sets the write to succeed only if the blob was not modified since a given instant. + * @param flag true to enable this condition + * @return this builder + */ + public Builder setIfUnmodifiedSince(boolean flag) { + this.ifUnmodifiedSince = flag; + return this; + } + + /** + * Sets the timestamp before which the blob must remain unmodified. + * @param ts the instant to check + * @return this builder + */ + public Builder setUnmodifiedSince(Instant ts) { + this.unmodifiedSince = ts; + return this; + } + + public Builder setVersionIdentifier(String versionIdentifier) { + this.versionIdentifier = versionIdentifier; + return this; + } + + public ConditionalWriteOptions build() { + return new ConditionalWriteOptions(this); + } + } + } + + /** + * encapsulates the result of a conditional write operation. + * Contains the new version identifier (such as an ETag or version string) retrieved from the remote store + * after a successful write. + */ + public static final class ConditionalWriteResponse { + private final String newVersionIdentifier; + + private ConditionalWriteResponse(String versionIdentifier) { + this.newVersionIdentifier = versionIdentifier; + } + + public static ConditionalWriteResponse success(String versionIdentifier) { + return new ConditionalWriteResponse(versionIdentifier); + } + + public String getVersionIdentifier() { + return newVersionIdentifier; + } + } +}