|
| 1 | +using System.Collections.Generic; |
| 2 | +using System.Linq; |
| 3 | +using Microsoft.CodeAnalysis; |
| 4 | +using Microsoft.CodeAnalysis.CSharp; |
| 5 | +using Microsoft.CodeAnalysis.CSharp.Syntax; |
| 6 | +using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory; |
| 7 | + |
| 8 | +namespace StructId; |
| 9 | + |
| 10 | +public static class CodeTemplate |
| 11 | +{ |
| 12 | + public static SyntaxNode Parse(string template) |
| 13 | + { |
| 14 | + var tree = CSharpSyntaxTree.ParseText(template, |
| 15 | + CSharpParseOptions.Default.WithLanguageVersion(LanguageVersion.Latest)); |
| 16 | + |
| 17 | + return tree.GetRoot(); |
| 18 | + } |
| 19 | + |
| 20 | + public static string Apply(string template, string structIdType, string valueType) |
| 21 | + { |
| 22 | + var targetNamespace = structIdType.Contains('.') ? structIdType.Substring(0, structIdType.LastIndexOf('.')) : null; |
| 23 | + structIdType = structIdType.Contains('.') ? structIdType.Substring(structIdType.LastIndexOf('.') + 1) : structIdType; |
| 24 | + |
| 25 | + return ApplyImpl(Parse(template), structIdType, valueType, targetNamespace).ToFullString(); |
| 26 | + } |
| 27 | + |
| 28 | + public static SyntaxNode Apply(this SyntaxNode node, INamedTypeSymbol structId) |
| 29 | + { |
| 30 | + var root = node.SyntaxTree.GetCompilationUnitRoot(); |
| 31 | + if (root == null) |
| 32 | + return node; |
| 33 | + |
| 34 | + // determine namespace of the IStructId/IStructId<T> interface implemented by structId |
| 35 | + var iface = structId.Interfaces.FirstOrDefault(x => x.Name == "IStructId"); |
| 36 | + if (iface == null) |
| 37 | + return root; |
| 38 | + |
| 39 | + var tid = iface.TypeArguments.FirstOrDefault()?.ToFullName() ?? "string"; |
| 40 | + var corens = iface.ContainingNamespace.ToFullName(); |
| 41 | + var targetNamespace = structId.ContainingNamespace != null && !structId.ContainingNamespace.IsGlobalNamespace ? |
| 42 | + structId.ContainingNamespace.ToDisplayString() : null; |
| 43 | + |
| 44 | + return ApplyImpl(root, structId.Name, tid, targetNamespace, corens); |
| 45 | + } |
| 46 | + |
| 47 | + static SyntaxNode ApplyImpl(this SyntaxNode node, string structIdType, string valueType, string? targetNamespace = default, string coreNamespace = "StructId") |
| 48 | + { |
| 49 | + var root = node.SyntaxTree.GetCompilationUnitRoot(); |
| 50 | + if (root == null) |
| 51 | + return node; |
| 52 | + |
| 53 | + // If we got a ns, move all members after a file-scoped namespace declaration |
| 54 | + if (targetNamespace != null) |
| 55 | + { |
| 56 | + var members = root.Members; |
| 57 | + var fsns = FileScopedNamespaceDeclaration(ParseName(targetNamespace) |
| 58 | + .WithLeadingTrivia(node.GetLeadingTrivia()) |
| 59 | + .WithLeadingTrivia(Whitespace(" "))) |
| 60 | + .WithLeadingTrivia(LineFeed) |
| 61 | + .WithTrailingTrivia(LineFeed, LineFeed) |
| 62 | + .WithMembers(members); |
| 63 | + |
| 64 | + root = root.WithMembers(SingletonList<MemberDeclarationSyntax>(fsns)); |
| 65 | + } |
| 66 | + |
| 67 | + var usings = root.DescendantNodes().OfType<UsingDirectiveSyntax>().ToList(); |
| 68 | + // There should be NO namespace declared in the template itself, since we enforce file-local |
| 69 | + usings.Add(UsingDirective(ParseName(coreNamespace)).NormalizeWhitespace()); |
| 70 | + |
| 71 | + // deduplicate usings just in case |
| 72 | + var unique = new HashSet<string>(); |
| 73 | + root = root.ReplaceNodes(usings, (old, _) => |
| 74 | + { |
| 75 | + // replace 'StructId' > StructIdNamespace |
| 76 | + if (old.Name?.ToString() == "StructId") |
| 77 | + { |
| 78 | + unique.Add(coreNamespace); |
| 79 | + return old.WithName(ParseName(coreNamespace)); |
| 80 | + } |
| 81 | + |
| 82 | + if (unique.Add(old.Name?.ToString() ?? "")) |
| 83 | + return old; |
| 84 | + |
| 85 | + return null!; |
| 86 | + }); |
| 87 | + |
| 88 | + node = new TemplateRewriter(structIdType, valueType).Visit(root)!; |
| 89 | + |
| 90 | + return node; |
| 91 | + } |
| 92 | + |
| 93 | + class TemplateRewriter(string tself, string tid) : CSharpSyntaxRewriter |
| 94 | + { |
| 95 | + public override SyntaxNode? VisitRecordDeclaration(RecordDeclarationSyntax node) |
| 96 | + { |
| 97 | + // remove file-local records that aren't annotated with [TStructId] |
| 98 | + if (node.Modifiers.Any(x => x.IsKind(SyntaxKind.FileKeyword)) && |
| 99 | + !node.AttributeLists.Any(list => list.Attributes.Any(a => a.IsStructIdTemplate()))) |
| 100 | + return null; |
| 101 | + |
| 102 | + // If the record has the [TStructId] attribute, remove parameter list |
| 103 | + if (node.AttributeLists.Any(list => list.Attributes.Any(a => a.IsStructIdTemplate())) && |
| 104 | + node.ParameterList is { } parameters) |
| 105 | + { |
| 106 | + // Check if the open paren trivia contains the text '🙏' and remove it |
| 107 | + if (parameters.OpenParenToken.GetAllTrivia().Any(x => x.ToString().Contains("🙏"))) |
| 108 | + node = node.WithParameterList(parameters |
| 109 | + .WithOpenParenToken(parameters.OpenParenToken.WithoutTrivia())); |
| 110 | + else |
| 111 | + node = node.WithParameterList(null); |
| 112 | + } |
| 113 | + |
| 114 | + var visited = (RecordDeclarationSyntax)base.VisitRecordDeclaration(node)!; |
| 115 | + |
| 116 | + // remove file modifier from type declarations |
| 117 | + if (visited.Modifiers.FirstOrDefault(x => x.IsKind(SyntaxKind.FileKeyword)) is { } file) |
| 118 | + // Preserve trivia, i.e. newline from original file modifier |
| 119 | + return visited |
| 120 | + .WithLeadingTrivia(file.LeadingTrivia) |
| 121 | + .WithModifiers(visited.Modifiers.Remove(file)); |
| 122 | + |
| 123 | + return visited; |
| 124 | + } |
| 125 | + |
| 126 | + public override SyntaxNode? VisitStructDeclaration(StructDeclarationSyntax node) |
| 127 | + { |
| 128 | + // remove file-local structs that aren't annotated with [TStructId] |
| 129 | + if (node.Modifiers.Any(x => x.IsKind(SyntaxKind.FileKeyword)) && |
| 130 | + !node.AttributeLists.Any(list => list.Attributes.Any(a => a.IsStructIdTemplate()))) |
| 131 | + return null; |
| 132 | + |
| 133 | + return base.VisitStructDeclaration(node); |
| 134 | + } |
| 135 | + |
| 136 | + public override SyntaxNode? VisitAttributeList(AttributeListSyntax node) |
| 137 | + { |
| 138 | + node = (AttributeListSyntax)base.VisitAttributeList(node)!; |
| 139 | + if (node.Attributes.Count == 0) |
| 140 | + return null; |
| 141 | + |
| 142 | + return node; |
| 143 | + } |
| 144 | + |
| 145 | + public override SyntaxNode? VisitAttribute(AttributeSyntax node) |
| 146 | + { |
| 147 | + if (node.IsStructIdTemplate()) |
| 148 | + return null; |
| 149 | + |
| 150 | + return base.VisitAttribute(node); |
| 151 | + } |
| 152 | + |
| 153 | + // rewrite references to the original type with the target type |
| 154 | + public override SyntaxNode? VisitIdentifierName(IdentifierNameSyntax node) |
| 155 | + { |
| 156 | + if (node.Identifier.Text == "TSelf") |
| 157 | + return IdentifierName(tself) |
| 158 | + .WithLeadingTrivia(node.Identifier.LeadingTrivia) |
| 159 | + .WithTrailingTrivia(node.Identifier.TrailingTrivia); |
| 160 | + else if (node.Identifier.Text == "TId") |
| 161 | + return IdentifierName(tid) |
| 162 | + .WithLeadingTrivia(node.Identifier.LeadingTrivia) |
| 163 | + .WithTrailingTrivia(node.Identifier.TrailingTrivia); |
| 164 | + |
| 165 | + return base.VisitIdentifierName(node); |
| 166 | + } |
| 167 | + |
| 168 | + public override SyntaxToken VisitToken(SyntaxToken token) |
| 169 | + { |
| 170 | + // if token is an identifier token, rewrite it |
| 171 | + if (token.IsKind(SyntaxKind.IdentifierToken) && token.Text == "TSelf") |
| 172 | + return Identifier(tself) |
| 173 | + .WithLeadingTrivia(token.LeadingTrivia) |
| 174 | + .WithTrailingTrivia(token.TrailingTrivia); |
| 175 | + else if (token.IsKind(SyntaxKind.IdentifierToken) && token.Text == "TId") |
| 176 | + return Identifier(tid) |
| 177 | + .WithLeadingTrivia(token.LeadingTrivia) |
| 178 | + .WithTrailingTrivia(token.TrailingTrivia); |
| 179 | + |
| 180 | + return base.VisitToken(token); |
| 181 | + } |
| 182 | + } |
| 183 | +} |
0 commit comments