Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions src/xAI.Protocol/ChatClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using Grpc.Core;
using Grpc.Net.Client;

namespace xAI.Protocol;

partial class Chat
{
partial class ChatClient
{
readonly object? options;

internal ChatClient(ChannelBase channel, object options) : this(channel)
=> this.options = options;

internal object? Options => options;
}
}
17 changes: 17 additions & 0 deletions src/xAI.Protocol/ImageClient.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
using Grpc.Core;
using Grpc.Net.Client;

namespace xAI.Protocol;

partial class Image
{
partial class ImageClient
{
readonly object? options;

internal ImageClient(ChannelBase channel, object options) : this(channel)
=> this.options = options;

internal object? Options => options;
}
}
8 changes: 5 additions & 3 deletions src/xAI.Protocol/xAI.Protocol.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,17 @@
<PackageReference Include="Grpc.Net.ClientFactory" Version="2.76.0" />
<PackageReference Include="Grpc.Tools" Version="2.78.0" PrivateAssets="all" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection.Abstractions" Version="10.0.4" />
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Configuration" Version="10.0.4" />
<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="10.0.4" />
<PackageReference Include="Microsoft.Extensions.Http" Version="10.0.4" />
</ItemGroup>

<ItemGroup>
<None Include="..\..\osmfeula.txt" Link="osmfeula.txt" PackagePath="OSMFEULA.txt" />
<Protobuf Include="*.proto" GrpcServices="Client" />
<Protobuf Include="google\rpc\status.proto" GrpcServices="Client" />
<InternalsVisibleTo Include="xAI"/>
<InternalsVisibleTo Include="xAI.Tests"/>
</ItemGroup>

<Target Name="FixProto" BeforeTargets="Protobuf_BeforeCompile">
Expand Down
42 changes: 41 additions & 1 deletion src/xAI.Tests/ChatClientTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,10 @@
using System.Text.Json.Nodes;
using Azure;
using Devlooped.Extensions.AI;
using Google.Protobuf;
using Grpc.Core;
using Grpc.Core.Interceptors;
using Grpc.Net.Client;
using Microsoft.Extensions.AI;
using Moq;
using OpenAI;
Expand Down Expand Up @@ -786,5 +789,42 @@ public async Task GrokSendsUriContentAsImageUrl()
Assert.Equal(imageUri.ToString(), imageContent.ImageUrl);
}

[Fact]
public async Task GrokPreservesEndUserIdFromClientOptions()
{
GetCompletionsRequest? capturedRequest = null;
var invoker = new Mock<CallInvoker>();
invoker.Setup(x => x.AsyncUnaryCall(
It.IsAny<Method<GetCompletionsRequest, GetChatCompletionResponse>>(),
It.IsAny<string>(),
It.IsAny<CallOptions>(),
It.IsAny<GetCompletionsRequest>()))
.Callback<Method<GetCompletionsRequest, GetChatCompletionResponse>, string, CallOptions, GetCompletionsRequest>(
(_, _, _, req) => capturedRequest = req)
.Returns(CallHelpers.CreateAsyncUnaryCall(new GetChatCompletionResponse
{
Outputs =
{
new CompletionOutput
{
Message = new CompletionMessage { Content = "Hello!" }
}
}
}));

var client = new GrokClient(new TestGrpcChannel(invoker.Object), new GrokClientOptions { EndUserId = "kzu" });
var chat = client.GetChatClient();
var grok = chat.AsIChatClient("grok");
await grok.GetResponseAsync("Hi");

Assert.NotNull(capturedRequest);
Assert.Equal("kzu", capturedRequest.User);
}

class TestGrpcChannel(CallInvoker invoker) : ChannelBase("test")
{
public override CallInvoker CreateCallInvoker() => invoker;
}

record Response(DateOnly Today, string Release, decimal Price);
}
}
30 changes: 30 additions & 0 deletions src/xAI.Tests/ImageGeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -372,4 +372,34 @@ public void GetService_ReturnsImageGeneratorMetadata()
Assert.Equal("xai", metadata.ProviderName);
Assert.Equal("grok-imagine-image", metadata.DefaultModelId);
}

[Fact]
public async Task GrokPreservesEndUserIdFromClientOptions()
{
GenerateImageRequest? capturedRequest = null;
var invoker = new Mock<CallInvoker>();
invoker.Setup(x => x.AsyncUnaryCall(
It.IsAny<Method<GenerateImageRequest, ImageResponse>>(),
It.IsAny<string>(),
It.IsAny<CallOptions>(),
It.IsAny<GenerateImageRequest>()))
.Callback<Method<GenerateImageRequest, ImageResponse>, string, CallOptions, GenerateImageRequest>(
(_, _, _, req) => capturedRequest = req)
.Returns(CallHelpers.CreateAsyncUnaryCall(new ImageResponse
{
}));

var client = new GrokClient(new TestGrpcChannel(invoker.Object), new GrokClientOptions { EndUserId = "kzu" });
var images = client.GetImageClient();
var grok = images.AsIImageGenerator("grok");
await grok.GenerateAsync(new ImageGenerationRequest { Prompt = "Generate a lion." });

Assert.NotNull(capturedRequest);
Assert.Equal("kzu", capturedRequest.User);
}

class TestGrpcChannel(CallInvoker invoker) : ChannelBase("test")
{
public override CallInvoker CreateCallInvoker() => invoker;
}
}
1 change: 1 addition & 0 deletions src/xAI.Tests/xAI.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
<Nullable>enable</Nullable>
<IsPackable>false</IsPackable>
<NoWarn>MEAI001;xAI001;$(NoWarn)</NoWarn>
<LangVersion>latest</LangVersion>
</PropertyGroup>

<ItemGroup>
Expand Down
6 changes: 3 additions & 3 deletions src/xAI/GrokChatClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@ class GrokChatClient : IChatClient
readonly string defaultModelId;
readonly GrokClientOptions clientOptions;

internal GrokChatClient(GrpcChannel channel, GrokClientOptions clientOptions, string defaultModelId)
internal GrokChatClient(ChannelBase channel, GrokClientOptions clientOptions, string defaultModelId)
: this(new ChatClient(channel), clientOptions, defaultModelId)
{ }

/// <summary>
/// Test constructor.
/// </summary>
internal GrokChatClient(ChatClient client, string defaultModelId)
: this(client, new(), defaultModelId)
: this(client, client.Options as GrokClientOptions ?? new(), defaultModelId)
{ }

GrokChatClient(ChatClient client, GrokClientOptions clientOptions, string defaultModelId)
internal GrokChatClient(ChatClient client, GrokClientOptions clientOptions, string defaultModelId)
{
this.client = client;
this.clientOptions = clientOptions;
Expand Down
18 changes: 11 additions & 7 deletions src/xAI/GrokClient.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System.Collections.Concurrent;
using System.Diagnostics;
using System.Net.Http.Headers;
using Grpc.Core;
using Grpc.Net.Client;
using Microsoft.Extensions.Http;
using Polly;
Expand All @@ -14,11 +15,14 @@ namespace xAI;
/// <param name="options">The options used to configure the client.</param>
public sealed class GrokClient(string apiKey, GrokClientOptions options) : IDisposable
{
static readonly ConcurrentDictionary<(Uri, string), GrpcChannel> channels = [];
static readonly ConcurrentDictionary<(Uri, string), ChannelBase> channels = [];

/// <summary>Initializes a new instance of the <see cref="GrokClient"/> class with default options.</summary>
public GrokClient(string apiKey) : this(apiKey, new GrokClientOptions()) { }

internal GrokClient(ChannelBase channel, GrokClientOptions options) : this("", options)
=> channels[(options.Endpoint, "")] = channel;

/// <summary>Gets the API key used for authentication.</summary>
public string ApiKey { get; } = apiKey;

Expand All @@ -32,7 +36,7 @@ public sealed class GrokClient(string apiKey, GrokClientOptions options) : IDisp
public Auth.AuthClient GetAuthClient() => new(Channel);

/// <summary>Gets a new instance of <see cref="Chat.ChatClient"/> that reuses the client configuration details provided to the <see cref="GrokClient"/> instance.</summary>
public Chat.ChatClient GetChatClient() => new(Channel);
public Chat.ChatClient GetChatClient() => new(Channel, Options);

/// <summary>Gets a new instance of <see cref="Documents.DocumentsClient"/> that reuses the client configuration details provided to the <see cref="GrokClient"/> instance.</summary>
public Documents.DocumentsClient GetDocumentsClient() => new(Channel);
Expand All @@ -41,15 +45,15 @@ public sealed class GrokClient(string apiKey, GrokClientOptions options) : IDisp
public Embedder.EmbedderClient GetEmbedderClient() => new(Channel);

/// <summary>Gets a new instance of <see cref="Image.ImageClient"/> that reuses the client configuration details provided to the <see cref="GrokClient"/> instance.</summary>
public Image.ImageClient GetImageClient() => new(Channel);
public Image.ImageClient GetImageClient() => new(Channel, Options);

/// <summary>Gets a new instance of <see cref="Models.ModelsClient"/> that reuses the client configuration details provided to the <see cref="GrokClient"/> instance.</summary>
public Models.ModelsClient GetModelsClient() => new(Channel);

/// <summary>Gets a new instance of <see cref="Tokenize.TokenizeClient"/> that reuses the client configuration details provided to the <see cref="GrokClient"/> instance.</summary>
public Tokenize.TokenizeClient GetTokenizeClient() => new(Channel);

internal GrpcChannel Channel => channels.GetOrAdd((Endpoint, ApiKey), key =>
internal ChannelBase Channel => channels.GetOrAdd((Endpoint, ApiKey), key =>
{
var inner = Options.ChannelOptions?.HttpHandler;
if (inner == null)
Expand All @@ -59,9 +63,9 @@ public sealed class GrokClient(string apiKey, GrokClientOptions options) : IDisp
var retryPolicy = HttpPolicyExtensions
.HandleTransientHttpError()
.Or<Grpc.Core.RpcException>(ex =>
ex.StatusCode is Grpc.Core.StatusCode.Unavailable or
Grpc.Core.StatusCode.DeadlineExceeded or
Grpc.Core.StatusCode.Internal &&
ex.StatusCode is StatusCode.Unavailable or
StatusCode.DeadlineExceeded or
StatusCode.Internal &&
ex.Status.Detail?.Contains("504") == true ||
ex.Status.Detail?.Contains("INTERNAL_ERROR") == true)
.WaitAndRetryAsync(
Expand Down
12 changes: 9 additions & 3 deletions src/xAI/GrokImageGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using Grpc.Core;
using Grpc.Net.Client;
using Microsoft.Extensions.AI;
using xAI.Protocol;
Expand Down Expand Up @@ -27,22 +28,24 @@ sealed class GrokImageGenerator : IImageGenerator

readonly ImageGeneratorMetadata metadata;
readonly ImageClient imageClient;
readonly GrokClientOptions clientOptions;
readonly string defaultModelId;

internal GrokImageGenerator(GrpcChannel channel, GrokClientOptions clientOptions, string defaultModelId)
: this(new ImageClient(channel), clientOptions, defaultModelId)
internal GrokImageGenerator(ChannelBase channel, GrokClientOptions options, string defaultModelId)
: this(new ImageClient(channel, options), options, defaultModelId)
{ }

/// <summary>
/// Test constructor.
/// </summary>
internal GrokImageGenerator(ImageClient imageClient, string defaultModelId)
: this(imageClient, new(), defaultModelId)
: this(imageClient, imageClient.Options as GrokClientOptions ?? new(), defaultModelId)
{ }

GrokImageGenerator(ImageClient imageClient, GrokClientOptions clientOptions, string defaultModelId)
{
this.imageClient = imageClient;
this.clientOptions = clientOptions;
this.defaultModelId = defaultModelId;
metadata = new ImageGeneratorMetadata("xai", clientOptions.Endpoint, defaultModelId);
}
Expand All @@ -59,6 +62,9 @@ public async Task<ImageGenerationResponse> GenerateAsync(
Model = options?.ModelId ?? defaultModelId,
};

if (clientOptions.EndUserId is { } user)
protocolRequest.User = clientOptions.EndUserId;

if (options?.Count is { } count)
protocolRequest.N = count;

Expand Down
Loading