Skip to content

Commit dbbfcee

Browse files
Copilotcrickman
andcommitted
Cache derived properties in MessageIndex to avoid repeated group traversals
Co-authored-by: crickman <66376200+crickman@users.noreply.github.com>
1 parent 24001ef commit dbbfcee

File tree

3 files changed

+290
-19
lines changed

3 files changed

+290
-19
lines changed

dotnet/src/Microsoft.Agents.AI/Compaction/CompactionMessageGroup.cs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// Copyright (c) Microsoft. All rights reserved.
22

3+
using System;
34
using System.Collections.Generic;
45
using System.Diagnostics.CodeAnalysis;
56
using System.Text.Json.Serialization;
@@ -100,14 +101,33 @@ internal CompactionMessageGroup(CompactionGroupKind kind, IReadOnlyList<ChatMess
100101
/// </remarks>
101102
public int? TurnIndex { get; }
102103

104+
private bool _isExcluded;
105+
106+
/// <summary>
107+
/// An optional callback invoked when <see cref="IsExcluded"/> changes value.
108+
/// Used internally by <see cref="CompactionMessageIndex"/> to invalidate cached aggregates.
109+
/// </summary>
110+
internal Action? ExclusionChanged;
111+
103112
/// <summary>
104113
/// Gets or sets a value indicating whether this group is excluded from the projected message list.
105114
/// </summary>
106115
/// <remarks>
107116
/// Excluded groups are preserved in the collection for diagnostics or storage purposes
108117
/// but are not included when calling <see cref="CompactionMessageIndex.GetIncludedMessages"/>.
109118
/// </remarks>
110-
public bool IsExcluded { get; set; }
119+
public bool IsExcluded
120+
{
121+
get => _isExcluded;
122+
set
123+
{
124+
if (_isExcluded != value)
125+
{
126+
_isExcluded = value;
127+
ExclusionChanged?.Invoke();
128+
}
129+
}
130+
}
111131

112132
/// <summary>
113133
/// Gets or sets an optional reason explaining why this group was excluded.

dotnet/src/Microsoft.Agents.AI/Compaction/CompactionMessageIndex.cs

Lines changed: 76 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,20 @@ public sealed class CompactionMessageIndex
2727
private int _currentTurn;
2828
private ChatMessage? _lastProcessedMessage;
2929

30+
// Cached values for derived properties — invalidated whenever groups are added/removed
31+
// or a group's IsExcluded state changes.
32+
private int? _cachedTotalMessageCount;
33+
private int? _cachedTotalByteCount;
34+
private int? _cachedTotalTokenCount;
35+
private int? _cachedIncludedGroupCount;
36+
private int? _cachedIncludedMessageCount;
37+
private int? _cachedIncludedByteCount;
38+
private int? _cachedIncludedTokenCount;
39+
private int? _cachedTotalTurnCount;
40+
private int? _cachedIncludedTurnCount;
41+
private int? _cachedIncludedNonSystemGroupCount;
42+
private int? _cachedRawMessageCount;
43+
3044
/// <summary>
3145
/// Gets the list of message groups in this collection.
3246
/// </summary>
@@ -47,6 +61,12 @@ public CompactionMessageIndex(IList<CompactionMessageGroup> groups, Tokenizer? t
4761
this.Groups = Throw.IfNull(groups, nameof(groups));
4862
this.Tokenizer = tokenizer;
4963

64+
// Register all pre-existing groups so that IsExcluded changes invalidate the cache.
65+
for (int i = 0; i < groups.Count; i++)
66+
{
67+
this.RegisterGroup(groups[i]);
68+
}
69+
5070
// Restore turn counter and last processed message from the groups
5171
for (int index = groups.Count - 1; index >= 0; --index)
5272
{
@@ -123,6 +143,7 @@ internal void Update(IList<ChatMessage> allMessages)
123143
this.Groups.Clear();
124144
this._currentTurn = 0;
125145
this._lastProcessedMessage = null;
146+
this.InvalidateCache();
126147
return;
127148
}
128149

@@ -184,13 +205,13 @@ private void AppendFromMessages(IList<ChatMessage> messages, int startIndex)
184205
if (message.Role == ChatRole.System)
185206
{
186207
// System messages are not part of any turn
187-
this.Groups.Add(CreateGroup(CompactionGroupKind.System, [message], this.Tokenizer, turnIndex: null));
208+
this.AddAndRegisterGroup(CreateGroup(CompactionGroupKind.System, [message], this.Tokenizer, turnIndex: null));
188209
index++;
189210
}
190211
else if (message.Role == ChatRole.User)
191212
{
192213
this._currentTurn++;
193-
this.Groups.Add(CreateGroup(CompactionGroupKind.User, [message], this.Tokenizer, this._currentTurn));
214+
this.AddAndRegisterGroup(CreateGroup(CompactionGroupKind.User, [message], this.Tokenizer, this._currentTurn));
194215
index++;
195216
}
196217
else if (message.Role == ChatRole.Assistant && HasToolCalls(message))
@@ -207,11 +228,11 @@ private void AppendFromMessages(IList<ChatMessage> messages, int startIndex)
207228
index++;
208229
}
209230

210-
this.Groups.Add(CreateGroup(CompactionGroupKind.ToolCall, groupMessages, this.Tokenizer, this._currentTurn));
231+
this.AddAndRegisterGroup(CreateGroup(CompactionGroupKind.ToolCall, groupMessages, this.Tokenizer, this._currentTurn));
211232
}
212233
else if (message.Role == ChatRole.Assistant && IsSummaryMessage(message))
213234
{
214-
this.Groups.Add(CreateGroup(CompactionGroupKind.Summary, [message], this.Tokenizer, this._currentTurn));
235+
this.AddAndRegisterGroup(CreateGroup(CompactionGroupKind.Summary, [message], this.Tokenizer, this._currentTurn));
215236
index++;
216237
}
217238
else if (message.Role == ChatRole.Assistant && HasOnlyReasoning(message))
@@ -247,17 +268,17 @@ private void AppendFromMessages(IList<ChatMessage> messages, int startIndex)
247268
index++;
248269
}
249270

250-
this.Groups.Add(CreateGroup(CompactionGroupKind.ToolCall, groupMessages, this.Tokenizer, this._currentTurn));
271+
this.AddAndRegisterGroup(CreateGroup(CompactionGroupKind.ToolCall, groupMessages, this.Tokenizer, this._currentTurn));
251272
}
252273
else
253274
{
254-
this.Groups.Add(CreateGroup(CompactionGroupKind.AssistantText, [message], this.Tokenizer, this._currentTurn));
275+
this.AddAndRegisterGroup(CreateGroup(CompactionGroupKind.AssistantText, [message], this.Tokenizer, this._currentTurn));
255276
index++;
256277
}
257278
}
258279
else
259280
{
260-
this.Groups.Add(CreateGroup(CompactionGroupKind.AssistantText, [message], this.Tokenizer, this._currentTurn));
281+
this.AddAndRegisterGroup(CreateGroup(CompactionGroupKind.AssistantText, [message], this.Tokenizer, this._currentTurn));
261282
index++;
262283
}
263284
}
@@ -266,6 +287,8 @@ private void AppendFromMessages(IList<ChatMessage> messages, int startIndex)
266287
{
267288
this._lastProcessedMessage = messages[^1];
268289
}
290+
291+
this.InvalidateCache();
269292
}
270293

271294
/// <summary>
@@ -281,6 +304,8 @@ public CompactionMessageGroup InsertGroup(int index, CompactionGroupKind kind, I
281304
{
282305
CompactionMessageGroup group = CreateGroup(kind, messages, this.Tokenizer, turnIndex);
283306
this.Groups.Insert(index, group);
307+
this.RegisterGroup(group);
308+
this.InvalidateCache();
284309
return group;
285310
}
286311

@@ -296,6 +321,8 @@ public CompactionMessageGroup AddGroup(CompactionGroupKind kind, IReadOnlyList<C
296321
{
297322
CompactionMessageGroup group = CreateGroup(kind, messages, this.Tokenizer, turnIndex);
298323
this.Groups.Add(group);
324+
this.RegisterGroup(group);
325+
this.InvalidateCache();
299326
return group;
300327
}
301328

@@ -320,57 +347,57 @@ public IEnumerable<ChatMessage> GetIncludedMessages() =>
320347
/// <summary>
321348
/// Gets the total number of messages across all groups, including excluded ones.
322349
/// </summary>
323-
public int TotalMessageCount => this.Groups.Sum(group => group.MessageCount);
350+
public int TotalMessageCount => _cachedTotalMessageCount ??= this.Groups.Sum(group => group.MessageCount);
324351

325352
/// <summary>
326353
/// Gets the total UTF-8 byte count across all groups, including excluded ones.
327354
/// </summary>
328-
public int TotalByteCount => this.Groups.Sum(group => group.ByteCount);
355+
public int TotalByteCount => _cachedTotalByteCount ??= this.Groups.Sum(group => group.ByteCount);
329356

330357
/// <summary>
331358
/// Gets the total token count across all groups, including excluded ones.
332359
/// </summary>
333-
public int TotalTokenCount => this.Groups.Sum(group => group.TokenCount);
360+
public int TotalTokenCount => _cachedTotalTokenCount ??= this.Groups.Sum(group => group.TokenCount);
334361

335362
/// <summary>
336363
/// Gets the total number of groups that are not excluded.
337364
/// </summary>
338-
public int IncludedGroupCount => this.Groups.Count(group => !group.IsExcluded);
365+
public int IncludedGroupCount => _cachedIncludedGroupCount ??= this.Groups.Count(group => !group.IsExcluded);
339366

340367
/// <summary>
341368
/// Gets the total number of messages across all included (non-excluded) groups.
342369
/// </summary>
343-
public int IncludedMessageCount => this.Groups.Where(group => !group.IsExcluded).Sum(group => group.MessageCount);
370+
public int IncludedMessageCount => _cachedIncludedMessageCount ??= this.Groups.Where(group => !group.IsExcluded).Sum(group => group.MessageCount);
344371

345372
/// <summary>
346373
/// Gets the total UTF-8 byte count across all included (non-excluded) groups.
347374
/// </summary>
348-
public int IncludedByteCount => this.Groups.Where(group => !group.IsExcluded).Sum(group => group.ByteCount);
375+
public int IncludedByteCount => _cachedIncludedByteCount ??= this.Groups.Where(group => !group.IsExcluded).Sum(group => group.ByteCount);
349376

350377
/// <summary>
351378
/// Gets the total token count across all included (non-excluded) groups.
352379
/// </summary>
353-
public int IncludedTokenCount => this.Groups.Where(group => !group.IsExcluded).Sum(group => group.TokenCount);
380+
public int IncludedTokenCount => _cachedIncludedTokenCount ??= this.Groups.Where(group => !group.IsExcluded).Sum(group => group.TokenCount);
354381

355382
/// <summary>
356383
/// Gets the total number of user turns across all groups (including those with excluded groups).
357384
/// </summary>
358-
public int TotalTurnCount => this.Groups.Select(group => group.TurnIndex).Distinct().Count(turnIndex => turnIndex is not null && turnIndex > 0);
385+
public int TotalTurnCount => _cachedTotalTurnCount ??= this.Groups.Select(group => group.TurnIndex).Distinct().Count(turnIndex => turnIndex is not null && turnIndex > 0);
359386

360387
/// <summary>
361388
/// Gets the number of user turns that have at least one non-excluded group.
362389
/// </summary>
363-
public int IncludedTurnCount => this.Groups.Where(group => !group.IsExcluded && group.TurnIndex is not null && group.TurnIndex > 0).Select(group => group.TurnIndex).Distinct().Count();
390+
public int IncludedTurnCount => _cachedIncludedTurnCount ??= this.Groups.Where(group => !group.IsExcluded && group.TurnIndex is not null && group.TurnIndex > 0).Select(group => group.TurnIndex).Distinct().Count();
364391

365392
/// <summary>
366393
/// Gets the total number of groups across all included (non-excluded) groups that are not <see cref="CompactionGroupKind.System"/>.
367394
/// </summary>
368-
public int IncludedNonSystemGroupCount => this.Groups.Count(group => !group.IsExcluded && group.Kind != CompactionGroupKind.System);
395+
public int IncludedNonSystemGroupCount => _cachedIncludedNonSystemGroupCount ??= this.Groups.Count(group => !group.IsExcluded && group.Kind != CompactionGroupKind.System);
369396

370397
/// <summary>
371398
/// Gets the total number of original messages (that are not summaries).
372399
/// </summary>
373-
public int RawMessageCount => this.Groups.Where(group => group.Kind != CompactionGroupKind.Summary).Sum(group => group.MessageCount);
400+
public int RawMessageCount => _cachedRawMessageCount ??= this.Groups.Where(group => group.Kind != CompactionGroupKind.Summary).Sum(group => group.MessageCount);
374401

375402
/// <summary>
376403
/// Returns all groups that belong to the specified user turn.
@@ -379,6 +406,37 @@ public IEnumerable<ChatMessage> GetIncludedMessages() =>
379406
/// <returns>The groups belonging to the turn, in order.</returns>
380407
public IEnumerable<CompactionMessageGroup> GetTurnGroups(int turnIndex) => this.Groups.Where(group => group.TurnIndex == turnIndex);
381408

409+
private void InvalidateCache()
410+
{
411+
_cachedTotalMessageCount = null;
412+
_cachedTotalByteCount = null;
413+
_cachedTotalTokenCount = null;
414+
_cachedIncludedGroupCount = null;
415+
_cachedIncludedMessageCount = null;
416+
_cachedIncludedByteCount = null;
417+
_cachedIncludedTokenCount = null;
418+
_cachedTotalTurnCount = null;
419+
_cachedIncludedTurnCount = null;
420+
_cachedIncludedNonSystemGroupCount = null;
421+
_cachedRawMessageCount = null;
422+
}
423+
424+
private void RegisterGroup(CompactionMessageGroup group)
425+
{
426+
// Each group is owned by exactly one index, so assignment rather than
427+
// += is intentional — no need to chain callbacks.
428+
group.ExclusionChanged = this.InvalidateCache;
429+
}
430+
431+
// Adds the group to the list and registers it for cache invalidation.
432+
// Callers that add many groups in a loop (e.g. AppendFromMessages) call
433+
// InvalidateCache() once at the end rather than per-group for efficiency.
434+
private void AddAndRegisterGroup(CompactionMessageGroup group)
435+
{
436+
this.Groups.Add(group);
437+
this.RegisterGroup(group);
438+
}
439+
382440
/// <summary>
383441
/// Computes the UTF-8 byte count for a set of messages across all content types.
384442
/// </summary>

0 commit comments

Comments
 (0)