Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions src/xAI.Tests/ChatClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@ public async Task GrokInvokesTools()
{
var messages = new Chat()
{
{ "system", "You are a bot that invokes the tool get_date when asked for the date." },
{ "user", "What day is today?" },
};

Expand Down Expand Up @@ -100,6 +99,12 @@ public async Task GrokInvokesTools()
.Any(x => x.Name == "get_date");

Assert.True(getdate);

messages.AddRange(response.Messages);
messages.Add("user", "What date is tomorrow then?");

var tomorrow = await chat.GetResponseAsync<DateOnly>(messages, options);
Assert.Equal(DateOnly.FromDateTime(DateTime.Today.AddDays(1)), tomorrow.Result);
}

[SecretsFact("XAI_API_KEY")]
Expand Down Expand Up @@ -136,7 +141,7 @@ public async Task GrokInvokesToolAndSearch()
{
var messages = new Chat()
{
{ "system", "You use Nasdaq for stocks news and prices." },
{ "system", "You use Nasdaq for stocks news and prices, get_date for getting today's date." },
{ "user", "What's Tesla stock worth today?" },
};

Expand All @@ -160,6 +165,8 @@ public async Task GrokInvokesToolAndSearch()

var response = await grok.GetResponseAsync(messages, options);

Assert.False(getDateCalls == 0, "Expected the get_date tool to be called at least once.");

// The get_date result shows up as a tool role
Assert.Contains(response.Messages, x => x.Role == ChatRole.Tool);

Expand All @@ -174,7 +181,6 @@ public async Task GrokInvokesToolAndSearch()
Assert.Equal(1, getDateCalls);
Assert.Contains(urls, x => x.Host.EndsWith("nasdaq.com"));
Assert.Contains(urls, x => x.PathAndQuery.Contains("/TSLA"));
Assert.Equal(options.ModelId, response.ModelId);

var calls = response.Messages
.SelectMany(x => x.Contents.Select(x => x.RawRepresentation as xAI.Protocol.ToolCall))
Expand Down Expand Up @@ -213,10 +219,11 @@ public async Task GrokInvokesSpecificSearchUrl()
.SelectMany(x => x.Annotations ?? [])
.OfType<CitationAnnotation>()
.Where(x => x.Url != null)
.Select(x => x.Url!.AbsoluteUri)
.Select(x => x.Url!.Host)
.Distinct()
.ToList();

Assert.Contains("https://partediario.catedralaltapatagonia.com/partediario/", citations);
Assert.Contains("catedralaltapatagonia.com", citations);
}

[SecretsFact("XAI_API_KEY")]
Expand Down Expand Up @@ -416,6 +423,7 @@ public async Task GrokInvokesHostedCollectionSearch()
{
var messages = new Chat()
{
{ "system", "Utilizar collection/file search SIEMPRE para buscar informacion legal." },
{ "user", "¿Cuál es el monto exacto del rango de la multa por inasistencia injustificada a la audiencia señalada por el juez en el proceso sucesorio, según lo establecido en el Artículo 691 del Código Procesal Civil y Comercial de la Nación (Ley 17.454)?" },
};

Expand All @@ -431,7 +439,6 @@ public async Task GrokInvokesHostedCollectionSearch()
var response = await grok.GetResponseAsync(messages, options);
var text = response.Text;

Assert.Contains("11,74", text);
Assert.Contains(response.Messages
.SelectMany(x => x.Contents)
.OfType<CollectionSearchToolCallContent>()
Expand All @@ -440,6 +447,8 @@ public async Task GrokInvokesHostedCollectionSearch()
// No actual search results content since we didn't specify it in Include
Assert.Empty(response.Messages.SelectMany(x => x.Contents).OfType<CollectionSearchToolResultContent>());

Assert.Contains("11,74", text);

options.Include = [IncludeOption.CollectionsSearchCallOutput];
response = await grok.GetResponseAsync(messages, options);

Expand Down Expand Up @@ -606,7 +615,7 @@ public async Task GrokCustomFactoryInvokedFromOptions()
}
}));

var grok = new GrokChatClient(client.Object, "grok-4-1-fast");
var grok = new GrokChatClient(client.Object, "grok-4-1-fast-non-reasoning");
var response = await grok.GetResponseAsync("Hi, my internet alias is kzu. Lookup my real full name online.",
new GrokChatOptions
{
Expand Down Expand Up @@ -648,7 +657,7 @@ public async Task GrokSetsToolCallIdFromFunctionResultContent()
}
}));

var grok = new GrokChatClient(client.Object, "grok-4-1-fast");
var grok = new GrokChatClient(client.Object, "grok-4-1-fast-non-reasoning");
var messages = new List<ChatMessage>
{
new(ChatRole.User, "What's the time?"),
Expand Down Expand Up @@ -685,7 +694,7 @@ public async Task GrokSetsToolCallIdOnlyWhenCallIdIsProvided()
}
}));

var grok = new GrokChatClient(client.Object, "grok-4-1-fast");
var grok = new GrokChatClient(client.Object, "grok-4-1-fast-non-reasoning");
var messages = new List<ChatMessage>
{
new(ChatRole.User, "What's the time?"),
Expand Down Expand Up @@ -722,7 +731,7 @@ public async Task GrokSendsDataContentAsBase64ImageUrl()
}));

var imageBytes = new byte[] { 1, 2, 3, 4, 5 };
var grok = new GrokChatClient(client.Object, "grok-4-1-fast");
var grok = new GrokChatClient(client.Object, "grok-4-1-fast-non-reasoning");
var messages = new List<ChatMessage>
{
new(ChatRole.User, [new TextContent("What do you see?"), new DataContent(imageBytes, "image/png")]),
Expand Down Expand Up @@ -759,7 +768,7 @@ public async Task GrokSendsUriContentAsImageUrl()
}));

var imageUri = new Uri("https://example.com/photo.jpg");
var grok = new GrokChatClient(client.Object, "grok-4-1-fast");
var grok = new GrokChatClient(client.Object, "grok-4-1-fast-non-reasoning");
var messages = new List<ChatMessage>
{
new(ChatRole.User, [new TextContent("What do you see?"), new UriContent(imageUri, "image/jpeg")]),
Expand Down
10 changes: 6 additions & 4 deletions src/xAI/GrokChatClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -277,12 +277,14 @@ codeResult.RawRepresentation is ToolCall codeToolCall &&
if (tool is not null) request.Tools.Add(tool);
}

if (options?.ResponseFormat is ChatResponseFormatJson)
if (options?.ResponseFormat is ChatResponseFormatJson jsonFormat)
{
request.ResponseFormat = new ResponseFormat
request.ResponseFormat = new ResponseFormat { FormatType = FormatType.JsonObject };
if (jsonFormat.Schema != null)
{
FormatType = FormatType.JsonObject
};
request.ResponseFormat.FormatType = FormatType.JsonSchema;
request.ResponseFormat.Schema = jsonFormat.Schema?.ToString();
}
}

return request;
Expand Down
Loading