diff --git a/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperAnalyzer.cs b/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperAnalyzer.cs index d798745..1b9c989 100644 --- a/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperAnalyzer.cs +++ b/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperAnalyzer.cs @@ -192,7 +192,7 @@ private void ValidateDapperMethod(in OperationAnalysisContext ctx, IOperation sq var parseState = new ParseState(ctx); bool aotEnabled = IsEnabled(in parseState, invoke, Types.DapperAotAttribute, out var aotAttribExists); if (!aotEnabled) flags |= OperationFlags.DoNotGenerate; - var location = SharedParseArgsAndFlags(parseState, invoke, ref flags, out var sql, out var argExpression, onDiagnostic, out _, exitFirstFailure: false); + var location = SharedParseArgsAndFlags(parseState, invoke, ref flags, out var sql, out var argExpression, onDiagnostic, out _, exitFirstFailure: false, out var viaCommandDefinition); // report our AOT readiness if (aotEnabled) @@ -410,7 +410,7 @@ private void ValidateSql(in OperationAnalysisContext ctx, IOperation sqlSource, if (caseSensitive) flags |= SqlParseInputFlags.CaseSensitive; // can we get the SQL itself? - if (!TryGetConstantValueWithSyntax(sqlSource, out string? sql, out var sqlSyntax, out var stringSyntaxKind)) + if (!TryGetStringConstantValueWithSyntax(sqlSource, out string? sql, out var sqlSyntax, out var stringSyntaxKind)) { DiagnosticDescriptor? descriptor = stringSyntaxKind switch { @@ -503,25 +503,62 @@ StringSyntaxKind.ConcatenatedString or StringSyntaxKind.FormatString // we want a common understanding of the setup between the analyzer and generator internal static Location SharedParseArgsAndFlags(in ParseState ctx, IInvocationOperation op, ref OperationFlags flags, out string? sql, - out IOperation? argExpression, Action? reportDiagnostic, out ITypeSymbol? resultType, bool exitFirstFailure) + out IOperation? argExpression, Action? reportDiagnostic, out ITypeSymbol? resultType, bool exitFirstFailure, + out bool viaCommandDefinition) { var callLocation = op.GetMemberLocation(); argExpression = null; sql = null; bool? buffered = null; + viaCommandDefinition = false; - // check the args - foreach (var arg in op.Arguments) + // default is invocation, so simply take arguments + IEnumerable arguments = op.Arguments; + + // invocation can be packed into a CommandDefinition + if (op.Arguments is { Length: >= 2 }) { + if (op.Arguments[0].Parameter?.Name == "cnn" + && op.Arguments[1].Parameter?.Name == "command" && op.Arguments[1].Parameter?.Type.IsCommandDefinition() == true) + { + viaCommandDefinition = true; + + // by default buffered CommandDefinition constructor initializes `buffered` as true via CommandFlags + // https://github.com/DapperLib/Dapper/blob/5c7143f2e3585d4708294a3b0530a134e18ace86/Dapper/CommandDefinition.cs#L85 + buffered = true; + + // in-place creation of CommandDefinition like `Query(new CommandDefinition(...))` + if (op.Arguments[1].Value is IObjectCreationOperation { Arguments.IsDefaultOrEmpty: false } commandDefinitionCreation ) + { + arguments = commandDefinitionCreation.Arguments; + } + // ideally here we would want to parse other CommandDefinition cases (i.e. local variable). + // but it is complicated, so we can simply rely on passing CommandDefinition's members to the underlying query API + // ... + } + } + + // check the args. Names of the parameters are handling Dapper method parameters + CommandDefinition members + foreach (var arg in arguments) + { switch (arg.Parameter?.Name) { case "sql": - if (TryGetConstantValueWithSyntax(arg, out string? s, out _, out _)) + case "commandText": + if (TryGetStringConstantValueWithSyntax(arg, out string? s, out _, out _)) { sql = s; } break; + case "flags": + { + if (TryGetEnumConstantValueWithSyntax(arg, out int? value)) + { + buffered = (value & 1) != 0; // CommandFlags.Buffered = 1 + } + } + break; case "buffered": if (TryGetConstantValue(arg, out bool b)) { @@ -529,6 +566,7 @@ internal static Location SharedParseArgsAndFlags(in ParseState ctx, IInvocationO } break; case "param": + case "parameters": if (arg.Value is not IDefaultValueOperation) { var expr = arg.Value; @@ -555,6 +593,7 @@ internal static Location SharedParseArgsAndFlags(in ParseState ctx, IInvocationO case "length": case "returnNullIfFirstMissing": case "concreteType" when arg.Value is IDefaultValueOperation || (arg.ConstantValue.HasValue && arg.ConstantValue.Value is null): + case "cancellationToken": // nothing to do break; case "commandType": @@ -583,6 +622,17 @@ internal static Location SharedParseArgsAndFlags(in ParseState ctx, IInvocationO } } break; + case "command": + { + // case for CommandDefinition - we need to check that we detected it correctly before + // and if we did; then don't drop errors - we could not parse SQL / other flags in complex CommandDefinition usages, + // but we can optimistically pass CommandDefinition data to underlying query + if (!viaCommandDefinition) + { + goto default; + } + } + break; default: if (!flags.HasAny(OperationFlags.NotAotSupported | OperationFlags.DoNotGenerate)) { diff --git a/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.Single.cs b/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.Single.cs index 58c3f80..d80807a 100644 --- a/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.Single.cs +++ b/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.Single.cs @@ -19,7 +19,8 @@ static void WriteSingleImplementation( in CommandFactoryState factories, in RowReaderState readers, string? fixedSql, - AdditionalCommandState? additionalCommandState) + AdditionalCommandState? additionalCommandState, + bool viaCommandDefinition) { sb.Append("return "); if (flags.HasAll(OperationFlags.Async | OperationFlags.Query | OperationFlags.Buffered)) @@ -27,18 +28,29 @@ static void WriteSingleImplementation( sb.Append("global::Dapper.DapperAotExtensions.AsEnumerableAsync(").Indent(false).NewLine(); } // (DbConnection connection, DbTransaction? transaction, string sql, TArgs args, CommandType commandType, int timeout, CommandFactory? commandFactory) - sb.Append("global::Dapper.DapperAotExtensions.Command(cnn, ").Append(Forward(methodParameters, "transaction")).Append(", "); + sb.Append("global::Dapper.DapperAotExtensions.Command(cnn, "); + + if (viaCommandDefinition) sb.Append("command.Transaction, "); + else sb.Append(Forward(methodParameters, "transaction")).Append(", "); + if (fixedSql is not null) { sb.AppendVerbatimLiteral(fixedSql).Append(", "); } else { - sb.Append("sql, "); + if (viaCommandDefinition) sb.Append("command.CommandText, "); + else sb.Append("sql, "); } + if (commandTypeMode == 0) - { // not hard-coded - if (HasParam(methodParameters, "command")) + { + if (viaCommandDefinition) + { + sb.Append("command.CommandType ?? default"); + } + // not hard-coded + else if (HasParam(methodParameters, "command")) { sb.Append("command.GetValueOrDefault()"); } @@ -49,9 +61,14 @@ static void WriteSingleImplementation( } else { - sb.Append("global::System.Data.CommandType.").Append(commandTypeMode.ToString()); + if (viaCommandDefinition) sb.Append("command.CommandType ?? default"); + else sb.Append("global::System.Data.CommandType.").Append(commandTypeMode.ToString()); } - sb.Append(", ").Append(Forward(methodParameters, "commandTimeout")).Append(HasParam(methodParameters, "commandTimeout") ? ".GetValueOrDefault()" : "").Append(", "); + sb.Append(", "); + + if (viaCommandDefinition) sb.Append("command.CommandTimeout ?? default, "); + else sb.Append(Forward(methodParameters, "commandTimeout")).Append(HasParam(methodParameters, "commandTimeout") ? ".GetValueOrDefault()" : "").Append(", "); + if (flags.HasAny(OperationFlags.HasParameters)) { var index = factories.GetIndex(parameterType!, map, cache, additionalCommandState, out var subIndex); @@ -79,7 +96,7 @@ static void WriteSingleImplementation( OperationFlags.Unbuffered => "Unbuffered", _ => "" }).Append(isAsync ? "Async" : "").Append("("); - WriteTypedArg(sb, parameterType).Append(", "); + WriteTypedArg(sb, parameterType, viaCommandDefinition).Append(", "); if (!flags.HasAny(OperationFlags.SingleRow)) { switch (flags & (OperationFlags.Buffered | OperationFlags.Unbuffered)) @@ -107,7 +124,7 @@ static void WriteSingleImplementation( sb.Append("<").Append(resultType).Append(">"); } sb.Append("("); - WriteTypedArg(sb, parameterType); + WriteTypedArg(sb, parameterType, viaCommandDefinition); } else { @@ -124,9 +141,11 @@ static void WriteSingleImplementation( sb.Append(", rowCountHint: ((").Append(parameterType).Append(")param!).").Append(additionalCommandState.RowCountHintMemberName); } } - if (isAsync && HasParam(methodParameters, "cancellationToken")) + if (isAsync && (HasParam(methodParameters, "cancellationToken") || viaCommandDefinition)) { - sb.Append(", cancellationToken: ").Append(Forward(methodParameters, "cancellationToken")); + sb.Append(", cancellationToken: "); + if (viaCommandDefinition) sb.Append("command.CancellationToken"); + else sb.Append(Forward(methodParameters, "cancellationToken")); } if (flags.HasAll(OperationFlags.Async | OperationFlags.Query | OperationFlags.Buffered)) { @@ -153,10 +172,17 @@ static void WriteSingleImplementation( } sb.Append(";").NewLine(); - static CodeWriter WriteTypedArg(CodeWriter sb, ITypeSymbol? parameterType) + static CodeWriter WriteTypedArg(CodeWriter sb, ITypeSymbol? parameterType, bool viaCommandDefinition) { + if (viaCommandDefinition) + { + sb.Append("command.Parameters"); + return sb; + } + if (parameterType is null || parameterType.IsAnonymousType) { + sb.Append("param"); } else diff --git a/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.cs b/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.cs index 595282a..2b05986 100644 --- a/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.cs +++ b/src/Dapper.AOT.Analyzers/CodeAnalysis/DapperInterceptorGenerator.cs @@ -104,14 +104,12 @@ internal bool PreFilter(SyntaxNode node, CancellationToken cancellationToken) return null; } - var location = DapperAnalyzer.SharedParseArgsAndFlags(ctx, op, ref flags, out var sql, out var argExpression, reportDiagnostic: null, out var resultType, exitFirstFailure: true); + var location = DapperAnalyzer.SharedParseArgsAndFlags(ctx, op, ref flags, out var sql, out var argExpression, reportDiagnostic: null, out var resultType, exitFirstFailure: true, out var viaCommandDefinition); if (flags.HasAny(OperationFlags.DoNotGenerate)) { return null; } - - // additional result-type checks // perform SQL inspection @@ -133,7 +131,7 @@ internal bool PreFilter(SyntaxNode node, CancellationToken cancellationToken) var additionalState = AdditionalCommandState.Parse(Inspection.GetSymbol(ctx, op), map, null); Debug.Assert(!flags.HasAny(OperationFlags.DoNotGenerate), "should have already exited"); - return new SuccessSourceState(location, op.TargetMethod, flags, sql, resultType, argExpression?.Type, parameterMap, additionalState); + return new SuccessSourceState(location, op.TargetMethod, flags, sql, resultType, argExpression?.Type, parameterMap, additionalState, viaCommandDefinition); } catch (Exception ex) { @@ -284,7 +282,7 @@ internal void Generate(in GenerateState ctx) foreach (var grp in ctx.Nodes.OfType().Where(x => !x.Flags.HasAny(OperationFlags.DoNotGenerate)).GroupBy(x => x.Group(), CommonComparer.Instance)) { // first, try to resolve the helper method that we're going to use for this - var (flags, method, parameterType, parameterMap, _, additionalCommandState) = grp.Key; + var (flags, method, parameterType, parameterMap, _, additionalCommandState, viaCommandDefinition) = grp.Key; const bool useUnsafe = false; int usageCount = 0; @@ -343,39 +341,52 @@ internal void Generate(in GenerateState ctx) var methodParameters = grp.Key.Method.Parameters; string? fixedSql = null; - if (HasParam(methodParameters, "sql")) + if (HasParam(methodParameters, "sql") || viaCommandDefinition) { if (flags.HasAny(OperationFlags.IncludeLocation)) { var origin = grp.Single(); fixedSql = origin.Sql; // expect exactly one SQL - sb.Append("global::System.Diagnostics.Debug.Assert(sql == ") + sb.Append("global::System.Diagnostics.Debug.Assert(") + .Append(viaCommandDefinition ? "command.CommandText" : "sql").Append(" == ") .AppendVerbatimLiteral(fixedSql).Append(");").NewLine(); var path = origin.Location.GetMappedLineSpan(); fixedSql = $"-- {path.Path}#{path.StartLinePosition.Line + 1}\r\n{fixedSql}"; } else { - sb.Append("global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(sql));").NewLine(); + sb.Append("global::System.Diagnostics.Debug.Assert(") + .Append("!string.IsNullOrWhiteSpace(").Append(viaCommandDefinition ? "command.CommandText" : "sql").Append(")") + .Append(");").NewLine(); } } - if (HasParam(methodParameters, "commandType")) + if (HasParam(methodParameters, "commandType") || viaCommandDefinition) { if (commandTypeMode != 0) { - sb.Append("global::System.Diagnostics.Debug.Assert((commandType ?? global::Dapper.DapperAotExtensions.GetCommandType(sql)) == global::System.Data.CommandType.") - .Append(commandTypeMode.ToString()).Append(");").NewLine(); + sb.Append("global::System.Diagnostics.Debug.Assert(") + .Append("(").Append(viaCommandDefinition ? "command.CommandType" : "commandType") + .Append(" ?? global::Dapper.DapperAotExtensions.GetCommandType(") + .Append(viaCommandDefinition ? "command.CommandText" : "sql") + .Append(")) == global::System.Data.CommandType.") + .Append(commandTypeMode.ToString()).Append(");").NewLine(); } } - if (flags.HasAny(OperationFlags.Buffered | OperationFlags.Unbuffered) && HasParam(methodParameters, "buffered")) + if (flags.HasAny(OperationFlags.Buffered | OperationFlags.Unbuffered) && HasParam(methodParameters, "buffered") || viaCommandDefinition) { - sb.Append("global::System.Diagnostics.Debug.Assert(buffered is ").Append((flags & OperationFlags.Buffered) != 0).Append(");").NewLine(); + sb + .Append("global::System.Diagnostics.Debug.Assert(") + .Append(viaCommandDefinition ? "command.Buffered is " : "buffered is ") + .Append((flags & OperationFlags.Buffered) != 0).Append(");").NewLine(); } - if (HasParam(methodParameters, "param")) + if (HasParam(methodParameters, "param") || viaCommandDefinition) { - sb.Append("global::System.Diagnostics.Debug.Assert(param is ").Append(flags.HasAny(OperationFlags.HasParameters) ? "not " : "").Append("null);").NewLine(); + sb + .Append("global::System.Diagnostics.Debug.Assert(") + .Append(viaCommandDefinition ? "command.Parameters" : "param") + .Append(" is ").Append(flags.HasAny(OperationFlags.HasParameters) ? "not " : "").Append("null);").NewLine(); } if (HasParam(methodParameters, "concreteType")) @@ -398,7 +409,7 @@ internal void Generate(in GenerateState ctx) } else if (!TryWriteMultiExecImplementation(sb, flags, commandTypeMode, parameterType, grp.Key.ParameterMap, grp.Key.UniqueLocation is not null, methodParameters, factories, fixedSql, additionalCommandState)) { - WriteSingleImplementation(sb, method, resultType, flags, commandTypeMode, parameterType, grp.Key.ParameterMap, grp.Key.UniqueLocation is not null, methodParameters, factories, readers, fixedSql, additionalCommandState); + WriteSingleImplementation(sb, method, resultType, flags, commandTypeMode, parameterType, grp.Key.ParameterMap, grp.Key.UniqueLocation is not null, methodParameters, factories, readers, fixedSql, additionalCommandState, viaCommandDefinition); } sb.Outdent().NewLine().NewLine(); @@ -1513,10 +1524,12 @@ internal sealed class SuccessSourceState : SourceState public ITypeSymbol? ResultType { get; } public ITypeSymbol? ParameterType { get; } public AdditionalCommandState? AdditionalCommandState { get; } + public bool ViaCommandDefinition { get; } public SuccessSourceState(Location location, IMethodSymbol method, OperationFlags flags, string? sql, ITypeSymbol? resultType, ITypeSymbol? parameterType, string parameterMap, - AdditionalCommandState? additionalCommandState) : base(location) + AdditionalCommandState? additionalCommandState, + bool viaCommandDefinition) : base(location) { Flags = flags; Sql = sql; @@ -1525,27 +1538,29 @@ public SuccessSourceState(Location location, IMethodSymbol method, OperationFlag Method = method; ParameterMap = parameterMap; AdditionalCommandState = additionalCommandState; + ViaCommandDefinition = viaCommandDefinition; } - public (OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation, AdditionalCommandState? AdditionalCommandState) Group() - => new(Flags, Method, ParameterType, ParameterMap, (Flags & (OperationFlags.CacheCommand | OperationFlags.IncludeLocation)) == 0 ? null : Location, AdditionalCommandState); + public (OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation, AdditionalCommandState? AdditionalCommandState, bool ViaCommandDefinition) Group() + => new(Flags, Method, ParameterType, ParameterMap, (Flags & (OperationFlags.CacheCommand | OperationFlags.IncludeLocation)) == 0 ? null : Location, AdditionalCommandState, ViaCommandDefinition); } - private sealed class CommonComparer : LocationComparer, IEqualityComparer<(OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation, AdditionalCommandState? AdditionalCommandState)> + private sealed class CommonComparer : LocationComparer, IEqualityComparer<(OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation, AdditionalCommandState? AdditionalCommandState, bool ViaCommandDefinition)> { public static readonly CommonComparer Instance = new(); private CommonComparer() { } public bool Equals( - (OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation, AdditionalCommandState? AdditionalCommandState) x, - (OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation, AdditionalCommandState? AdditionalCommandState) y) => x.Flags == y.Flags + (OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation, AdditionalCommandState? AdditionalCommandState, bool ViaCommandDefinition) x, + (OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation, AdditionalCommandState? AdditionalCommandState, bool ViaCommandDefinition) y) => x.Flags == y.Flags && x.ParameterMap == y.ParameterMap && SymbolEqualityComparer.Default.Equals(x.Method, y.Method) && SymbolEqualityComparer.Default.Equals(x.ParameterType, y.ParameterType) && x.UniqueLocation == y.UniqueLocation - && Equals(x.AdditionalCommandState, y.AdditionalCommandState); + && Equals(x.AdditionalCommandState, y.AdditionalCommandState) + && x.ViaCommandDefinition == y.ViaCommandDefinition; - public int GetHashCode((OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation, AdditionalCommandState? AdditionalCommandState) obj) + public int GetHashCode((OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? ParameterType, string ParameterMap, Location? UniqueLocation, AdditionalCommandState? AdditionalCommandState, bool ViaCommandDefinition) obj) { var hash = (int)obj.Flags; hash *= -47; @@ -1567,6 +1582,8 @@ public int GetHashCode((OperationFlags Flags, IMethodSymbol Method, ITypeSymbol? { hash += obj.AdditionalCommandState.GetHashCode(); } + hash *= -47; + hash += obj.ViaCommandDefinition ? 1 : 0; return hash; } } diff --git a/src/Dapper.AOT.Analyzers/Internal/Inspection.cs b/src/Dapper.AOT.Analyzers/Internal/Inspection.cs index ed11497..8f29e28 100644 --- a/src/Dapper.AOT.Analyzers/Internal/Inspection.cs +++ b/src/Dapper.AOT.Analyzers/Internal/Inspection.cs @@ -193,6 +193,9 @@ public static ImmutableArray ParseQueryColumns(AttributeData attrib, Act return result; } + public static bool IsCommandDefinition(this ITypeSymbol? typeSymbol) + => typeSymbol.IsDapperType("CommandDefinition") && typeSymbol?.TypeKind == TypeKind.Struct; + public static bool IsSqlClient(ITypeSymbol? typeSymbol) => typeSymbol is { Name: "SqlCommand", @@ -1294,7 +1297,7 @@ public static bool IsDapperMethod(this IInvocationOperation operation, out Opera public static bool HasAll(this OperationFlags value, OperationFlags testFor) => (value & testFor) == testFor; public static bool TryGetConstantValue(IOperation op, out T? value) - => TryGetConstantValueWithSyntax(op, out value, out _, out _); + => TryGetStringConstantValueWithSyntax(op, out value, out _, out _); public static ITypeSymbol? GetResultType(this IInvocationOperation invocation, OperationFlags flags) { @@ -1313,7 +1316,35 @@ internal static bool CouldBeNullable(ITypeSymbol symbol) => symbol.IsValueType ? symbol.NullableAnnotation == NullableAnnotation.Annotated : symbol.NullableAnnotation != NullableAnnotation.NotAnnotated; - public static bool TryGetConstantValueWithSyntax(IOperation val, out T? value, out SyntaxNode? syntax, out StringSyntaxKind? syntaxKind) + public static bool TryGetEnumConstantValueWithSyntax(IOperation val, out T? value) + { + if (val.ConstantValue.HasValue) + { + value = (T?)val.ConstantValue.Value; + return true; + } + if (val is IArgumentOperation arg) + { + val = arg.Value; + } + // work through any implicit/explicit conversion steps + while (val is IConversionOperation conv) + { + val = conv.Operand; + } + + // type-level constants + if (val is IFieldReferenceOperation field && field.Field.HasConstantValue) + { + value = (T?)field.Field.ConstantValue; + return true; + } + + value = default!; + return false; + } + + public static bool TryGetStringConstantValueWithSyntax(IOperation val, out T? value, out SyntaxNode? syntax, out StringSyntaxKind? syntaxKind) { try { diff --git a/src/Dapper.AOT.Analyzers/Internal/Roslyn/TypeSymbolExtensions.cs b/src/Dapper.AOT.Analyzers/Internal/Roslyn/TypeSymbolExtensions.cs index 7d9e70a..d2581cc 100644 --- a/src/Dapper.AOT.Analyzers/Internal/Roslyn/TypeSymbolExtensions.cs +++ b/src/Dapper.AOT.Analyzers/Internal/Roslyn/TypeSymbolExtensions.cs @@ -145,6 +145,19 @@ public static bool IsAsync(this ITypeSymbol? type, out ITypeSymbol? result) /// public static bool IsImmutableArray(this ITypeSymbol? typeSymbol) => IsStandardCollection(typeSymbol, "ImmutableArray", "Immutable", TypeKind.Struct); + public static bool IsDapperType(this ITypeSymbol? typeSymbol, string expectedName) + { + if (typeSymbol is null) return false; + return typeSymbol.Name == expectedName && typeSymbol.ContainingNamespace is + { + Name: "Dapper", + ContainingNamespace: + { + IsGlobalNamespace: true + } + }; + } + private static bool IsStandardCollection(ITypeSymbol? type, string name, string nsName = "Generic", TypeKind kind = TypeKind.Class) => type is INamedTypeSymbol named && named.Name == name diff --git a/test/Dapper.AOT.Test.Integration.Executables/Models/CommandDefinitionPoco.cs b/test/Dapper.AOT.Test.Integration.Executables/Models/CommandDefinitionPoco.cs new file mode 100644 index 0000000..cf73ed7 --- /dev/null +++ b/test/Dapper.AOT.Test.Integration.Executables/Models/CommandDefinitionPoco.cs @@ -0,0 +1,9 @@ +namespace Dapper.AOT.Test.Integration.Executables.Models; + +public class CommandDefinitionPoco +{ + public const string TableName = "commandDefinitionPoco"; + + public int Id { get; set; } + public string? Name { get; set; } +} diff --git a/test/Dapper.AOT.Test.Integration.Executables/UserCode/CommandDefinition/CommandDefinitionUsage.cs b/test/Dapper.AOT.Test.Integration.Executables/UserCode/CommandDefinition/CommandDefinitionUsage.cs new file mode 100644 index 0000000..99e10de --- /dev/null +++ b/test/Dapper.AOT.Test.Integration.Executables/UserCode/CommandDefinition/CommandDefinitionUsage.cs @@ -0,0 +1,26 @@ +using System.Data; +using System.Linq; +using System.Threading; +using Dapper; +using Dapper.AOT.Test.Integration.Executables.Models; + +namespace Dapper.AOT.Test.Integration.Executables.UserCode; + +[DapperAot] +public class CommandDefinitionUsage : IExecutable +{ + public CommandDefinitionPoco Execute(IDbConnection connection) + { + var results = connection.Query( + new CommandDefinition( + commandText: $"select * from {CommandDefinitionPoco.TableName} where name = @name", + parameters: new + { + name = "my-data" + }, + cancellationToken: CancellationToken.None) + ); + + return results.First(); + } +} \ No newline at end of file diff --git a/test/Dapper.AOT.Test.Integration/CommandDefinitionTests.cs b/test/Dapper.AOT.Test.Integration/CommandDefinitionTests.cs new file mode 100644 index 0000000..4a3a224 --- /dev/null +++ b/test/Dapper.AOT.Test.Integration/CommandDefinitionTests.cs @@ -0,0 +1,44 @@ +using System; +using System.Data; +using Dapper.AOT.Test.Integration.Executables.Models; +using Dapper.AOT.Test.Integration.Executables.UserCode; +using Dapper.AOT.Test.Integration.Setup; + +namespace Dapper.AOT.Test.Integration; + +[Collection(SharedPostgresqlClient.Collection)] +public class CommandDefinitionTests : IntegrationTestsBase +{ + public CommandDefinitionTests(PostgresqlFixture fixture) : base(fixture) + { + } + + protected override void SetupDatabase(IDbConnection dbConnection) + { + base.SetupDatabase(dbConnection); + + dbConnection.Execute($""" + CREATE TABLE IF NOT EXISTS {CommandDefinitionPoco.TableName}( + id integer PRIMARY KEY, + name varchar(40) + ); + + TRUNCATE {CommandDefinitionPoco.TableName}; + + INSERT INTO {CommandDefinitionPoco.TableName} (id, name) + VALUES (1, 'my-data'), + (2, 'my-poco'), + (3, 'your-data'); + """); + } + + [Fact] + public void CommandDefinition_BasicUsage_InterceptsAndReturnsExpectedData() + { + var result = ExecuteInterceptedUserCode(DbConnection); + + Assert.NotNull(result); + Assert.True(result.Id.Equals(1)); + Assert.Equal("my-data", result.Name); + } +} \ No newline at end of file diff --git a/test/Dapper.AOT.Test/Interceptors/CommandDefinition.input.cs b/test/Dapper.AOT.Test/Interceptors/CommandDefinition.input.cs new file mode 100644 index 0000000..b4ce5a7 --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/CommandDefinition.input.cs @@ -0,0 +1,61 @@ +using Dapper; +using System.Data; +using System.Data.Common; +using Microsoft.Data.SqlClient; +using System.Threading; + +[module: DapperAot] + +public static class Foo +{ + static void Run(DbConnection connection) + { + var sql = "sp_crunch"; + + // 0: sql comes as local var + _ = connection.ExecuteScalar(new CommandDefinition( + sql, + new { X = 3 }, + commandType: CommandType.StoredProcedure, + flags: CommandFlags.Buffered, + cancellationToken: CancellationToken.None)); + + // 1: sql inline + _ = connection.ExecuteScalarAsync(new CommandDefinition( + "sp_crunch", + new { X = 3 }, + commandType: CommandType.StoredProcedure, + flags: CommandFlags.Buffered, + cancellationToken: CancellationToken.None)); + + // 2: very limited setupdap + _ = connection.ExecuteScalarAsync(new CommandDefinition( + "select * from table where X = @X", + new { X = 3 }) + ); + + // 3: no async + _ = connection.ExecuteScalar(new CommandDefinition( + "select * from table where X = @X", + new { X = 3 }) + ); + + // 4: command definition as local var + var local = new CommandDefinition( + "select * from table where X = @X", + new { X = 3 } + ); + _ = connection.ExecuteScalarAsync(local); + + // 5: query via command definition + var results = connection.Query( + new CommandDefinition( + commandText: "select * from table where name = @name", + parameters: new + { + name = "my-data" + }, + cancellationToken: CancellationToken.None) + ); + } +} \ No newline at end of file diff --git a/test/Dapper.AOT.Test/Interceptors/CommandDefinition.output.cs b/test/Dapper.AOT.Test/Interceptors/CommandDefinition.output.cs new file mode 100644 index 0000000..c77a7f6 --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/CommandDefinition.output.cs @@ -0,0 +1,193 @@ +#nullable enable +#pragma warning disable IDE0078 // unnecessary suppression is necessary +#pragma warning disable CS9270 // SDK-dependent change to interceptors usage +namespace Dapper.AOT // interceptors must be in a known namespace +{ + file static class DapperGeneratedInterceptors + { + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\CommandDefinition.input.cs", 16, 23)] + internal static string? ExecuteScalar0(this global::System.Data.IDbConnection cnn, global::Dapper.CommandDefinition command) + { + // Execute, TypedResult, HasParameters, StoredProcedure, Scalar, KnownParameters + // takes parameter: + // parameter map: X + // returns data: string + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(command.CommandText)); + global::System.Diagnostics.Debug.Assert((command.CommandType ?? global::Dapper.DapperAotExtensions.GetCommandType(command.CommandText)) == global::System.Data.CommandType.StoredProcedure); + global::System.Diagnostics.Debug.Assert(command.Buffered is false); + global::System.Diagnostics.Debug.Assert(command.Parameters is not null); + + return global::Dapper.DapperAotExtensions.Command(cnn, command.Transaction, command.CommandText, command.CommandType ?? default, command.CommandTimeout ?? default, CommandFactory0.Instance).ExecuteScalar(command.Parameters); + + } + + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\CommandDefinition.input.cs", 24, 24)] + internal static global::System.Threading.Tasks.Task ExecuteScalarAsync1(this global::System.Data.IDbConnection cnn, global::Dapper.CommandDefinition command) + { + // Execute, Async, TypedResult, HasParameters, StoredProcedure, Scalar, KnownParameters + // takes parameter: + // parameter map: X + // returns data: string + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(command.CommandText)); + global::System.Diagnostics.Debug.Assert((command.CommandType ?? global::Dapper.DapperAotExtensions.GetCommandType(command.CommandText)) == global::System.Data.CommandType.StoredProcedure); + global::System.Diagnostics.Debug.Assert(command.Buffered is false); + global::System.Diagnostics.Debug.Assert(command.Parameters is not null); + + return global::Dapper.DapperAotExtensions.Command(cnn, command.Transaction, command.CommandText, command.CommandType ?? default, command.CommandTimeout ?? default, CommandFactory0.Instance).ExecuteScalarAsync(command.Parameters, cancellationToken: command.CancellationToken); + + } + + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\CommandDefinition.input.cs", 32, 24)] + internal static global::System.Threading.Tasks.Task ExecuteScalarAsync2(this global::System.Data.IDbConnection cnn, global::Dapper.CommandDefinition command) + { + // Execute, Async, TypedResult, HasParameters, Text, Scalar, KnownParameters + // takes parameter: + // parameter map: X + // returns data: string + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(command.CommandText)); + global::System.Diagnostics.Debug.Assert((command.CommandType ?? global::Dapper.DapperAotExtensions.GetCommandType(command.CommandText)) == global::System.Data.CommandType.Text); + global::System.Diagnostics.Debug.Assert(command.Buffered is false); + global::System.Diagnostics.Debug.Assert(command.Parameters is not null); + + return global::Dapper.DapperAotExtensions.Command(cnn, command.Transaction, command.CommandText, command.CommandType ?? default, command.CommandTimeout ?? default, CommandFactory0.Instance).ExecuteScalarAsync(command.Parameters, cancellationToken: command.CancellationToken); + + } + + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\CommandDefinition.input.cs", 38, 24)] + internal static string? ExecuteScalar3(this global::System.Data.IDbConnection cnn, global::Dapper.CommandDefinition command) + { + // Execute, TypedResult, HasParameters, Text, Scalar, KnownParameters + // takes parameter: + // parameter map: X + // returns data: string + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(command.CommandText)); + global::System.Diagnostics.Debug.Assert((command.CommandType ?? global::Dapper.DapperAotExtensions.GetCommandType(command.CommandText)) == global::System.Data.CommandType.Text); + global::System.Diagnostics.Debug.Assert(command.Buffered is false); + global::System.Diagnostics.Debug.Assert(command.Parameters is not null); + + return global::Dapper.DapperAotExtensions.Command(cnn, command.Transaction, command.CommandText, command.CommandType ?? default, command.CommandTimeout ?? default, CommandFactory0.Instance).ExecuteScalar(command.Parameters); + + } + + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\CommandDefinition.input.cs", 48, 24)] + internal static global::System.Threading.Tasks.Task ExecuteScalarAsync4(this global::System.Data.IDbConnection cnn, global::Dapper.CommandDefinition command) + { + // Execute, Async, TypedResult, Scalar + // returns data: string + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(command.CommandText)); + global::System.Diagnostics.Debug.Assert(command.Buffered is false); + global::System.Diagnostics.Debug.Assert(command.Parameters is null); + + return global::Dapper.DapperAotExtensions.Command(cnn, command.Transaction, command.CommandText, command.CommandType ?? default, command.CommandTimeout ?? default, DefaultCommandFactory).ExecuteScalarAsync(command.Parameters, cancellationToken: command.CancellationToken); + + } + + [global::System.Runtime.CompilerServices.InterceptsLocationAttribute("Interceptors\\CommandDefinition.input.cs", 51, 34)] + internal static global::System.Collections.Generic.IEnumerable Query5(this global::System.Data.IDbConnection cnn, global::Dapper.CommandDefinition command) + { + // Query, TypedResult, HasParameters, Buffered, Text, KnownParameters + // takes parameter: + // parameter map: name + // returns data: string + global::System.Diagnostics.Debug.Assert(!string.IsNullOrWhiteSpace(command.CommandText)); + global::System.Diagnostics.Debug.Assert((command.CommandType ?? global::Dapper.DapperAotExtensions.GetCommandType(command.CommandText)) == global::System.Data.CommandType.Text); + global::System.Diagnostics.Debug.Assert(command.Buffered is true); + global::System.Diagnostics.Debug.Assert(command.Parameters is not null); + + return global::Dapper.DapperAotExtensions.Command(cnn, command.Transaction, command.CommandText, command.CommandType ?? default, command.CommandTimeout ?? default, CommandFactory1.Instance).QueryBuffered(command.Parameters, global::Dapper.RowFactory.Inbuilt.Value()); + + } + + private class CommonCommandFactory : global::Dapper.CommandFactory + { + public override global::System.Data.Common.DbCommand GetCommand(global::System.Data.Common.DbConnection connection, string sql, global::System.Data.CommandType commandType, T args) + { + var cmd = base.GetCommand(connection, sql, commandType, args); + // apply special per-provider command initialization logic for OracleCommand + if (cmd is global::Oracle.ManagedDataAccess.Client.OracleCommand cmd0) + { + cmd0.BindByName = true; + cmd0.InitialLONGFetchSize = -1; + + } + return cmd; + } + + } + + private static readonly CommonCommandFactory DefaultCommandFactory = new(); + + private sealed class CommandFactory0 : CommonCommandFactory // + { + internal static readonly CommandFactory0 Instance = new(); + public override void AddParameters(in global::Dapper.UnifiedCommand cmd, object? args) + { + var typed = Cast(args, static () => new { X = default(int) }); // expected shape + var ps = cmd.Parameters; + global::System.Data.Common.DbParameter p; + p = cmd.CreateParameter(); + p.ParameterName = "X"; + p.DbType = global::System.Data.DbType.Int32; + p.Direction = global::System.Data.ParameterDirection.Input; + p.Value = AsValue(typed.X); + ps.Add(p); + + } + public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, object? args) + { + var typed = Cast(args, static () => new { X = default(int) }); // expected shape + var ps = cmd.Parameters; + ps[0].Value = AsValue(typed.X); + + } + public override bool CanPrepare => true; + + } + + private sealed class CommandFactory1 : CommonCommandFactory // + { + internal static readonly CommandFactory1 Instance = new(); + public override void AddParameters(in global::Dapper.UnifiedCommand cmd, object? args) + { + var typed = Cast(args, static () => new { name = default(string)! }); // expected shape + var ps = cmd.Parameters; + global::System.Data.Common.DbParameter p; + p = cmd.CreateParameter(); + p.ParameterName = "name"; + p.DbType = global::System.Data.DbType.String; + p.Direction = global::System.Data.ParameterDirection.Input; + SetValueWithDefaultSize(p, typed.name); + ps.Add(p); + + } + public override void UpdateParameters(in global::Dapper.UnifiedCommand cmd, object? args) + { + var typed = Cast(args, static () => new { name = default(string)! }); // expected shape + var ps = cmd.Parameters; + ps[0].Value = AsValue(typed.name); + + } + public override bool CanPrepare => true; + + } + + + } +} +namespace System.Runtime.CompilerServices +{ + // this type is needed by the compiler to implement interceptors - it doesn't need to + // come from the runtime itself, though + + [global::System.Diagnostics.Conditional("DEBUG")] // not needed post-build, so: evaporate + [global::System.AttributeUsage(global::System.AttributeTargets.Method, AllowMultiple = true)] + sealed file class InterceptsLocationAttribute : global::System.Attribute + { + public InterceptsLocationAttribute(string path, int lineNumber, int columnNumber) + { + _ = path; + _ = lineNumber; + _ = columnNumber; + } + } +} \ No newline at end of file diff --git a/test/Dapper.AOT.Test/Interceptors/CommandDefinition.output.txt b/test/Dapper.AOT.Test/Interceptors/CommandDefinition.output.txt new file mode 100644 index 0000000..27f3fca --- /dev/null +++ b/test/Dapper.AOT.Test/Interceptors/CommandDefinition.output.txt @@ -0,0 +1,22 @@ +Input code has 1 diagnostics from 'Interceptors/CommandDefinition.input.cs': + +Hidden CS8019 Interceptors/CommandDefinition.input.cs L4 C1 +Unnecessary using directive. +Generator produced 1 diagnostics: + +Hidden DAP000 L1 C1 +Dapper.AOT handled 6 of 6 possible call-sites using 6 interceptors, 2 commands and 0 readers +Output code has 1 diagnostics from 'Interceptors/CommandDefinition.input.cs': + +Hidden CS8019 Interceptors/CommandDefinition.input.cs L4 C1 +Unnecessary using directive. +Output code has 3 diagnostics from 'Dapper.AOT.Analyzers/Dapper.CodeAnalysis.DapperInterceptorGenerator/Test.generated.cs': + +Warning CS8619 Dapper.AOT.Analyzers/Dapper.CodeAnalysis.DapperInterceptorGenerator/Test.generated.cs L36 C20 +Nullability of reference types in value of type 'Task' doesn't match target type 'Task'. + +Warning CS8619 Dapper.AOT.Analyzers/Dapper.CodeAnalysis.DapperInterceptorGenerator/Test.generated.cs L52 C20 +Nullability of reference types in value of type 'Task' doesn't match target type 'Task'. + +Warning CS8619 Dapper.AOT.Analyzers/Dapper.CodeAnalysis.DapperInterceptorGenerator/Test.generated.cs L81 C20 +Nullability of reference types in value of type 'Task' doesn't match target type 'Task'.