Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/StructId.Analyzer/ConstructorGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,13 @@ void GenerateCode(SourceProductionContext context, INamedTypeSymbol symbol)
$$"""

[System.CodeDom.Compiler.GeneratedCode("StructId", "{{ThisAssembly.Info.InformationalVersion}}")]
partial {{kind}} {{symbol.Name}}({{type}} Value);
partial {{kind}} {{symbol.Name}}({{type}} Value)
{
public static implicit operator {{type}}({{symbol.Name}} id) => id.Value;
public static explicit operator {{symbol.Name}}({{type}} value) => new(value);
}
""");

context.AddSource($"{symbol.ToFileName()}.ctor.cs", output.ToString());
context.AddSource($"{symbol.ToFileName()}.cs", output.ToString());
}
}
11 changes: 11 additions & 0 deletions src/StructId.Analyzer/EntityFrameworkSelector.sbn
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ public static class StructIdDbContextOptionsBuilderExtensions
foreach (var converter in baseConverters)
yield return converter;

modelClrType = Unwrap(modelClrType);
providerClrType = Unwrap(providerClrType);

{{~ for id in Ids ~}}
if (modelClrType == typeof({{ id.TSelf }}))
yield return converters.GetOrAdd((modelClrType, providerClrType), key => new ValueConverterInfo(
Expand All @@ -44,5 +47,13 @@ public static class StructIdDbContextOptionsBuilderExtensions

{{~ end ~}}
}

static Type Unwrap(Type? type)
{
if (type is null)
return null;

return Nullable.GetUnderlyingType(type) ?? type;
}
}
}
20 changes: 18 additions & 2 deletions src/StructId.Analyzer/NewableGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Microsoft.CodeAnalysis;
using System;
using Microsoft.CodeAnalysis;

namespace StructId;

Expand All @@ -7,4 +8,19 @@ public class NewableGenerator() : TemplateGenerator(
"System.Object",
ThisAssembly.Resources.Templates.Newable.Text,
ThisAssembly.Resources.Templates.NewableT.Text,
ReferenceCheck.TypeExists);
ReferenceCheck.TypeExists)
{
protected override IncrementalValuesProvider<TemplateArgs> OnInitialize(IncrementalGeneratorInitializationContext context, IncrementalValuesProvider<TemplateArgs> source)
{
var args = base.OnInitialize(context, source);

context.RegisterSourceOutput(
args.Where(x => x.ValueType.ToFullName() == "System.Guid"),
GenerateGuidCode);

return args;
}

void GenerateGuidCode(SourceProductionContext context, TemplateArgs args) => AddFromTemplate(
context, args, $"{args.StructId.ToFileName()}.Guid.cs", ThisAssembly.Resources.Templates.NewableGuid.Text);
}
24 changes: 13 additions & 11 deletions src/StructId.Analyzer/TemplateGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System.Diagnostics;
using System.Linq;
using System.Linq;
using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
Expand Down Expand Up @@ -67,17 +66,18 @@ public virtual void Initialize(IncrementalGeneratorInitializationContext context

protected virtual IncrementalValuesProvider<TemplateArgs> OnInitialize(IncrementalGeneratorInitializationContext context, IncrementalValuesProvider<TemplateArgs> source) => source;

void GenerateCode(SourceProductionContext context, TemplateArgs args)
void GenerateCode(SourceProductionContext context, TemplateArgs args) => AddFromTemplate(
context, args, $"{args.StructId.ToFileName()}.cs",
args.ValueType.Equals(args.StringType, SymbolEqualityComparer.Default) ? stringTemplate : typeTemplate);

protected static void AddFromTemplate(SourceProductionContext context, TemplateArgs args, string hintName, string template)
{
var ns = args.StructId.ContainingNamespace.Equals(args.StructId.ContainingModule.GlobalNamespace, SymbolEqualityComparer.Default)
? null
: args.StructId.ContainingNamespace.ToDisplayString();

var template = args.ValueType.Equals(args.StringType, SymbolEqualityComparer.Default)
? stringTemplate : typeTemplate;

// replace tokens in the template
template = template
var replaced = template
// Adjust to current target namespace
.Replace("namespace StructId;", $"namespace {args.TargetNamespace};")
.Replace("using StructId;", $"using {args.TargetNamespace};")
Expand All @@ -87,19 +87,21 @@ void GenerateCode(SourceProductionContext context, TemplateArgs args)
.Replace("TId", args.ValueType.ToDisplayString(SymbolDisplayFormat.FullyQualifiedFormat));

// parse template into a C# compilation unit
var parseable = CSharpSyntaxTree.ParseText(template).GetCompilationUnitRoot();
var syntax = CSharpSyntaxTree.ParseText(replaced).GetCompilationUnitRoot();

// if we got a ns, move all members after a file-scoped namespace declaration
if (ns != null)
{
var members = parseable.Members;
var members = syntax.Members;
var fsns = FileScopedNamespaceDeclaration(ParseName(ns).WithLeadingTrivia(Whitespace(" ")))
.WithLeadingTrivia(LineFeed)
.WithTrailingTrivia(LineFeed)
.WithMembers(members);
parseable = parseable.WithMembers(SingletonList<MemberDeclarationSyntax>(fsns));
syntax = syntax.WithMembers(SingletonList<MemberDeclarationSyntax>(fsns));
}

context.AddSource($"{args.StructId.ToFileName()}.cs", SourceText.From(parseable.ToFullString(), Encoding.UTF8));
var output = syntax.ToFullString();

context.AddSource(hintName, SourceText.From(output, Encoding.UTF8));
}
}
62 changes: 49 additions & 13 deletions src/StructId.FunctionalTests/Functional.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Dapper;
using Microsoft.Data.Sqlite;
using Microsoft.EntityFrameworkCore;
using Microsoft.Extensions.Logging;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;

Expand All @@ -14,16 +15,10 @@ public record Product(ProductId Id, string Name);
public record Wallet(WalletId Id, string Alias);
public record User(UserId Id, string Name, Wallet Wallet);

partial record struct ProductId
{
public static implicit operator Guid(ProductId id) => id.Value;
public static explicit operator ProductId(Guid value) => new(value);
}

public class FunctionalTests
public class FunctionalTests(ITestOutputHelper output)
{
[Fact]
public void Test()
public void EqualityTest()
{
var guid = Guid.NewGuid();
var id1 = new ProductId(guid);
Expand All @@ -33,6 +28,17 @@ public void Test()
Assert.True(id1 == id2);
}

[Fact]
public void ImplicitAndExplicitCast()
{
var guid = Guid.NewGuid();
var id = new ProductId(guid);
Guid guid2 = id;
var id2 = (ProductId)guid2;
Assert.Equal(guid, guid2);
Assert.Equal(id, id2);
}

[Fact]
public void Newtonsoft()
{
Expand Down Expand Up @@ -61,20 +67,32 @@ public void Newtonsoft()
public void EntityFramework()
{
var options = new DbContextOptionsBuilder<Context>()
.UseSqlite("Data Source=ef.db")
.UseStructId()
.UseSqlite("Data Source=ef.db")
// Uncomment to see full SQL being run
// .EnableSensitiveDataLogging()
// .UseLoggerFactory(new LoggerFactory(output))
.Options;

using var context = new Context(options);

var id = ProductId.New();
var product = new Product(new ProductId(id), "Product");

// Seed data
var productId = Guid.NewGuid();
var product = new Product(new ProductId(productId), "Product");
context.Products.Add(new Product(ProductId.New(), "Product1"));
context.Products.Add(product);
context.Products.Add(new Product(ProductId.New(), "Product2"));

context.SaveChanges();

var product2 = context.Products.First(x => productId == product.Id);
var product2 = context.Products.Where(x => x.Id == id).FirstOrDefault();
Assert.Equal(product, product2);

Guid guid = id;

var product3 = context.Products.FirstOrDefault(x => guid == x.Id);
Assert.Equal(product, product3);
}

[Fact]
Expand All @@ -88,7 +106,11 @@ public void Dapper()
// Seed data
var productId = Guid.NewGuid();
var product = new Product(new ProductId(productId), "Product");

connection.Execute("INSERT INTO Products (Id, Name) VALUES (@Id, @Name)", new Product(ProductId.New(), "Product1"));
connection.Execute("INSERT INTO Products (Id, Name) VALUES (@Id, @Name)", product);
connection.Execute("INSERT INTO Products (Id, Name) VALUES (@Id, @Name)", new Product(ProductId.New(), "Product2"));

var product2 = connection.QueryFirst<Product>("SELECT * FROM Products WHERE Id = @Id", new { Id = productId });
Assert.Equal(product, product2);
}
Expand All @@ -97,6 +119,20 @@ public class Context : DbContext
{
public Context(DbContextOptions<Context> options) : base(options) { }
public DbSet<Product> Products { get; set; } = null!;
protected override void OnModelCreating(ModelBuilder model) => model.Entity<Product>().HasKey(e => e.Id);
}

class LoggerFactory(ITestOutputHelper output) : ILoggerFactory
{
public void AddProvider(ILoggerProvider provider) => throw new NotImplementedException();
public ILogger CreateLogger(string categoryName) => new Logger(output);
public void Dispose() { }

class Logger(ITestOutputHelper output) : ILogger
{
public IDisposable? BeginScope<TState>(TState state) where TState : notnull => null;
public bool IsEnabled(LogLevel logLevel) => true;
public void Log<TState>(LogLevel logLevel, EventId eventId, TState state, Exception? exception, Func<TState, Exception?, string> formatter)
=> output.WriteLine(formatter(state, exception));
}
}
}
7 changes: 1 addition & 6 deletions src/StructId/Templates/EntityFramework.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@ public partial class EntityFrameworkValueConverter : ValueConverter<TSelf, TId>
public EntityFrameworkValueConverter() : this(null) { }

public EntityFrameworkValueConverter(ConverterMappingHints? mappingHints = null)
: base(
id => id.Value,
value => TSelf.New(value),
mappingHints
)
{ }
: base(id => id.Value, value => TSelf.New(value), mappingHints) { }
}
}
2 changes: 0 additions & 2 deletions src/StructId/Templates/JsonConverter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
using System.Text.Json.Serialization;
using StructId;

#if NET7_0_OR_GREATER
[JsonConverter(typeof(StructIdConverters.SystemTextJsonConverter<Self>))]
#endif
readonly partial record struct Self
{
}
2 changes: 0 additions & 2 deletions src/StructId/Templates/JsonConverterT.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
using System.Text.Json.Serialization;
using StructId;

#if NET7_0_OR_GREATER
[JsonConverter(typeof(StructIdConverters.SystemTextJsonConverter<TSelf, TId>))]
#endif
readonly partial record struct TSelf
{
}
11 changes: 11 additions & 0 deletions src/StructId/Templates/NewableGuid.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// <auto-generated />

using System;

readonly partial record struct TSelf
{
/// <summary>
/// Creates a new instance of <typeparamref name="TSelf"/> with a <see cref="Guid.NewGuid"/>.
/// </summary>
public static TSelf New() => new(Guid.NewGuid());
}
1 change: 1 addition & 0 deletions src/StructId/Templates/TSelf.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ readonly partial record struct Self(string Value) : IStructId

readonly partial record struct TSelf(TId Value) : IStructId<TId>
{
public TSelf(Guid _) : this(default(TId)) { }
}