diff --git a/src/StructId.Analyzer/AnalysisExtensions.cs b/src/StructId.Analyzer/AnalysisExtensions.cs index f3d8159..a9fbf21 100644 --- a/src/StructId.Analyzer/AnalysisExtensions.cs +++ b/src/StructId.Analyzer/AnalysisExtensions.cs @@ -40,7 +40,8 @@ @this is INamedTypeSymbol namedActual && if (iface.Is(baseTypeOrInterface)) return true; - if (@this.BaseType?.Name.Equals("object", StringComparison.OrdinalIgnoreCase) == true) + if (@this.BaseType?.Name.Equals("object", StringComparison.OrdinalIgnoreCase) == true && + @this.BaseType?.Equals(baseTypeOrInterface, SymbolEqualityComparer.Default) != true) return false; return Is(@this.BaseType, baseTypeOrInterface); diff --git a/src/StructId.Analyzer/TemplatedGenerator.cs b/src/StructId.Analyzer/TemplatedGenerator.cs index 963fb06..85aa17c 100644 --- a/src/StructId.Analyzer/TemplatedGenerator.cs +++ b/src/StructId.Analyzer/TemplatedGenerator.cs @@ -1,4 +1,5 @@ -using System.Linq; +using System.Collections.Generic; +using System.Linq; using System.Text; using System.Text.RegularExpressions; using Microsoft.CodeAnalysis; @@ -12,15 +13,15 @@ namespace StructId; [Generator(LanguageNames.CSharp)] public class TemplatedGenerator : IIncrementalGenerator { - record KnownTypes(INamedTypeSymbol String, INamedTypeSymbol? IStructId, INamedTypeSymbol? TStructId, INamedTypeSymbol? TStructIdT); + record KnownTypes(string StructIdNamespace, INamedTypeSymbol String, INamedTypeSymbol? IStructId, INamedTypeSymbol? TStructId, INamedTypeSymbol? TStructIdT); record IdTemplate(INamedTypeSymbol StructId, Template Template); - record Template(INamedTypeSymbol TSelf, ITypeSymbol TId, AttributeData Attribute, bool IsGenericTId) + record Template(INamedTypeSymbol TSelf, ITypeSymbol TId, AttributeData Attribute, string StructIdNamespace, bool IsGenericTId) { public Regex NameExpr { get; } = new Regex($@"\b{TSelf.Name}\b", RegexOptions.Compiled | RegexOptions.Multiline); - public string Text { get; } = GetTemplateCode(TSelf, TId, Attribute); + public string Text { get; } = GetTemplateCode(TSelf, TId, Attribute, StructIdNamespace); - static string GetTemplateCode(INamedTypeSymbol self, ITypeSymbol tid, AttributeData attribute) + static string GetTemplateCode(INamedTypeSymbol self, ITypeSymbol tid, AttributeData attribute, string StructIdNamespace) { if (self.DeclaringSyntaxReferences[0].GetSyntax() is not TypeDeclarationSyntax declaration) return ""; @@ -50,18 +51,48 @@ static string GetTemplateCode(INamedTypeSymbol self, ITypeSymbol tid, AttributeD root = root.ReplaceNode(update, updated); } - return root.SyntaxTree.GetRoot().ToFullString().Trim(); + // replace usings/namespace from StructId > StructIdNamespace + var usings = root.DescendantNodes().OfType().ToList(); + var ns = root.DescendantNodes().OfType().FirstOrDefault(); + var nsname = ns?.Name.ToString(); + + if (nsname == "StructId") + root = root.ReplaceNode(ns!, ns!.WithName(ParseName(StructIdNamespace))); + else if (nsname != StructIdNamespace) + usings.Add(UsingDirective(ParseName(StructIdNamespace))); + + // deduplicate usings just in case + var unique = new HashSet(); + root = root.ReplaceNodes(usings, (old, _) => + { + // replace 'StructId' > StructIdNamespace + if (old.Name?.ToString() == "StructId") + { + unique.Add(StructIdNamespace); + return old.WithName(ParseName(StructIdNamespace)); + } + + if (unique.Add(old.Name?.ToString() ?? "")) + return old; + + return null!; + }); + + var code = root.SyntaxTree.GetRoot().NormalizeWhitespace().ToFullString().Trim(); + + return code; } } public void Initialize(IncrementalGeneratorInitializationContext context) { - var targetNamespace = context.AnalyzerConfigOptionsProvider + var structIdNamespace = context.AnalyzerConfigOptionsProvider .Select((x, _) => x.GlobalOptions.TryGetValue("build_property.StructIdNamespace", out var ns) ? ns : "StructId"); var known = context.CompilationProvider - .Combine(targetNamespace) + .Combine(structIdNamespace) .Select((x, _) => new KnownTypes( + x.Right, // get string known type x.Left.GetTypeByMetadataName("System.String")!, x.Left.GetTypeByMetadataName($"{x.Right}.IStructId`1"), @@ -91,14 +122,14 @@ public void Initialize(IncrementalGeneratorInitializationContext context) var (structId, known) = x; var attribute = structId.GetAttributes().FirstOrDefault(a => a.AttributeClass != null && a.AttributeClass.Is(known.TStructIdT)); if (attribute != null) - return new Template(structId, attribute.AttributeClass!.TypeArguments[0], attribute, true); + return new Template(structId, attribute.AttributeClass!.TypeArguments[0], attribute, known.StructIdNamespace, true); // If we don't have the generic attribute, infer the idType from the required // primary constructor Value parameter type var idType = structId.GetMembers().OfType().First(p => p.Name == "Value").Type; attribute = structId.GetAttributes().First(a => a.AttributeClass != null && a.AttributeClass.Is(known.TStructId)); - return new Template(structId, idType, attribute, false); + return new Template(structId, idType, attribute, known.StructIdNamespace, false); }) .Collect(); diff --git a/src/StructId.FunctionalTests/ObjectTemplate.cs b/src/StructId.FunctionalTests/ObjectTemplate.cs new file mode 100644 index 0000000..d376e44 --- /dev/null +++ b/src/StructId.FunctionalTests/ObjectTemplate.cs @@ -0,0 +1,7 @@ +using StructId; + +[TStructId] +file partial record struct ObjectTemplate(object Value) +{ + // applies to any struct id, whether struct or string +} \ No newline at end of file diff --git a/src/StructId.FunctionalTests/StructTemplate.cs b/src/StructId.FunctionalTests/StructTemplate.cs new file mode 100644 index 0000000..913eeb6 --- /dev/null +++ b/src/StructId.FunctionalTests/StructTemplate.cs @@ -0,0 +1,7 @@ +using StructId; + +[TStructId] +file partial record struct StructTemplate(ValueType Value) +{ + // applies to any ValueType-based struct id +} \ No newline at end of file