diff --git a/Extensions/Xtensive.Orm.BulkOperations.Tests/ContainsTest.cs b/Extensions/Xtensive.Orm.BulkOperations.Tests/ContainsTest.cs index 846f628e59..af0817a24a 100644 --- a/Extensions/Xtensive.Orm.BulkOperations.Tests/ContainsTest.cs +++ b/Extensions/Xtensive.Orm.BulkOperations.Tests/ContainsTest.cs @@ -109,5 +109,36 @@ public void Test4() Assert.That(session.Query.All().Count(t => t.ProjectedValueAdjustment == -1 && t.Id > 700), Is.EqualTo(1)); } } + + [Test] + public void TestTvp() + { + using (var session = Domain.OpenSession()) + using (var tx = session.OpenTransaction()) { + var updatedRows = session.Query.All() + .Where(t => t.Id.In(IncludeAlgorithm.TableValuedParameter, tagIds)) + .Set(t => t.ProjectedValueAdjustment, 2) + .Update(); + Assert.That(updatedRows, Is.EqualTo(100)); + Assert.That(session.Query.All().Count(t => t.ProjectedValueAdjustment == 2 && t.Id <= 200), Is.EqualTo(100)); + Assert.That(session.Query.All().Count(t => t.ProjectedValueAdjustment == -1 && t.Id > 700), Is.EqualTo(1)); + } + } + + [Test] + public void TestManyIds() + { + using (var session = Domain.OpenSession()) + using (var tx = session.OpenTransaction()) { + var ids = tagIds.Concat(Enumerable.Range(4000, 5000).Select(o => (long)o)); + var updatedRows = session.Query.All() + .Where(t => t.Id.In(ids)) + .Set(t => t.ProjectedValueAdjustment, 2) + .Update(); + Assert.That(updatedRows, Is.EqualTo(100)); + Assert.That(session.Query.All().Count(t => t.ProjectedValueAdjustment == 2 && t.Id <= 200), Is.EqualTo(100)); + Assert.That(session.Query.All().Count(t => t.ProjectedValueAdjustment == -1 && t.Id > 700), Is.EqualTo(1)); + } + } } } diff --git a/Extensions/Xtensive.Orm.BulkOperations/Internals/QueryOperation.cs b/Extensions/Xtensive.Orm.BulkOperations/Internals/QueryOperation.cs index fe7795f5e3..716d285b4d 100644 --- a/Extensions/Xtensive.Orm.BulkOperations/Internals/QueryOperation.cs +++ b/Extensions/Xtensive.Orm.BulkOperations/Internals/QueryOperation.cs @@ -9,6 +9,7 @@ using Xtensive.Core; using Xtensive.Orm.Linq; using Xtensive.Orm.Model; +using Xtensive.Orm.Providers; using Xtensive.Reflection; using Xtensive.Sql; using Xtensive.Sql.Dml; @@ -19,6 +20,7 @@ internal abstract class QueryOperation : Operation where T : class, IEntity { private static readonly ConstantExpression ComplexConditionConstant = Expression.Constant(IncludeAlgorithm.ComplexCondition); + private static readonly ConstantExpression AutoConditionConstant = Expression.Constant(IncludeAlgorithm.Auto); protected IQueryable query; @@ -54,13 +56,19 @@ protected override int ExecuteInternal() if (algorithm == IncludeAlgorithm.Auto) { var arguments = ex.Arguments.ToList(); - arguments[1] = ComplexConditionConstant; - ex = Expression.Call(methodInfo, arguments); + + if (!CanUseTvp(ex.Arguments[0].Type)) { + arguments[1] = ComplexConditionConstant; + ex = Expression.Call(methodInfo, arguments); + } } } else { var arguments = ex.Arguments.ToList(); - arguments.Insert(1, ComplexConditionConstant); + var conditionConstant = CanUseTvp(ex.Arguments[0].Type) + ? AutoConditionConstant + : ComplexConditionConstant; + arguments.Insert(1, AutoConditionConstant); ex = Expression.Call(WellKnownMembers.InMethod.MakeGenericMethod(methodInfo.GetGenericArguments()), arguments); } } @@ -73,6 +81,10 @@ protected override int ExecuteInternal() #region Non-public methods + private bool CanUseTvp(Type fieldType) => + (fieldType == typeof(long) || fieldType == typeof(int) || fieldType == typeof(string)) + && DomainHandler.Handlers.ProviderInfo.Supports(ProviderFeatures.TableValuedParameters); + protected abstract SqlTableRef GetStatementTable(SqlStatement statement); protected abstract SqlExpression GetStatementWhere(SqlStatement statement);