diff --git a/src/LLL.DurableTask.Core/Serializing/HistoryEventConverter.cs b/src/LLL.DurableTask.Core/Serializing/HistoryEventConverter.cs index 26ed223..5f0c1cd 100644 --- a/src/LLL.DurableTask.Core/Serializing/HistoryEventConverter.cs +++ b/src/LLL.DurableTask.Core/Serializing/HistoryEventConverter.cs @@ -17,6 +17,7 @@ public class HistoryEventConverter : JsonConverter //{ EventType.ExecutionFailed, typeof(ExecutionFailedEvent) }, { EventType.ExecutionStarted, typeof(ExecutionStartedEvent) }, { EventType.ExecutionTerminated, typeof(ExecutionTerminatedEvent) }, + { EventType.ExecutionRewound, typeof(ExecutionRewoundEvent) }, { EventType.GenericEvent, typeof(GenericEvent) }, { EventType.HistoryState, typeof(HistoryStateEvent) }, { EventType.OrchestratorCompleted, typeof(OrchestratorCompletedEvent) }, @@ -43,15 +44,24 @@ public override object ReadJson(JsonReader reader, Type objectType, object exist { var jObject = JObject.Load(reader); - var eventTypeToken = jObject.GetValue("EventType", StringComparison.OrdinalIgnoreCase); + var eventType = jObject.GetValue("EventType", StringComparison.OrdinalIgnoreCase) + ?.ToObject() + ?? throw new Exception("Expected EventType field in HistoryEvent"); - if (eventTypeToken is null) - throw new Exception("Expected EventType field in HistoryEvent"); - - var eventType = eventTypeToken.ToObject(); + var eventId = jObject.GetValue("EventId", StringComparison.OrdinalIgnoreCase) + ?.ToObject() + ?? throw new Exception("Expected EventId field in HistoryEvent"); var type = _typesMap[eventType]; + if (type == typeof(ExecutionRewoundEvent)) + { + // Handles multiple constructors present in ExecutionRewoundEvent + var @event = new ExecutionRewoundEvent(eventId); + serializer.Populate(jObject.CreateReader(), @event); + return @event; + } + return jObject.ToObject(type, serializer); } diff --git a/src/LLL.DurableTask.EFCore/EFCoreOrchestrationOptions.cs b/src/LLL.DurableTask.EFCore/EFCoreOrchestrationOptions.cs index 8cec339..be2dd84 100644 --- a/src/LLL.DurableTask.EFCore/EFCoreOrchestrationOptions.cs +++ b/src/LLL.DurableTask.EFCore/EFCoreOrchestrationOptions.cs @@ -17,4 +17,5 @@ public class EFCoreOrchestrationOptions public TimeSpan ActivtyLockTimeout { get; set; } = TimeSpan.FromMinutes(1); public TimeSpan FetchNewMessagesPollingTimeout { get; set; } = TimeSpan.FromSeconds(10); public int DelayInSecondsAfterFailure { get; set; } = 5; + public bool UseDTFxRewind { get; set; } = true; } diff --git a/src/LLL.DurableTask.EFCore/EFCoreOrchestrationServiceClient.cs b/src/LLL.DurableTask.EFCore/EFCoreOrchestrationServiceClient.cs index 1612fd2..aa609d3 100644 --- a/src/LLL.DurableTask.EFCore/EFCoreOrchestrationServiceClient.cs +++ b/src/LLL.DurableTask.EFCore/EFCoreOrchestrationServiceClient.cs @@ -7,6 +7,7 @@ using DurableTask.Core.Exceptions; using DurableTask.Core.History; using DurableTask.Core.Query; +using DurableTask.Core.Tracing; using LLL.DurableTask.Core; using LLL.DurableTask.EFCore.Extensions; using LLL.DurableTask.EFCore.Mappers; @@ -240,9 +241,25 @@ public async Task GetOrchestrationWithQueryAsync(Orche public async Task RewindTaskOrchestrationAsync(string instanceId, string reason) { - using var dbContext = _dbContextFactory.CreateDbContext(); - await RewindInstanceAsync(dbContext, instanceId, reason, true, FindLastErrorOrCompletionRewindPoint); - await dbContext.SaveChangesAsync(); + if (_options.UseDTFxRewind) + { + var taskMessage = new TaskMessage + { + OrchestrationInstance = new OrchestrationInstance { InstanceId = instanceId }, + Event = new ExecutionRewoundEvent(-1, reason) + { + // Set a dummy trace context to avoid an exception in DTFx + ParentTraceContext = new DistributedTraceContext($"{instanceId}") + } + }; + await SendTaskOrchestrationMessageAsync(taskMessage); + } + else + { + using var dbContext = _dbContextFactory.CreateDbContext(); + await RewindInstanceAsync(dbContext, instanceId, reason, true, FindLastErrorOrCompletionRewindPoint); + await dbContext.SaveChangesAsync(); + } } private async Task RewindInstanceAsync(OrchestrationDbContext dbContext, string instanceId, string reason, bool rewindParents, Func, HistoryEvent> findRewindPoint) diff --git a/src/LLL.DurableTask.EFCore/EFCoreOrchestrationSession.cs b/src/LLL.DurableTask.EFCore/EFCoreOrchestrationSession.cs index 09c176d..c956b2c 100644 --- a/src/LLL.DurableTask.EFCore/EFCoreOrchestrationSession.cs +++ b/src/LLL.DurableTask.EFCore/EFCoreOrchestrationSession.cs @@ -70,54 +70,38 @@ public async Task> FetchNewMessagesAsync( .OrderBy(w => w.AvailableAt) .ThenBy(w => w.SequenceNumber) .AsNoTracking() - .ToArrayAsync(cancellationToken); - - var messagesToDiscard = newDbMessages - .Where(m => m.ExecutionId is not null && m.ExecutionId != Instance.LastExecutionId) - .ToArray(); - - if (messagesToDiscard.Length > 0) - { - foreach (var message in messagesToDiscard) - { - dbContext.OrchestrationMessages.Attach(message); - dbContext.OrchestrationMessages.Remove(message); - } - - newDbMessages = newDbMessages - .Except(messagesToDiscard) - .ToArray(); - } + .ToListAsync(cancellationToken); var deserializedMessages = newDbMessages .Select(w => _options.DataConverter.Deserialize(w.Message)) .ToList(); - if (RuntimeState.ExecutionStartedEvent is not null) + if (RuntimeState.ExecutionStartedEvent is not null + && RuntimeState.OrchestrationStatus is OrchestrationStatus.Completed + && deserializedMessages.Any(m => m.Event.EventType == EventType.EventRaised)) { - if (RuntimeState.OrchestrationStatus is OrchestrationStatus.Completed - && deserializedMessages.Any(m => m.Event.EventType == EventType.EventRaised)) - { - // Reopen completed orchestrations after receiving an event raised - RuntimeState = new OrchestrationRuntimeState( - RuntimeState.Events.Reopen(_options.DataConverter) - ); - } + // Reopen completed orchestrations after receiving an event raised + RuntimeState = new OrchestrationRuntimeState( + RuntimeState.Events.Reopen(_options.DataConverter) + ); + } + + var isRunning = RuntimeState.ExecutionStartedEvent is null + || RuntimeState.OrchestrationStatus is OrchestrationStatus.Running + or OrchestrationStatus.Suspended + or OrchestrationStatus.Pending; - var isRunning = RuntimeState.OrchestrationStatus is OrchestrationStatus.Running - or OrchestrationStatus.Suspended - or OrchestrationStatus.Pending; + for (var i = newDbMessages.Count - 1; i >= 0; i--) + { + var dbMessage = newDbMessages[i]; + var deserializedMessage = deserializedMessages[i]; - if (!isRunning) + if (ShouldDropNewMessage(isRunning, dbMessage, deserializedMessage)) { - // Discard all messages if not running - foreach (var message in newDbMessages) - { - dbContext.OrchestrationMessages.Attach(message); - dbContext.OrchestrationMessages.Remove(message); - } - newDbMessages = []; - deserializedMessages = []; + dbContext.OrchestrationMessages.Attach(dbMessage); + dbContext.OrchestrationMessages.Remove(dbMessage); + newDbMessages.RemoveAt(i); + deserializedMessages.RemoveAt(i); } } @@ -126,6 +110,22 @@ or OrchestrationStatus.Suspended return deserializedMessages; } + private bool ShouldDropNewMessage( + bool isRunning, + OrchestrationMessage dbMessage, + TaskMessage taskMessage) + { + // Drop messages to previous executions + if (dbMessage.ExecutionId is not null && dbMessage.ExecutionId != Instance.LastExecutionId) + return true; + + // When not running, drop anything that is not execution rewound + if (!isRunning && taskMessage.Event.EventType != EventType.ExecutionRewound) + return true; + + return false; + } + public void ClearMessages() { Messages.Clear();