Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Reflection;
using System.Text.Json.Nodes;
using System.Threading;

#pragma warning disable S1067 // Expressions should not be too complex

Expand All @@ -23,6 +25,18 @@ public sealed class AIJsonSchemaCreateOptions : IEquatable<AIJsonSchemaCreateOpt
/// </summary>
public Func<AIJsonSchemaCreateContext, JsonNode, JsonNode>? TransformSchemaNode { get; init; }

/// <summary>
/// Gets a callback that is invoked for every parameter in the <see cref="MethodBase"/> provided to
/// <see cref="AIJsonUtilities.CreateFunctionJsonSchema"/> in order to determine whether it should
/// be included in the generated schema.
/// </summary>
/// <remarks>
/// By default, when <see cref="IncludeParameter"/> is <see langword="null"/>, all parameters other
/// than those of type <see cref="CancellationToken"/> are included in the generated schema.
/// The delegate is not invoked for <see cref="CancellationToken"/> parameters.
/// </remarks>
public Func<ParameterInfo, bool>? IncludeParameter { get; init; }

/// <summary>
/// Gets a value indicating whether to include the type keyword in inferred schemas for .NET enums.
/// </summary>
Expand All @@ -44,19 +58,24 @@ public sealed class AIJsonSchemaCreateOptions : IEquatable<AIJsonSchemaCreateOpt
public bool RequireAllProperties { get; init; } = true;

/// <inheritdoc/>
public bool Equals(AIJsonSchemaCreateOptions? other)
{
return other is not null &&
TransformSchemaNode == other.TransformSchemaNode &&
IncludeTypeInEnumSchemas == other.IncludeTypeInEnumSchemas &&
DisallowAdditionalProperties == other.DisallowAdditionalProperties &&
IncludeSchemaKeyword == other.IncludeSchemaKeyword &&
RequireAllProperties == other.RequireAllProperties;
}
public bool Equals(AIJsonSchemaCreateOptions? other) =>
other is not null &&
TransformSchemaNode == other.TransformSchemaNode &&
IncludeParameter == other.IncludeParameter &&
IncludeTypeInEnumSchemas == other.IncludeTypeInEnumSchemas &&
DisallowAdditionalProperties == other.DisallowAdditionalProperties &&
IncludeSchemaKeyword == other.IncludeSchemaKeyword &&
RequireAllProperties == other.RequireAllProperties;

/// <inheritdoc />
public override bool Equals(object? obj) => obj is AIJsonSchemaCreateOptions other && Equals(other);

/// <inheritdoc />
public override int GetHashCode() => (TransformSchemaNode, IncludeTypeInEnumSchemas, DisallowAdditionalProperties, IncludeSchemaKeyword, RequireAllProperties).GetHashCode();
public override int GetHashCode() =>
(TransformSchemaNode,
IncludeParameter,
IncludeTypeInEnumSchemas,
DisallowAdditionalProperties,
IncludeSchemaKeyword,
RequireAllProperties).GetHashCode();
}
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ public static JsonElement CreateFunctionJsonSchema(
continue;
}

if (inferenceOptions.IncludeParameter is { } includeParameter &&
!includeParameter(parameter))
{
// Skip parameters that should not be included in the schema.
// By default, all parameters are included.
continue;
}

JsonNode parameterSchema = CreateJsonSchemaCore(
type: parameter.ParameterType,
parameterName: parameter.Name,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using System.Reflection;
Expand All @@ -10,6 +11,7 @@
using System.Text.Json.Nodes;
using System.Text.Json.Serialization;
using System.Text.Json.Serialization.Metadata;
using System.Threading;
using Microsoft.Extensions.AI.JsonSchemaExporter;
using Xunit;

Expand Down Expand Up @@ -86,6 +88,12 @@ public static void AIJsonSchemaCreateOptions_UsesStructuralEquality()
property.SetValue(options2, transformer);
break;

case null when property.PropertyType == typeof(Func<ParameterInfo, bool>):
Func<ParameterInfo, bool> includeParameter = static (parameter) => true;
property.SetValue(options1, includeParameter);
property.SetValue(options2, includeParameter);
break;

default:
Assert.Fail($"Unexpected property type: {property.PropertyType}");
break;
Expand Down Expand Up @@ -443,6 +451,31 @@ public static void HashData_Idempotent()
}
}

[Fact]
public static void CreateFunctionJsonSchema_InvokesIncludeParameterCallbackForEveryParameter()
{
Delegate method = (int first, string second, bool third, CancellationToken fourth, DateTime fifth) => { };

List<string?> names = [];
JsonElement schema = AIJsonUtilities.CreateFunctionJsonSchema(method.Method, inferenceOptions: new()
{
IncludeParameter = p =>
{
names.Add(p.Name);
return p.Name is "first" or "fifth";
}
});

Assert.Equal(["first", "second", "third", "fifth"], names);

string schemaString = schema.ToString();
Assert.Contains("first", schemaString);
Assert.DoesNotContain("second", schemaString);
Assert.DoesNotContain("third", schemaString);
Assert.DoesNotContain("fourth", schemaString);
Assert.Contains("fifth", schemaString);
}

private class DerivedAIContent : AIContent
{
public int DerivedValue { get; set; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Collections.Generic;
using System.ComponentModel;
using System.Reflection;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Xunit;
Expand Down Expand Up @@ -172,4 +173,29 @@ public void AIFunctionFactoryOptions_DefaultValues()
Assert.Null(options.SerializerOptions);
Assert.Null(options.JsonSchemaCreateOptions);
}

[Fact]
public async Task AIFunctionFactoryOptions_SupportsSkippingParameters()
{
AIFunction func = AIFunctionFactory.Create(
(string firstParameter, int secondParameter) => firstParameter + secondParameter,
new()
{
JsonSchemaCreateOptions = new()
{
IncludeParameter = p => p.Name != "firstParameter",
}
});

Assert.DoesNotContain("firstParameter", func.JsonSchema.ToString());
Assert.Contains("secondParameter", func.JsonSchema.ToString());

JsonElement? result = (JsonElement?)await func.InvokeAsync(new Dictionary<string, object?>
{
["firstParameter"] = "test",
["secondParameter"] = 42
});
Assert.NotNull(result);
Assert.Contains("test42", result.ToString());
}
}
Loading