Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// ------------------------------------------------------------
// ------------------------------------------------------------
// Copyright (c) Microsoft Corporation. All rights reserved.
// ------------------------------------------------------------

Expand All @@ -7,6 +7,7 @@ namespace Microsoft.Azure.Cosmos
using System;
using System.Collections.Generic;
using System.IO;
using System.Net;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Azure.Cosmos.Core.Trace;
Expand All @@ -15,15 +16,29 @@ namespace Microsoft.Azure.Cosmos

internal class DistributedTransactionCommitter
{
private const int MaxRetryAttempts = 3;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a timeout for each retry? is it same for all retries?

private static readonly TimeSpan DefaultRetryBaseDelay = TimeSpan.FromSeconds(1);
private static readonly string ResourceUri = Paths.OperationsPathSegment + "/" + Paths.Operations_Dtc;

private readonly IReadOnlyList<DistributedTransactionOperation> operations;
private readonly CosmosClientContext clientContext;
private readonly TimeSpan retryBaseDelay;

public DistributedTransactionCommitter(
IReadOnlyList<DistributedTransactionOperation> operations,
CosmosClientContext clientContext)
: this(operations, clientContext, DefaultRetryBaseDelay)
{
}

internal DistributedTransactionCommitter(
IReadOnlyList<DistributedTransactionOperation> operations,
CosmosClientContext clientContext,
TimeSpan retryBaseDelay)
{
this.operations = operations ?? throw new ArgumentNullException(nameof(operations));
this.clientContext = clientContext ?? throw new ArgumentNullException(nameof(clientContext));
this.retryBaseDelay = retryBaseDelay;
}

public async Task<DistributedTransactionResponse> CommitTransactionAsync(CancellationToken cancellationToken)
Expand All @@ -41,27 +56,69 @@ await DistributedTransactionCommitterUtils.ResolveCollectionRidsAsync(
this.clientContext.SerializerCore,
cancellationToken);

return await this.ExecuteCommitAsync(serverRequest, cancellationToken);
return await this.ExecuteCommitWithRetryAsync(serverRequest, cancellationToken);
}
catch (Exception ex)
catch (Exception ex) when (ex is not OperationCanceledException)
{
DefaultTrace.TraceError($"Distributed transaction failed: {ex.Message}");
// await this.AbortTransactionAsync(cancellationToken);
throw;
}
}

private async Task<DistributedTransactionResponse> ExecuteCommitWithRetryAsync(
DistributedTransactionServerRequest serverRequest,
CancellationToken cancellationToken)
{
for (int attempt = 0; attempt <= MaxRetryAttempts; attempt++)
{
cancellationToken.ThrowIfCancellationRequested();
bool canRetry = attempt < MaxRetryAttempts;

DistributedTransactionResponse response;
try
{
response = await this.ExecuteCommitAsync(serverRequest, cancellationToken);
}
catch (CosmosException cosmosEx) when (
!cancellationToken.IsCancellationRequested
&& canRetry
&& cosmosEx.StatusCode == HttpStatusCode.RequestTimeout)
{
DefaultTrace.TraceWarning(
$"Distributed transaction commit timed out (attempt {attempt + 1}/{MaxRetryAttempts + 1}). " +
$"Retrying with idempotency token {serverRequest.IdempotencyToken}.");
await Task.Delay(TimeSpan.FromTicks((long)(this.retryBaseDelay.Ticks * Math.Pow(2, attempt))), cancellationToken);
continue;
}

if (canRetry
&& !response.IsSuccessStatusCode
&& (response.IsRetriable || response.StatusCode == HttpStatusCode.RequestTimeout))
{
DefaultTrace.TraceWarning(
$"Distributed transaction commit retriable (StatusCode={response.StatusCode}, IsRetriable={response.IsRetriable}, " +
$"attempt {attempt + 1}/{MaxRetryAttempts + 1}). Retrying with idempotency token {serverRequest.IdempotencyToken}.");
response.Dispose();
await Task.Delay(TimeSpan.FromTicks((long)(this.retryBaseDelay.Ticks * Math.Pow(2, attempt))), cancellationToken);
continue;
}

return response;
}
throw new InvalidOperationException("Unexpected state: retry loop exhausted without returning.");
}

private async Task<DistributedTransactionResponse> ExecuteCommitAsync(
DistributedTransactionServerRequest serverRequest,
CancellationToken cancellationToken)
{
cancellationToken.ThrowIfCancellationRequested();
using (ITrace trace = Trace.GetRootTrace("Execute Distributed Transaction Commit", TraceComponent.Batch, TraceLevel.Info))
{
using (MemoryStream bodyStream = serverRequest.TransferBodyStream())
using (MemoryStream bodyStream = serverRequest.CreateBodyStream())
{
ResponseMessage responseMessage = await this.clientContext.ProcessResourceOperationStreamAsync(
resourceUri: DistributedTransactionCommitter.GetResourceUri(),
resourceUri: DistributedTransactionCommitter.ResourceUri,
resourceType: ResourceType.DistributedTransactionBatch,
operationType: OperationType.CommitDistributedTransaction,
requestOptions: null,
Expand All @@ -73,24 +130,19 @@ private async Task<DistributedTransactionResponse> ExecuteCommitAsync(
trace: trace,
cancellationToken: cancellationToken);

cancellationToken.ThrowIfCancellationRequested();

return await DistributedTransactionResponse.FromResponseMessageAsync(
responseMessage,
serverRequest,
this.clientContext.SerializerCore,
serverRequest.IdempotencyToken,
trace,
cancellationToken);
using (responseMessage)
{
return await DistributedTransactionResponse.FromResponseMessageAsync(
responseMessage,
serverRequest,
this.clientContext.SerializerCore,
trace,
cancellationToken);
}
}
}
}

private static string GetResourceUri()
{
return Paths.OperationsPathSegment + "/" + Paths.Operations_Dtc;
}

private static void EnrichRequestMessage(RequestMessage requestMessage, DistributedTransactionServerRequest serverRequest)
{
// Set DTC-specific headers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ internal string ResourceBodyBase64
/// <returns>The deserialized operation result.</returns>
internal static DistributedTransactionOperationResult FromJson(JsonElement json)
{
return JsonSerializer.Deserialize<DistributedTransactionOperationResult>(json);
return JsonSerializer.Deserialize<DistributedTransactionOperationResult>(json)
?? throw new JsonException($"Failed to deserialize {nameof(DistributedTransactionOperationResult)}.");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ private DistributedTransactionResponse(
CosmosSerializerCore serializer,
ITrace trace,
Guid idempotencyToken,
string serverDiagnostics = null)
bool isRetriable = false)
{
this.Headers = headers;
this.StatusCode = statusCode;
Expand All @@ -47,7 +47,7 @@ private DistributedTransactionResponse(
this.SerializerCore = serializer;
this.Trace = trace;
this.IdempotencyToken = idempotencyToken;
this.ServerDiagnostics = serverDiagnostics;
this.IsRetriable = isRetriable;
}

/// <summary>
Expand Down Expand Up @@ -110,7 +110,14 @@ public virtual DistributedTransactionOperationResult this[int index]
/// <summary>
/// Gets the number of operation results in the distributed transaction response.
/// </summary>
public virtual int Count => this.results?.Count ?? 0;
public virtual int Count
{
get
{
this.ThrowIfDisposed();
return this.results?.Count ?? 0;
}
}

/// <summary>
/// Gets the idempotency token associated with this distributed transaction.
Expand All @@ -122,6 +129,11 @@ public virtual DistributedTransactionOperationResult this[int index]
/// </summary>
public virtual string ServerDiagnostics { get; }

/// <summary>
/// Gets a value indicating whether the transaction is safe to retry with the same idempotency token.
/// </summary>
public virtual bool IsRetriable { get; }

internal virtual SubStatusCodes SubStatusCode { get; }

internal virtual CosmosSerializerCore SerializerCore { get; }
Expand All @@ -136,6 +148,7 @@ public virtual DistributedTransactionOperationResult this[int index]
/// <returns>An enumerator for the operation results.</returns>
public virtual IEnumerator<DistributedTransactionOperationResult> GetEnumerator()
{
this.ThrowIfDisposed();
return this.results?.GetEnumerator()
?? ((IList<DistributedTransactionOperationResult>)Array.Empty<DistributedTransactionOperationResult>()).GetEnumerator();
}
Expand Down Expand Up @@ -165,7 +178,6 @@ internal static async Task<DistributedTransactionResponse> FromResponseMessageAs
ResponseMessage responseMessage,
DistributedTransactionServerRequest serverRequest,
CosmosSerializerCore serializer,
Guid requestIdempotencyToken,
ITrace trace,
CancellationToken cancellationToken)
{
Expand All @@ -174,7 +186,7 @@ internal static async Task<DistributedTransactionResponse> FromResponseMessageAs
cancellationToken.ThrowIfCancellationRequested();

// Extract idempotency token from response headers, fallback to request token if not present
Guid idempotencyToken = GetIdempotencyTokenFromHeaders(responseMessage.Headers, requestIdempotencyToken);
Guid idempotencyToken = GetIdempotencyTokenFromHeaders(responseMessage.Headers, serverRequest.IdempotencyToken);

DistributedTransactionResponse response = null;
MemoryStream memoryStream = null;
Expand Down Expand Up @@ -297,16 +309,37 @@ private static async Task<DistributedTransactionResponse> PopulateFromJsonConten
CancellationToken cancellationToken)
{
List<DistributedTransactionOperationResult> results = new List<DistributedTransactionOperationResult>();
bool isRetriable = false;

// Scope the JsonException catch to document parse only so that isRetriable and
// serverDiagnostics already extracted from the root are not silently discarded
// when only the operationResponses array fails to deserialize.
JsonDocument responseJson;
try
{
using (JsonDocument responseJson = await JsonDocument.ParseAsync(content, cancellationToken: cancellationToken))
responseJson = await JsonDocument.ParseAsync(content, cancellationToken: cancellationToken);
}
catch (JsonException)
{
// Unparseable body — fall back to default response construction.
return null;
}

using (responseJson)
{
JsonElement root = responseJson.RootElement;

if (root.TryGetProperty("isRetriable", out JsonElement isRetriableElement) &&
isRetriableElement.ValueKind == JsonValueKind.True)
{
JsonElement root = responseJson.RootElement;
isRetriable = true;
}

// Parse operation results from "operationResponses" array
if (root.TryGetProperty("operationResponses", out JsonElement operationResponses) &&
operationResponses.ValueKind == JsonValueKind.Array)
// Parse operation results from "operationResponses" array.
if (root.TryGetProperty("operationResponses", out JsonElement operationResponses) &&
operationResponses.ValueKind == JsonValueKind.Array)
{
try
{
foreach (JsonElement operationElement in operationResponses.EnumerateArray())
{
Expand All @@ -319,13 +352,12 @@ private static async Task<DistributedTransactionResponse> PopulateFromJsonConten
results.Add(operationResult);
}
}
catch (JsonException)
{
results.Clear();
}
}
}
catch (JsonException)
{
// If JSON parsing fails, return null to fall back to default response
return null;
}

HttpStatusCode finalStatusCode = responseMessage.StatusCode;
SubStatusCodes finalSubStatusCode = responseMessage.Headers.SubStatusCode;
Expand Down Expand Up @@ -353,7 +385,8 @@ private static async Task<DistributedTransactionResponse> PopulateFromJsonConten
serverRequest.Operations,
serializer,
trace,
idempotencyToken)
idempotencyToken,
isRetriable)
{
results = results
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace Microsoft.Azure.Cosmos
internal class DistributedTransactionServerRequest
{
private readonly CosmosSerializerCore serializerCore;
private MemoryStream bodyStream;
private byte[] serializedBody;

private DistributedTransactionServerRequest(
IReadOnlyList<DistributedTransactionOperation> operations,
Expand All @@ -26,7 +26,7 @@ private DistributedTransactionServerRequest(

public IReadOnlyList<DistributedTransactionOperation> Operations { get; }

public Guid IdempotencyToken { get; private set; }
public Guid IdempotencyToken { get; }

public static async Task<DistributedTransactionServerRequest> CreateAsync(
IReadOnlyList<DistributedTransactionOperation> operations,
Expand All @@ -38,11 +38,21 @@ public static async Task<DistributedTransactionServerRequest> CreateAsync(
return request;
}

public MemoryStream TransferBodyStream()
/// <summary>
/// Returns a new <see cref="MemoryStream"/> backed by the pre-serialized request bytes.
/// Each call returns an independent, non-writable stream positioned at offset zero so
/// that the caller can safely wrap it in a <c>using</c> block and dispose it without
/// affecting subsequent retry attempts.
/// </summary>
/// <returns>Body stream.</returns>
public MemoryStream CreateBodyStream()
{
MemoryStream bodyStream = this.bodyStream;
this.bodyStream = null;
return bodyStream;
if (this.serializedBody == null)
{
throw new InvalidOperationException("Request body has not been initialized. Use CreateAsync to construct a request.");
}

return new MemoryStream(this.serializedBody, writable: false);
}

private async Task CreateBodyStreamAsync(CancellationToken cancellationToken)
Expand All @@ -53,7 +63,10 @@ private async Task CreateBodyStreamAsync(CancellationToken cancellationToken)
operation.PartitionKeyJson ??= operation.PartitionKey.ToJsonString();
}

this.bodyStream = DistributedTransactionSerializer.SerializeRequest(this.Operations);
using (MemoryStream stream = DistributedTransactionSerializer.SerializeRequest(this.Operations))
{
this.serializedBody = stream.ToArray();
}
}
}
}
Loading
Loading