diff --git a/src/StructId.Analyzer/ConstructorGenerator.cs b/src/StructId.Analyzer/ConstructorGenerator.cs index d134f81..f348b38 100644 --- a/src/StructId.Analyzer/ConstructorGenerator.cs +++ b/src/StructId.Analyzer/ConstructorGenerator.cs @@ -42,9 +42,13 @@ void GenerateCode(SourceProductionContext context, INamedTypeSymbol symbol) $$""" [System.CodeDom.Compiler.GeneratedCode("StructId", "{{ThisAssembly.Info.InformationalVersion}}")] - partial {{kind}} {{symbol.Name}}({{type}} Value); + partial {{kind}} {{symbol.Name}}({{type}} Value) + { + public static implicit operator {{type}}({{symbol.Name}} id) => id.Value; + public static explicit operator {{symbol.Name}}({{type}} value) => new(value); + } """); - context.AddSource($"{symbol.ToFileName()}.ctor.cs", output.ToString()); + context.AddSource($"{symbol.ToFileName()}.cs", output.ToString()); } } diff --git a/src/StructId.Analyzer/EntityFrameworkSelector.sbn b/src/StructId.Analyzer/EntityFrameworkSelector.sbn index 5208593..70fdd03 100644 --- a/src/StructId.Analyzer/EntityFrameworkSelector.sbn +++ b/src/StructId.Analyzer/EntityFrameworkSelector.sbn @@ -36,6 +36,9 @@ public static class StructIdDbContextOptionsBuilderExtensions foreach (var converter in baseConverters) yield return converter; + modelClrType = Unwrap(modelClrType); + providerClrType = Unwrap(providerClrType); + {{~ for id in Ids ~}} if (modelClrType == typeof({{ id.TSelf }})) yield return converters.GetOrAdd((modelClrType, providerClrType), key => new ValueConverterInfo( @@ -44,5 +47,13 @@ public static class StructIdDbContextOptionsBuilderExtensions {{~ end ~}} } + + static Type Unwrap(Type? type) + { + if (type is null) + return null; + + return Nullable.GetUnderlyingType(type) ?? type; + } } } \ No newline at end of file diff --git a/src/StructId.Analyzer/NewableGenerator.cs b/src/StructId.Analyzer/NewableGenerator.cs index 5164be0..21d58fc 100644 --- a/src/StructId.Analyzer/NewableGenerator.cs +++ b/src/StructId.Analyzer/NewableGenerator.cs @@ -1,4 +1,5 @@ -using Microsoft.CodeAnalysis; +using System; +using Microsoft.CodeAnalysis; namespace StructId; @@ -7,4 +8,19 @@ public class NewableGenerator() : TemplateGenerator( "System.Object", ThisAssembly.Resources.Templates.Newable.Text, ThisAssembly.Resources.Templates.NewableT.Text, - ReferenceCheck.TypeExists); \ No newline at end of file + ReferenceCheck.TypeExists) +{ + protected override IncrementalValuesProvider OnInitialize(IncrementalGeneratorInitializationContext context, IncrementalValuesProvider source) + { + var args = base.OnInitialize(context, source); + + context.RegisterSourceOutput( + args.Where(x => x.ValueType.ToFullName() == "System.Guid"), + GenerateGuidCode); + + return args; + } + + void GenerateGuidCode(SourceProductionContext context, TemplateArgs args) => AddFromTemplate( + context, args, $"{args.StructId.ToFileName()}.Guid.cs", ThisAssembly.Resources.Templates.NewableGuid.Text); +} \ No newline at end of file diff --git a/src/StructId.Analyzer/TemplateGenerator.cs b/src/StructId.Analyzer/TemplateGenerator.cs index 02b6afb..340e184 100644 --- a/src/StructId.Analyzer/TemplateGenerator.cs +++ b/src/StructId.Analyzer/TemplateGenerator.cs @@ -1,5 +1,4 @@ -using System.Diagnostics; -using System.Linq; +using System.Linq; using System.Text; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; @@ -67,17 +66,18 @@ public virtual void Initialize(IncrementalGeneratorInitializationContext context protected virtual IncrementalValuesProvider OnInitialize(IncrementalGeneratorInitializationContext context, IncrementalValuesProvider source) => source; - void GenerateCode(SourceProductionContext context, TemplateArgs args) + void GenerateCode(SourceProductionContext context, TemplateArgs args) => AddFromTemplate( + context, args, $"{args.StructId.ToFileName()}.cs", + args.ValueType.Equals(args.StringType, SymbolEqualityComparer.Default) ? stringTemplate : typeTemplate); + + protected static void AddFromTemplate(SourceProductionContext context, TemplateArgs args, string hintName, string template) { var ns = args.StructId.ContainingNamespace.Equals(args.StructId.ContainingModule.GlobalNamespace, SymbolEqualityComparer.Default) ? null : args.StructId.ContainingNamespace.ToDisplayString(); - var template = args.ValueType.Equals(args.StringType, SymbolEqualityComparer.Default) - ? stringTemplate : typeTemplate; - // replace tokens in the template - template = template + var replaced = template // Adjust to current target namespace .Replace("namespace StructId;", $"namespace {args.TargetNamespace};") .Replace("using StructId;", $"using {args.TargetNamespace};") @@ -87,19 +87,21 @@ void GenerateCode(SourceProductionContext context, TemplateArgs args) .Replace("TId", args.ValueType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat)); // parse template into a C# compilation unit - var parseable = CSharpSyntaxTree.ParseText(template).GetCompilationUnitRoot(); + var syntax = CSharpSyntaxTree.ParseText(replaced).GetCompilationUnitRoot(); // if we got a ns, move all members after a file-scoped namespace declaration if (ns != null) { - var members = parseable.Members; + var members = syntax.Members; var fsns = FileScopedNamespaceDeclaration(ParseName(ns).WithLeadingTrivia(Whitespace(" "))) .WithLeadingTrivia(LineFeed) .WithTrailingTrivia(LineFeed) .WithMembers(members); - parseable = parseable.WithMembers(SingletonList(fsns)); + syntax = syntax.WithMembers(SingletonList(fsns)); } - context.AddSource($"{args.StructId.ToFileName()}.cs", SourceText.From(parseable.ToFullString(), Encoding.UTF8)); + var output = syntax.ToFullString(); + + context.AddSource(hintName, SourceText.From(output, Encoding.UTF8)); } } diff --git a/src/StructId.FunctionalTests/Functional.cs b/src/StructId.FunctionalTests/Functional.cs index cf13d95..86cf219 100644 --- a/src/StructId.FunctionalTests/Functional.cs +++ b/src/StructId.FunctionalTests/Functional.cs @@ -1,6 +1,7 @@ using Dapper; using Microsoft.Data.Sqlite; using Microsoft.EntityFrameworkCore; +using Microsoft.Extensions.Logging; using Newtonsoft.Json; using Newtonsoft.Json.Linq; @@ -14,16 +15,10 @@ public record Product(ProductId Id, string Name); public record Wallet(WalletId Id, string Alias); public record User(UserId Id, string Name, Wallet Wallet); -partial record struct ProductId -{ - public static implicit operator Guid(ProductId id) => id.Value; - public static explicit operator ProductId(Guid value) => new(value); -} - -public class FunctionalTests +public class FunctionalTests(ITestOutputHelper output) { [Fact] - public void Test() + public void EqualityTest() { var guid = Guid.NewGuid(); var id1 = new ProductId(guid); @@ -33,6 +28,17 @@ public void Test() Assert.True(id1 == id2); } + [Fact] + public void ImplicitAndExplicitCast() + { + var guid = Guid.NewGuid(); + var id = new ProductId(guid); + Guid guid2 = id; + var id2 = (ProductId)guid2; + Assert.Equal(guid, guid2); + Assert.Equal(id, id2); + } + [Fact] public void Newtonsoft() { @@ -61,20 +67,32 @@ public void Newtonsoft() public void EntityFramework() { var options = new DbContextOptionsBuilder() - .UseSqlite("Data Source=ef.db") .UseStructId() + .UseSqlite("Data Source=ef.db") + // Uncomment to see full SQL being run + // .EnableSensitiveDataLogging() + // .UseLoggerFactory(new LoggerFactory(output)) .Options; using var context = new Context(options); + var id = ProductId.New(); + var product = new Product(new ProductId(id), "Product"); + // Seed data - var productId = Guid.NewGuid(); - var product = new Product(new ProductId(productId), "Product"); + context.Products.Add(new Product(ProductId.New(), "Product1")); context.Products.Add(product); + context.Products.Add(new Product(ProductId.New(), "Product2")); + context.SaveChanges(); - var product2 = context.Products.First(x => productId == product.Id); + var product2 = context.Products.Where(x => x.Id == id).FirstOrDefault(); Assert.Equal(product, product2); + + Guid guid = id; + + var product3 = context.Products.FirstOrDefault(x => guid == x.Id); + Assert.Equal(product, product3); } [Fact] @@ -88,7 +106,11 @@ public void Dapper() // Seed data var productId = Guid.NewGuid(); var product = new Product(new ProductId(productId), "Product"); + + connection.Execute("INSERT INTO Products (Id, Name) VALUES (@Id, @Name)", new Product(ProductId.New(), "Product1")); connection.Execute("INSERT INTO Products (Id, Name) VALUES (@Id, @Name)", product); + connection.Execute("INSERT INTO Products (Id, Name) VALUES (@Id, @Name)", new Product(ProductId.New(), "Product2")); + var product2 = connection.QueryFirst("SELECT * FROM Products WHERE Id = @Id", new { Id = productId }); Assert.Equal(product, product2); } @@ -97,6 +119,20 @@ public class Context : DbContext { public Context(DbContextOptions options) : base(options) { } public DbSet Products { get; set; } = null!; - protected override void OnModelCreating(ModelBuilder model) => model.Entity().HasKey(e => e.Id); + } + + class LoggerFactory(ITestOutputHelper output) : ILoggerFactory + { + public void AddProvider(ILoggerProvider provider) => throw new NotImplementedException(); + public ILogger CreateLogger(string categoryName) => new Logger(output); + public void Dispose() { } + + class Logger(ITestOutputHelper output) : ILogger + { + public IDisposable? BeginScope(TState state) where TState : notnull => null; + public bool IsEnabled(LogLevel logLevel) => true; + public void Log(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func formatter) + => output.WriteLine(formatter(state, exception)); + } } } \ No newline at end of file diff --git a/src/StructId/Templates/EntityFramework.cs b/src/StructId/Templates/EntityFramework.cs index c5e7f8a..c1b5b9f 100644 --- a/src/StructId/Templates/EntityFramework.cs +++ b/src/StructId/Templates/EntityFramework.cs @@ -14,11 +14,6 @@ public partial class EntityFrameworkValueConverter : ValueConverter public EntityFrameworkValueConverter() : this(null) { } public EntityFrameworkValueConverter(ConverterMappingHints? mappingHints = null) - : base( - id => id.Value, - value => TSelf.New(value), - mappingHints - ) - { } + : base(id => id.Value, value => TSelf.New(value), mappingHints) { } } } diff --git a/src/StructId/Templates/JsonConverter.cs b/src/StructId/Templates/JsonConverter.cs index 858eba6..eca65fb 100644 --- a/src/StructId/Templates/JsonConverter.cs +++ b/src/StructId/Templates/JsonConverter.cs @@ -4,9 +4,7 @@ using System.Text.Json.Serialization; using StructId; -#if NET7_0_OR_GREATER [JsonConverter(typeof(StructIdConverters.SystemTextJsonConverter))] -#endif readonly partial record struct Self { } \ No newline at end of file diff --git a/src/StructId/Templates/JsonConverterT.cs b/src/StructId/Templates/JsonConverterT.cs index 74c13f8..a1fb892 100644 --- a/src/StructId/Templates/JsonConverterT.cs +++ b/src/StructId/Templates/JsonConverterT.cs @@ -4,9 +4,7 @@ using System.Text.Json.Serialization; using StructId; -#if NET7_0_OR_GREATER [JsonConverter(typeof(StructIdConverters.SystemTextJsonConverter))] -#endif readonly partial record struct TSelf { } \ No newline at end of file diff --git a/src/StructId/Templates/NewableGuid.cs b/src/StructId/Templates/NewableGuid.cs new file mode 100644 index 0000000..bf4123e --- /dev/null +++ b/src/StructId/Templates/NewableGuid.cs @@ -0,0 +1,11 @@ +// + +using System; + +readonly partial record struct TSelf +{ + /// + /// Creates a new instance of with a . + /// + public static TSelf New() => new(Guid.NewGuid()); +} \ No newline at end of file diff --git a/src/StructId/Templates/TSelf.cs b/src/StructId/Templates/TSelf.cs index 68721d7..0a8be46 100644 --- a/src/StructId/Templates/TSelf.cs +++ b/src/StructId/Templates/TSelf.cs @@ -27,4 +27,5 @@ readonly partial record struct Self(string Value) : IStructId readonly partial record struct TSelf(TId Value) : IStructId { + public TSelf(Guid _) : this(default(TId)) { } }