using System; using System.Collections.Generic; using System.Linq; using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.Text; using System.Text; using Microsoft.CodeAnalysis.CSharp.Syntax; namespace ExampleGenerator.Unity.Ui { public static class Helpers { public const string UiElementAttribute = "UiElementAttribute"; public const string AtUiComponentAttribute = "AtUiComponentAttribute"; internal static bool IsDerivedFrom( INamedTypeSymbol baseType , string targetType ) { while ( baseType != null ) { if ( baseType.Name == targetType ) return true; baseType = baseType.BaseType; } return false; } } [Generator] public class UiBackingClassGenerator : ISourceGenerator { private static readonly string AtUiComponentAttributeText = $@"// using System; [AttributeUsage(AttributeTargets.Class, Inherited = true, AllowMultiple = false)] internal class {Helpers.AtUiComponentAttribute} : Attribute {{ public {Helpers.AtUiComponentAttribute}(string uxmlPath) {{ }} }} "; private static readonly string UiElementAttributeText = $@"// using System; [AttributeUsage(AttributeTargets.Property | AttributeTargets.Field, Inherited = true, AllowMultiple = false)] internal class {Helpers.UiElementAttribute} : Attribute {{ public {Helpers.UiElementAttribute}(string name=null) {{ }} }} "; #region Implementation of ISourceGenerator public void Initialize( GeneratorInitializationContext context ) { context.RegisterForPostInitialization( i => { i.AddSource( $"{Helpers.AtUiComponentAttribute}_g.cs" , SourceText.From( AtUiComponentAttributeText , Encoding.UTF8 ) ); i.AddSource( $"{Helpers.UiElementAttribute}_g.cs" , SourceText.From( UiElementAttributeText , Encoding.UTF8 ) ); } ); context.RegisterForSyntaxNotifications( () => new SyntaxReceiver() ); } public void Execute( GeneratorExecutionContext context ) { if ( !(context.SyntaxContextReceiver is SyntaxReceiver receiver) ) return; var atUiComponentAttributeSymbol = context.Compilation.GetTypeByMetadataName( Helpers.AtUiComponentAttribute ); var uiElementAttributeSymbol = context.Compilation.GetTypeByMetadataName( Helpers.UiElementAttribute ); foreach ( var group in receiver.Fields .GroupBy( f => f.ContainingType , SymbolEqualityComparer.Default ) ) { var classSymbol = group.Key; if (classSymbol is null || ! Helpers.IsDerivedFrom( classSymbol , "AtVisualElement" ) ) { continue; } var classSource = ProcessClassUiElement( classSymbol , group , uiElementAttributeSymbol ); if ( classSource == null ) continue; context.AddSource( $"{classSymbol.Name}_ui_query_g.cs" , SourceText.From( classSource , Encoding.UTF8 ) ); } foreach ( var classSymbol in receiver.Classes ) { if (classSymbol is null || ! Helpers.IsDerivedFrom( classSymbol , "AtVisualElement" ) ) { continue; } var classSource = ProcessClassUiComponent( classSymbol , atUiComponentAttributeSymbol ); if ( classSource == null ) continue; context.AddSource( $"{classSymbol.Name}_at_ui_component_g.cs" , SourceText.From( classSource , Encoding.UTF8 ) ); } } private static string ProcessClassUiComponent( INamedTypeSymbol classSymbol , INamedTypeSymbol atUiComponentAttributeSymbol ) { var uiComponentAttributeData = GetUiElementAttributeData( classSymbol , atUiComponentAttributeSymbol ); var uxmlPath = uiComponentAttributeData?.UxmlPath; var source = new StringBuilder(); AppendClassFrameStart( classSymbol,source ); if ( !string.IsNullOrWhiteSpace( uxmlPath ) ) { // Example output: // protected override string UxmlPath => "Ingame/Inventory/Inventory"; source.AppendLine( $" protected override string UxmlPath => \"{uxmlPath}\";" ); } AppendClassFrameEnd( source ); return source.ToString(); } private string ProcessClassUiElement( INamedTypeSymbol classSymbol , IEnumerable fields , INamedTypeSymbol uiElementAttributeSymbol ) { var elementFields = fields.Where( f => GetUiElementAttributeData( f , uiElementAttributeSymbol ) != null ).ToList(); if ( !elementFields.Any() && uiElementAttributeSymbol is null ) return null; var source = new StringBuilder(); AppendClassFrameStart( classSymbol , source ); if ( elementFields.Any() ) { source.Append( @" protected override void QueryElements() { " ); foreach ( var fieldSymbol in elementFields ) { source.AppendLine( $" {fieldSymbol.Name} = this.Q<{GetQualifyingTypeNameFromSymbol( fieldSymbol )}>(\"{GetUiElementAttributeData( fieldSymbol , uiElementAttributeSymbol )?.Name}\");" ); } source.AppendLine( " }" ); } AppendClassFrameEnd( source ); return source.ToString(); } private static void AppendClassFrameStart( INamedTypeSymbol classSymbol , StringBuilder source ) { source.AppendLine($@"// using UnityEngine.UIElements; namespace {classSymbol.ContainingNamespace} {{ public partial class {classSymbol.Name} {{"); } private static void AppendClassFrameEnd( StringBuilder source ) { source.Append( @"} } " ); } private static string GetQualifyingTypeName( ITypeSymbol type ) { return type.ToDisplayString( SymbolDisplayFormat.FullyQualifiedFormat ); } private static string GetQualifyingTypeNameFromSymbol( ISymbol symbol ) => GetQualifyingTypeName( GetTypeFromSymbol( symbol ) ); private static AtUiComponentAttributeData? GetUiElementAttributeData( INamedTypeSymbol classSymbol , INamedTypeSymbol uiElementAttributeSymbol ) { if ( classSymbol is null || uiElementAttributeSymbol is null ) { return null; } var attr = GetSingleAttributeData( classSymbol , uiElementAttributeSymbol ); if ( attr == null ) return null; var args = attr.ConstructorArguments.ToList(); if ( args.Count != 1 ) { throw new NotImplementedException( $"Attribute had a different parameter amount than expected: expected 1 got {args.Count} {attr}: args: {args}" ); } return new AtUiComponentAttributeData { UxmlPath = args[0].Value as string , }; } private static UiElementAttributeData? GetUiElementAttributeData( ISymbol fieldSymbol , INamedTypeSymbol uiElementAttributeSymbol ) { var attr = GetSingleAttributeData( fieldSymbol , uiElementAttributeSymbol ); if ( attr == null ) return null; var args = attr.ConstructorArguments.ToList(); if ( args.Count > 1 ) { throw new NotImplementedException( $"Attribute had more parameters than expected: expected 1 got {args.Count} {attr}: args: {args}" ); } string name = null; if ( args.Count == 1 ) { name = args[0].Value as string; } if ( name is null ) { name = fieldSymbol.Name; } return new UiElementAttributeData { Name = name , }; } private static AttributeData GetSingleAttributeData( ISymbol fieldSymbol , INamedTypeSymbol attributeSymbol ) { var attr = fieldSymbol.GetAttributes() .SingleOrDefault( ad => ad?.AttributeClass?.Equals( attributeSymbol , SymbolEqualityComparer.Default ) ?? false ); return attr; } private static ITypeSymbol GetTypeFromSymbol( ISymbol symbol ) { switch ( symbol ) { case IFieldSymbol fieldSymbol: return fieldSymbol.Type; case IPropertySymbol propertySymbol: return propertySymbol.Type; default: throw new InvalidCastException( $"symbol was not property or field: {symbol}" ); } } private struct UiElementAttributeData { public string Name; } private struct AtUiComponentAttributeData { public string UxmlPath; } #endregion } public class SyntaxReceiver : ISyntaxContextReceiver { public List Fields { get; } = new List(); public List Classes { get; } = new List(); #region Implementation of ISyntaxContextReceiver public void OnVisitSyntaxNode( GeneratorSyntaxContext context ) { switch ( context.Node ) { case ClassDeclarationSyntax classDeclarationSyntax when classDeclarationSyntax.AttributeLists.Count > 0: if ( context.SemanticModel.GetDeclaredSymbol( classDeclarationSyntax ) is INamedTypeSymbol namedTypeSymbol && Helpers.IsDerivedFrom( namedTypeSymbol.BaseType , "AtVisualElement" ) && namedTypeSymbol.GetAttributes() .Any( ad => ad.AttributeClass?.ToDisplayString() == Helpers.AtUiComponentAttribute ) ) { Classes.Add( namedTypeSymbol ); } break; case FieldDeclarationSyntax fieldDeclarationSyntax when fieldDeclarationSyntax.AttributeLists.Count > 0: { foreach ( var variable in fieldDeclarationSyntax.Declaration.Variables ) { if ( context.SemanticModel.GetDeclaredSymbol( variable ) is IFieldSymbol symbol && Helpers.IsDerivedFrom( symbol.ContainingType.BaseType , "AtVisualElement" ) && symbol.GetAttributes() .Any( ad => ad.AttributeClass?.ToDisplayString() == Helpers.UiElementAttribute ) ) { Fields.Add( symbol ); } } break; } case PropertyDeclarationSyntax propertyDeclarationSyntax when propertyDeclarationSyntax.AttributeLists.Count > 0: { if ( context.SemanticModel.GetDeclaredSymbol( propertyDeclarationSyntax ) is IPropertySymbol symbol && Helpers.IsDerivedFrom( symbol.ContainingType.BaseType , "AtVisualElement" ) && symbol.GetAttributes() .Any( ad => ad.AttributeClass?.ToDisplayString() == Helpers.UiElementAttribute ) ) { Fields.Add( symbol ); } break; } } } #endregion } }