Skip to content

Commit f335993

Browse files
feat: add cancellation token support
1 parent 9b4ef86 commit f335993

File tree

5 files changed

+48
-39
lines changed

5 files changed

+48
-39
lines changed

Sources/EventViewerX/SearchEvents.QueryLog.cs

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using System.Diagnostics.Eventing.Reader;
22
using System.IO;
33
using System.Net;
4+
using System.Threading;
45

56
namespace EventViewerX;
67

@@ -67,7 +68,7 @@ private static string GetFQDN() {
6768
/// <param name="namedDataFilter">Optional hashtable containing named data filters to include events.</param>
6869
/// <param name="namedDataExcludeFilter">Optional hashtable containing named data filters to exclude events.</param>
6970
/// <returns>An enumerable collection of EventObject instances representing the filtered events from the log file.</returns>
70-
public static IEnumerable<EventObject> QueryLogFile(string filePath, List<int> eventIds = null, string providerName = null, Keywords? keywords = null, Level? level = null, DateTime? startTime = null, DateTime? endTime = null, string userId = null, int maxEvents = 0, List<long> eventRecordId = null, TimePeriod? timePeriod = null, bool oldest = false, System.Collections.Hashtable namedDataFilter = null, System.Collections.Hashtable namedDataExcludeFilter = null) {
71+
public static IEnumerable<EventObject> QueryLogFile(string filePath, List<int> eventIds = null, string providerName = null, Keywords? keywords = null, Level? level = null, DateTime? startTime = null, DateTime? endTime = null, string userId = null, int maxEvents = 0, List<long> eventRecordId = null, TimePeriod? timePeriod = null, bool oldest = false, System.Collections.Hashtable namedDataFilter = null, System.Collections.Hashtable namedDataExcludeFilter = null, CancellationToken cancellationToken = default) {
7172

7273
string absolutePath = Path.GetFullPath(filePath);
7374

@@ -142,7 +143,7 @@ public static IEnumerable<EventObject> QueryLogFile(string filePath, List<int> e
142143
using (EventLogReader reader = CreateEventLogReader(query, filePath)) {
143144
if (reader != null) {
144145
int eventCount = 0;
145-
while ((record = reader.ReadEvent()) != null) {
146+
while (!cancellationToken.IsCancellationRequested && (record = reader.ReadEvent()) != null) {
146147
// using (record) {
147148
EventObject eventObject = new EventObject(record, filePath);
148149
yield return eventObject;
@@ -172,7 +173,7 @@ public static IEnumerable<EventObject> QueryLogFile(string filePath, List<int> e
172173
/// <param name="eventRecordId">The event record identifier.</param>
173174
/// <param name="timePeriod">The time period.</param>
174175
/// <returns></returns>
175-
public static IEnumerable<EventObject> QueryLog(string logName, List<int> eventIds = null, string machineName = null, string providerName = null, Keywords? keywords = null, Level? level = null, DateTime? startTime = null, DateTime? endTime = null, string userId = null, int maxEvents = 0, List<long> eventRecordId = null, TimePeriod? timePeriod = null) {
176+
public static IEnumerable<EventObject> QueryLog(string logName, List<int> eventIds = null, string machineName = null, string providerName = null, Keywords? keywords = null, Level? level = null, DateTime? startTime = null, DateTime? endTime = null, string userId = null, int maxEvents = 0, List<long> eventRecordId = null, TimePeriod? timePeriod = null, CancellationToken cancellationToken = default) {
176177
if (eventIds != null && eventIds.Any(id => id <= 0)) {
177178
throw new ArgumentException("Event IDs must be positive.", nameof(eventIds));
178179
}
@@ -207,7 +208,7 @@ public static IEnumerable<EventObject> QueryLog(string logName, List<int> eventI
207208
using (EventLogReader reader = CreateEventLogReader(query, machineName)) {
208209
if (reader != null) {
209210
int eventCount = 0;
210-
while ((record = reader.ReadEvent()) != null) {
211+
while (!cancellationToken.IsCancellationRequested && (record = reader.ReadEvent()) != null) {
211212
// using (record) {
212213
EventObject eventObject = new EventObject(record, queriedMachine);
213214
yield return eventObject;
@@ -218,11 +219,11 @@ public static IEnumerable<EventObject> QueryLog(string logName, List<int> eventI
218219
// }
219220
}
220221
}
221-
}
222+
}
222223
}
223224

224-
public static IEnumerable<EventObject> QueryLog(KnownLog logName, List<int> eventIds = null, string machineName = null, string providerName = null, Keywords? keywords = null, Level? level = null, DateTime? startTime = null, DateTime? endTime = null, string userId = null, int maxEvents = 0, List<long> eventRecordId = null, TimePeriod? timePeriod = null) {
225-
return QueryLog(LogNameToString(logName), eventIds, machineName, providerName, keywords, level, startTime, endTime, userId, maxEvents, eventRecordId, timePeriod);
225+
public static IEnumerable<EventObject> QueryLog(KnownLog logName, List<int> eventIds = null, string machineName = null, string providerName = null, Keywords? keywords = null, Level? level = null, DateTime? startTime = null, DateTime? endTime = null, string userId = null, int maxEvents = 0, List<long> eventRecordId = null, TimePeriod? timePeriod = null, CancellationToken cancellationToken = default) {
226+
return QueryLog(LogNameToString(logName), eventIds, machineName, providerName, keywords, level, startTime, endTime, userId, maxEvents, eventRecordId, timePeriod, cancellationToken);
226227
}
227228

228229
/// <summary>
@@ -345,7 +346,7 @@ private static void AddCondition(StringBuilder queryString, string condition) {
345346
_ => logName.ToString()
346347
};
347348

348-
public static IEnumerable<EventObject> QueryLogsParallel(string logName, List<int> eventIds = null, List<string> machineNames = null, string providerName = null, Keywords? keywords = null, Level? level = null, DateTime? startTime = null, DateTime? endTime = null, string userId = null, int maxEvents = 0, int maxThreads = 8, List<long> eventRecordId = null, TimePeriod? timePeriod = null) {
349+
public static IEnumerable<EventObject> QueryLogsParallel(string logName, List<int> eventIds = null, List<string> machineNames = null, string providerName = null, Keywords? keywords = null, Level? level = null, DateTime? startTime = null, DateTime? endTime = null, string userId = null, int maxEvents = 0, int maxThreads = 8, List<long> eventRecordId = null, TimePeriod? timePeriod = null, CancellationToken cancellationToken = default) {
349350
if (machineNames == null || !machineNames.Any()) {
350351
machineNames = new List<string> { null };
351352
_logger.WriteVerbose("No machine names provided, querying the local machine.");
@@ -364,7 +365,7 @@ public static IEnumerable<EventObject> QueryLogsParallel(string logName, List<in
364365
.ToList();
365366

366367
foreach (var chunk in eventIdsChunks) {
367-
tasks.Add(CreateTask(machineName, logName, chunk, providerName, keywords, level, startTime, endTime, userId, maxEvents, semaphore, results, timePeriod: timePeriod));
368+
tasks.Add(CreateTask(machineName, logName, chunk, providerName, keywords, level, startTime, endTime, userId, maxEvents, semaphore, results, cancellationToken, timePeriod: timePeriod));
368369
}
369370
} else if (eventRecordId != null) {
370371
var eventRecordIdChunks = eventRecordId.Select((x, i) => new { Index = i, Value = x })
@@ -373,11 +374,11 @@ public static IEnumerable<EventObject> QueryLogsParallel(string logName, List<in
373374
.ToList();
374375

375376
foreach (var chunk in eventRecordIdChunks) {
376-
tasks.Add(CreateTask(machineName, logName, null, providerName, keywords, level, startTime, endTime, userId, maxEvents, semaphore, results, chunk, timePeriod: timePeriod));
377+
tasks.Add(CreateTask(machineName, logName, null, providerName, keywords, level, startTime, endTime, userId, maxEvents, semaphore, results, cancellationToken, chunk, timePeriod: timePeriod));
377378
}
378379
} else {
379380
// event ids are null, so we don't need to chunk them
380-
tasks.Add(CreateTask(machineName, logName, eventIds, providerName, keywords, level, startTime, endTime, userId, maxEvents, semaphore, results, timePeriod: timePeriod));
381+
tasks.Add(CreateTask(machineName, logName, eventIds, providerName, keywords, level, startTime, endTime, userId, maxEvents, semaphore, results, cancellationToken, timePeriod: timePeriod));
381382
}
382383
}
383384

@@ -386,30 +387,31 @@ public static IEnumerable<EventObject> QueryLogsParallel(string logName, List<in
386387
results.CompleteAdding();
387388
});
388389

389-
return results.GetConsumingEnumerable();
390+
return results.GetConsumingEnumerable(cancellationToken);
390391
}
391392

392-
public static IEnumerable<EventObject> QueryLogsParallel(KnownLog logName, List<int> eventIds = null, List<string> machineNames = null, string providerName = null, Keywords? keywords = null, Level? level = null, DateTime? startTime = null, DateTime? endTime = null, string userId = null, int maxEvents = 0, int maxThreads = 8, List<long> eventRecordId = null, TimePeriod? timePeriod = null) {
393-
return QueryLogsParallel(LogNameToString(logName), eventIds, machineNames, providerName, keywords, level, startTime, endTime, userId, maxEvents, maxThreads, eventRecordId, timePeriod);
393+
public static IEnumerable<EventObject> QueryLogsParallel(KnownLog logName, List<int> eventIds = null, List<string> machineNames = null, string providerName = null, Keywords? keywords = null, Level? level = null, DateTime? startTime = null, DateTime? endTime = null, string userId = null, int maxEvents = 0, int maxThreads = 8, List<long> eventRecordId = null, TimePeriod? timePeriod = null, CancellationToken cancellationToken = default) {
394+
return QueryLogsParallel(LogNameToString(logName), eventIds, machineNames, providerName, keywords, level, startTime, endTime, userId, maxEvents, maxThreads, eventRecordId, timePeriod, cancellationToken);
394395
}
395396

396-
private static Task CreateTask(string machineName, string logName, List<int> eventIds, string providerName, Keywords? keywords, Level? level, DateTime? startTime, DateTime? endTime, string userId, int maxEvents, SemaphoreSlim semaphore, BlockingCollection<EventObject> results, List<long> eventRecordId = null, TimePeriod? timePeriod = null) {
397+
private static Task CreateTask(string machineName, string logName, List<int> eventIds, string providerName, Keywords? keywords, Level? level, DateTime? startTime, DateTime? endTime, string userId, int maxEvents, SemaphoreSlim semaphore, BlockingCollection<EventObject> results, CancellationToken cancellationToken, List<long> eventRecordId = null, TimePeriod? timePeriod = null) {
397398
return Task.Run(async () => {
398399
_logger.WriteVerbose($"Querying log on machine: {machineName}, logName: {logName}, event ids: " + string.Join(", ", eventIds ?? new List<int>()));
399-
await semaphore.WaitAsync();
400+
await semaphore.WaitAsync(cancellationToken);
400401
try {
401-
var queryResults = QueryLog(logName, eventIds, machineName, providerName, keywords, level, startTime, endTime, userId, maxEvents, eventRecordId, timePeriod);
402+
var queryResults = QueryLog(logName, eventIds, machineName, providerName, keywords, level, startTime, endTime, userId, maxEvents, eventRecordId, timePeriod, cancellationToken);
402403
foreach (var result in queryResults) {
403-
results.Add(result);
404+
if (cancellationToken.IsCancellationRequested) break;
405+
results.Add(result, cancellationToken);
404406
}
405407
_logger.WriteVerbose("Querying log on machine: " + machineName + " completed.");
406408
} finally {
407409
semaphore.Release();
408410
}
409-
});
411+
}, cancellationToken);
410412
}
411413

412-
public static IEnumerable<EventObject> QueryLogsParallelForEach(string logName, List<int> eventIds = null, List<string> machineNames = null, string providerName = null, Keywords? keywords = null, Level? level = null, DateTime? startTime = null, DateTime? endTime = null, string userId = null, int maxEvents = 0, int maxThreads = 4, List<long> eventRecordId = null) {
414+
public static IEnumerable<EventObject> QueryLogsParallelForEach(string logName, List<int> eventIds = null, List<string> machineNames = null, string providerName = null, Keywords? keywords = null, Level? level = null, DateTime? startTime = null, DateTime? endTime = null, string userId = null, int maxEvents = 0, int maxThreads = 4, List<long> eventRecordId = null, CancellationToken cancellationToken = default) {
413415
if (machineNames == null || !machineNames.Any()) {
414416
throw new ArgumentException("At least one machine name must be provided", nameof(machineNames));
415417
}
@@ -420,19 +422,20 @@ public static IEnumerable<EventObject> QueryLogsParallelForEach(string logName,
420422
Task.Factory.StartNew(() => {
421423
Parallel.ForEach(machineNames, options, machineName => {
422424
_logger.WriteVerbose("Starting task for machine: " + machineName);
423-
var queryResults = QueryLog(logName, eventIds, machineName, providerName, keywords, level, startTime, endTime, userId, maxEvents, eventRecordId);
425+
var queryResults = QueryLog(logName, eventIds, machineName, providerName, keywords, level, startTime, endTime, userId, maxEvents, eventRecordId, cancellationToken: cancellationToken);
424426
foreach (var result in queryResults) {
425-
results.Add(result);
427+
if (cancellationToken.IsCancellationRequested) break;
428+
results.Add(result, cancellationToken);
426429
}
427430
_logger.WriteVerbose("Finished task for machine: " + machineName);
428431
});
429432
results.CompleteAdding();
430-
});
433+
}, cancellationToken);
431434

432-
return results.GetConsumingEnumerable();
435+
return results.GetConsumingEnumerable(cancellationToken);
433436
}
434437

435-
public static IEnumerable<EventObject> QueryLogsParallelForEach(KnownLog logName, List<int> eventIds = null, List<string> machineNames = null, string providerName = null, Keywords? keywords = null, Level? level = null, DateTime? startTime = null, DateTime? endTime = null, string userId = null, int maxEvents = 0, int maxThreads = 4, List<long> eventRecordId = null) {
436-
return QueryLogsParallelForEach(LogNameToString(logName), eventIds, machineNames, providerName, keywords, level, startTime, endTime, userId, maxEvents, maxThreads, eventRecordId);
438+
public static IEnumerable<EventObject> QueryLogsParallelForEach(KnownLog logName, List<int> eventIds = null, List<string> machineNames = null, string providerName = null, Keywords? keywords = null, Level? level = null, DateTime? startTime = null, DateTime? endTime = null, string userId = null, int maxEvents = 0, int maxThreads = 4, List<long> eventRecordId = null, CancellationToken cancellationToken = default) {
439+
return QueryLogsParallelForEach(LogNameToString(logName), eventIds, machineNames, providerName, keywords, level, startTime, endTime, userId, maxEvents, maxThreads, eventRecordId, cancellationToken);
437440
}
438441
}

Sources/EventViewerX/SearchEvents.QueryNamedEvents.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@
22
using System.Collections.Generic;
33
using System.Linq;
44

5+
using System.Threading;
56
namespace EventViewerX {
67
public partial class SearchEvents : Settings {
78

8-
public static IEnumerable<EventObjectSlim> FindEventsByNamedEvents(List<NamedEvents> typeEventsList, List<string> machineNames = null, DateTime? startTime = null, DateTime? endTime = null, TimePeriod? timePeriod = null, int maxThreads = 8, int maxEvents = 0) {
9+
public static IEnumerable<EventObjectSlim> FindEventsByNamedEvents(List<NamedEvents> typeEventsList, List<string> machineNames = null, DateTime? startTime = null, DateTime? endTime = null, TimePeriod? timePeriod = null, int maxThreads = 8, int maxEvents = 0, CancellationToken cancellationToken = default) {
910
// Create a dictionary to store unique event IDs and log names
1011
var eventInfoDict = new Dictionary<string, HashSet<int>>();
1112

@@ -31,7 +32,7 @@ public static IEnumerable<EventObjectSlim> FindEventsByNamedEvents(List<NamedEve
3132
var logName = kvp.Key;
3233
var eventIds = kvp.Value.ToList();
3334

34-
foreach (var foundEvent in SearchEvents.QueryLogsParallel(logName, eventIds, machineNames, startTime: startTime, endTime: endTime, timePeriod: timePeriod, maxThreads: maxThreads, maxEvents: maxEvents)) {
35+
foreach (var foundEvent in SearchEvents.QueryLogsParallel(logName, eventIds, machineNames, startTime: startTime, endTime: endTime, timePeriod: timePeriod, maxThreads: maxThreads, maxEvents: maxEvents, cancellationToken: cancellationToken)) {
3536
_logger.WriteDebug($"Found event: {foundEvent.Id} {foundEvent.LogName} {foundEvent.ComputerName}");
3637
// yield return BuildTargetEvents(foundEvent, typeEventsList);
3738
var targetEvent = BuildTargetEvents(foundEvent, typeEventsList);

Sources/EventViewerX/WatchEvents.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public WatchEvents(InternalLogger internalLogger = null) {
3131
}
3232
}
3333

34-
public void Watch(string machineName, string logName, List<int> eventId, Action<EventObject> eventAction = null) {
34+
public void Watch(string machineName, string logName, List<int> eventId, Action<EventObject> eventAction = null, CancellationToken cancellationToken = default) {
3535
_machineName = machineName;
3636
_watchEventId = new ConcurrentBag<int>(eventId);
3737
_eventAction = eventAction;
@@ -42,6 +42,11 @@ public void Watch(string machineName, string logName, List<int> eventId, Action<
4242
});
4343
eventLogWatcher.EventRecordWritten += DetectEventsLogCallback;
4444
eventLogWatcher.Enabled = true;
45+
cancellationToken.Register(() => {
46+
eventLogWatcher.EventRecordWritten -= DetectEventsLogCallback;
47+
eventLogWatcher.Enabled = false;
48+
eventLogWatcher.Dispose();
49+
});
4550
_logger.WriteVerbose("Created event log subscription to {0}.", machineName);
4651
} catch (Exception ex) {
4752
_logger.WriteWarning("Failed to create event log subscription to Target Machine {0}. Verify network connectivity, firewall settings, permissions, etc. Continuing on to next DC if applicable... ({1})", machineName, ex.Message.Trim());

0 commit comments

Comments
 (0)