Skip to content

Commit 829f8ff

Browse files
committed
Use RawRepresentation for more faithful native multi-turn
Rather than converting back and forth when the same messages we returned are passed back to continue a conversation, we leverage the message's RawRepresentation to bring back directly the original gRPC message. This also requires doing a more 1:1 conversion of completion messages into individual response messages, rather than having a single response message with multiple contents. This would align better with the official xAI docs on response too.
1 parent f977db8 commit 829f8ff

2 files changed

Lines changed: 104 additions & 47 deletions

File tree

src/xAI/GrokChatClient.cs

Lines changed: 24 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,11 @@ public async Task<ChatResponse> GetResponseAsync(IEnumerable<ChatMessage> messag
4545
ResponseId = response.Id,
4646
ModelId = response.Model,
4747
CreatedAt = response.Created?.ToDateTimeOffset(),
48-
FinishReason = lastOutput != null ? MapFinishReason(lastOutput.FinishReason) : null,
49-
Usage = MapToUsage(response.Usage),
48+
FinishReason = lastOutput != null ? lastOutput.FinishReason.Convert() : null,
49+
Usage = response.Usage.Convert(),
5050
};
5151

52-
var citations = response.Citations?.Distinct().Select(MapCitation).ToList<AIAnnotation>();
52+
var citations = response.Citations?.Distinct().Select(x => x.FromCitationUrl()).ToList<AIAnnotation>();
5353

5454
((List<ChatMessage>)result.Messages).AddRange(response.Outputs.AsChatMessages(citations));
5555

@@ -73,12 +73,12 @@ async IAsyncEnumerable<ChatResponseUpdate> CompleteChatStreamingCore(IEnumerable
7373
// Use positional arguments for ChatResponseUpdate
7474
var update = new ChatResponseUpdate
7575
{
76-
Role = MapRole(output.Delta.Role),
76+
Role = output.Delta.Role.Convert(),
7777
ResponseId = chunk.Id,
7878
ModelId = chunk.Model,
7979
CreatedAt = chunk.Created?.ToDateTimeOffset(),
8080
RawRepresentation = chunk,
81-
FinishReason = output.FinishReason != FinishReason.ReasonInvalid ? MapFinishReason(output.FinishReason) : null,
81+
FinishReason = output.FinishReason != FinishReason.ReasonInvalid ? output.FinishReason.Convert() : null,
8282
};
8383

8484
var citations = chunk.Citations?.Distinct().Select(MapCitation).ToList<AIAnnotation>();
@@ -101,7 +101,7 @@ async IAsyncEnumerable<ChatResponseUpdate> CompleteChatStreamingCore(IEnumerable
101101
text is not null)
102102
update.Contents.Add(new TextContent(text));
103103

104-
if (MapToUsage(chunk.Usage) is { } usage)
104+
if (chunk.Usage.Convert() is { } usage)
105105
update.Contents.Add(new UsageContent(usage) { RawRepresentation = chunk.Usage });
106106

107107
yield return update;
@@ -149,10 +149,27 @@ GetCompletionsRequest MapToRequest(IEnumerable<ChatMessage> messages, ChatOption
149149

150150
foreach (var message in messages)
151151
{
152-
var gmsg = new Message { Role = MapRole(message.Role) };
152+
if (message.RawRepresentation is Message input)
153+
{
154+
request.Messages.Add(input);
155+
continue;
156+
}
157+
else if (message.RawRepresentation is CompletionMessage completion)
158+
{
159+
request.Messages.Add(completion.AsMessage());
160+
continue;
161+
}
162+
163+
var gmsg = new Message { Role = message.Role.Convert() };
153164

154165
foreach (var content in message.Contents)
155166
{
167+
if (content.RawRepresentation is CompletionMessage completion)
168+
{
169+
request.Messages.Add(completion.AsMessage());
170+
continue;
171+
}
172+
156173
if (content is TextContent textContent && !string.IsNullOrEmpty(textContent.Text))
157174
{
158175
gmsg.Content.Add(new Content { Text = textContent.Text });
@@ -271,41 +288,6 @@ codeResult.RawRepresentation is ToolCall codeToolCall &&
271288
return request;
272289
}
273290

274-
static MessageRole MapRole(ChatRole role) => role switch
275-
{
276-
_ when role == ChatRole.System => MessageRole.RoleSystem,
277-
_ when role == ChatRole.User => MessageRole.RoleUser,
278-
_ when role == ChatRole.Assistant => MessageRole.RoleAssistant,
279-
_ when role == ChatRole.Tool => MessageRole.RoleTool,
280-
_ => MessageRole.RoleUser
281-
};
282-
283-
static ChatRole MapRole(MessageRole role) => role switch
284-
{
285-
MessageRole.RoleSystem => ChatRole.System,
286-
MessageRole.RoleUser => ChatRole.User,
287-
MessageRole.RoleAssistant => ChatRole.Assistant,
288-
MessageRole.RoleTool => ChatRole.Tool,
289-
_ => ChatRole.Assistant
290-
};
291-
292-
static ChatFinishReason? MapFinishReason(FinishReason finishReason) => finishReason switch
293-
{
294-
FinishReason.ReasonStop => ChatFinishReason.Stop,
295-
FinishReason.ReasonMaxLen => ChatFinishReason.Length,
296-
FinishReason.ReasonToolCalls => ChatFinishReason.ToolCalls,
297-
FinishReason.ReasonMaxContext => ChatFinishReason.Length,
298-
FinishReason.ReasonTimeLimit => ChatFinishReason.Length,
299-
_ => null
300-
};
301-
302-
static UsageDetails? MapToUsage(SamplingUsage usage) => usage == null ? null : new()
303-
{
304-
InputTokenCount = usage.PromptTokens,
305-
OutputTokenCount = usage.CompletionTokens,
306-
TotalTokenCount = usage.TotalTokens
307-
};
308-
309291
/// <inheritdoc />
310292
public object? GetService(Type serviceType, object? serviceKey = null) => serviceType switch
311293
{

src/xAI/GrokProtocolExtensions.cs

Lines changed: 80 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,11 +161,13 @@ grokSearch.City is not null ||
161161

162162
static IEnumerable<ChatMessage> ToChatMessages(IEnumerable<CompletionMessage> messages, List<AIAnnotation>? citations = default)
163163
{
164-
ChatMessage? message = null;
165-
166164
foreach (var completion in messages)
167165
{
168-
message ??= new(ChatRole.Assistant, (string?)null);
166+
ChatMessage message = new(ChatRole.Assistant, (string?)null)
167+
{
168+
RawRepresentation = completion
169+
};
170+
169171
var annotations = citations;
170172
if (completion.Citations.Count > 0)
171173
{
@@ -214,11 +216,10 @@ static IEnumerable<ChatMessage> ToChatMessages(IEnumerable<CompletionMessage> me
214216
// RawRepresentation = completion
215217
// });
216218
//}
217-
}
218219

219-
if (message is not null)
220220
yield return message;
221221
}
222+
}
222223

223224
internal static IEnumerable<AIContent> AsContents(this IEnumerable<ToolCall> toolCalls, string? content = default, List<AIAnnotation>? annotations = default)
224225
{
@@ -356,6 +357,80 @@ static IEnumerable<CitationAnnotation> AsCitations(CollectionSearchItem item)
356357
_ => [new CitationAnnotation { RawRepresentation = citation }]
357358
};
358359

360+
internal static Message AsMessage(this CompletionMessage completion)
361+
{
362+
var message = new Message
363+
{
364+
Role = completion.Role,
365+
EncryptedContent = completion.EncryptedContent,
366+
ReasoningContent = completion.ReasoningContent,
367+
};
368+
369+
if (!string.IsNullOrEmpty(completion.Content))
370+
message.Content.Add(new Content { Text = completion.Content });
371+
372+
message.ToolCalls.AddRange(completion.ToolCalls);
373+
374+
return message;
375+
}
376+
377+
internal static MessageRole Convert(this ChatRole role) => role switch
378+
{
379+
_ when role == ChatRole.System => MessageRole.RoleSystem,
380+
_ when role == ChatRole.User => MessageRole.RoleUser,
381+
_ when role == ChatRole.Assistant => MessageRole.RoleAssistant,
382+
_ when role == ChatRole.Tool => MessageRole.RoleTool,
383+
_ => MessageRole.RoleUser
384+
};
385+
386+
internal static ChatRole Convert(this MessageRole role) => role switch
387+
{
388+
MessageRole.RoleSystem => ChatRole.System,
389+
MessageRole.RoleUser => ChatRole.User,
390+
MessageRole.RoleAssistant => ChatRole.Assistant,
391+
MessageRole.RoleTool => ChatRole.Tool,
392+
_ => ChatRole.Assistant
393+
};
394+
395+
internal static ChatFinishReason? Convert(this FinishReason finishReason) => finishReason switch
396+
{
397+
FinishReason.ReasonStop => ChatFinishReason.Stop,
398+
FinishReason.ReasonMaxLen => ChatFinishReason.Length,
399+
FinishReason.ReasonToolCalls => ChatFinishReason.ToolCalls,
400+
FinishReason.ReasonMaxContext => ChatFinishReason.Length,
401+
FinishReason.ReasonTimeLimit => ChatFinishReason.Length,
402+
_ => null
403+
};
404+
405+
internal static UsageDetails? Convert(this SamplingUsage usage) => usage == null ? null : new()
406+
{
407+
InputTokenCount = usage.PromptTokens,
408+
OutputTokenCount = usage.CompletionTokens,
409+
TotalTokenCount = usage.TotalTokens
410+
};
411+
412+
internal static CitationAnnotation FromCitationUrl(this string citationUrl)
413+
{
414+
var url = new Uri(citationUrl);
415+
if (url.Scheme != "collections")
416+
return new CitationAnnotation { Url = url };
417+
418+
// Special-case collection citations so we get better metadata
419+
var collection = url.Host;
420+
var file = url.AbsolutePath[7..];
421+
422+
return new CitationAnnotation
423+
{
424+
AdditionalProperties = new AdditionalPropertiesDictionary
425+
{
426+
{ "collection_id", collection }
427+
},
428+
FileId = file,
429+
ToolName = "collections_search",
430+
Url = new Uri($"collections://{collection}/files/{file}"),
431+
};
432+
}
433+
359434
[JsonSourceGenerationOptions(JsonSerializerDefaults.Web,
360435
UseStringEnumConverter = true,
361436
UnmappedMemberHandling = JsonUnmappedMemberHandling.Skip,

0 commit comments

Comments
 (0)