Skip to content

Commit

Permalink
Change ExecuteUpdate to accept non-expression lambda
Browse files Browse the repository at this point in the history
Closes #32018
  • Loading branch information
roji committed Dec 5, 2024
1 parent 76d5bef commit f378a22
Show file tree
Hide file tree
Showing 38 changed files with 931 additions and 383 deletions.
8 changes: 4 additions & 4 deletions src/EFCore.Design/Query/Internal/CSharpToLinqTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1245,10 +1245,10 @@ private sealed class FakeFieldInfo(
public bool IsNonNullableReferenceType { get; } = isNonNullableReferenceType;

public override object[] GetCustomAttributes(bool inherit)
=> Array.Empty<object>();
=> [];

public override object[] GetCustomAttributes(Type attributeType, bool inherit)
=> Array.Empty<object>();
=> [];

public override bool IsDefined(Type attributeType, bool inherit)
=> false;
Expand Down Expand Up @@ -1289,10 +1289,10 @@ public override RuntimeFieldHandle FieldHandle
private sealed class FakeConstructorInfo(Type type, ParameterInfo[] parameters) : ConstructorInfo
{
public override object[] GetCustomAttributes(bool inherit)
=> Array.Empty<object>();
=> [];

public override object[] GetCustomAttributes(Type attributeType, bool inherit)
=> Array.Empty<object>();
=> [];

public override bool IsDefined(Type attributeType, bool inherit)
=> false;
Expand Down
127 changes: 110 additions & 17 deletions src/EFCore.Design/Query/Internal/PrecompiledQueryCodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -734,18 +734,52 @@ void ProcessCapturedVariables()

for (var i = 1; i < parameters.Length; i++)
{
var parameter = parameters[i];
var (parameterName, parameterType) = (parameters[i].Name!, parameters[i].ParameterType);

if (parameter.ParameterType == typeof(CancellationToken))
if (parameterType == typeof(CancellationToken))
{
continue;
}

if (_funcletizer.CalculatePathsToEvaluatableRoots(operatorMethodCall, i) is not ExpressionTreeFuncletizer.PathNode
evaluatableRootPaths)
ExpressionTreeFuncletizer.PathNode? evaluatableRootPaths;

// ExecuteUpdate requires really special handling: the function accepts a Func<SetPropertyCalls...> argument, but
// we need to run funcletization on the setter lambdas added via that Func<>.
if (operatorMethodCall.Method is
{
Name: nameof(EntityFrameworkQueryableExtensions.ExecuteUpdate)
or nameof(EntityFrameworkQueryableExtensions.ExecuteUpdateAsync),
IsGenericMethod: true
}
&& operatorMethodCall.Method.DeclaringType == typeof(EntityFrameworkQueryableExtensions))
{
// There are no captured variables in this lambda argument - skip the argument
continue;
// First, statically convert the Func<SetPropertyCalls...> to a NewArrayExpression which represents all the
// setters; since that's an expression, we can run the funcletizer on it.
var settersExpression = ProcessExecuteUpdate(operatorMethodCall);
evaluatableRootPaths = _funcletizer.CalculatePathsToEvaluatableRoots(settersExpression);

if (evaluatableRootPaths is null)
{
// There are no captured variables in this lambda argument - skip the argument
continue;
}

// If there were captured variables, generate code to evaluate and build the same NewArrayExpression at runtime,
// and then fall through to the normal logic, generating variable extractors against that NewArrayExpression
// (local var) instead of against the method argument.
code.AppendLine(
$"var setters = {parameterName}(new SetPropertyCalls<{sourceElementTypeName}>()).BuildSettersExpression();");
parameterName = "setters";
parameterType = typeof(NewArrayExpression);
}
else
{
evaluatableRootPaths = _funcletizer.CalculatePathsToEvaluatableRoots(operatorMethodCall, i);
if (evaluatableRootPaths is null)
{
// There are no captured variables in this lambda argument - skip the argument
continue;
}
}

// We have a lambda argument with captured variables. Use the information returned by the funcletizer to generate code
Expand All @@ -756,11 +790,11 @@ void ProcessCapturedVariables()
declaredQueryContextVariable = true;
}

if (!parameter.ParameterType.IsSubclassOf(typeof(Expression)))
if (!parameterType.IsSubclassOf(typeof(Expression)))
{
// Special case: this is a non-lambda argument (Skip/Take/FromSql).
// Simply add the argument directly as a parameter
code.AppendLine($"""queryContext.AddParameter("{evaluatableRootPaths.ParameterName}", {parameter.Name});""");
code.AppendLine($"""queryContext.AddParameter("{evaluatableRootPaths.ParameterName}", {parameterName});""");
continue;
}

Expand All @@ -769,7 +803,7 @@ void ProcessCapturedVariables()
// Lambda argument. Recurse through evaluatable path trees.
foreach (var child in evaluatableRootPaths.Children!)
{
GenerateCapturedVariableExtractors(parameter.Name!, parameter.ParameterType, child);
GenerateCapturedVariableExtractors(parameterName, parameterType, child);

void GenerateCapturedVariableExtractors(
string currentIdentifier,
Expand All @@ -786,12 +820,13 @@ void GenerateCapturedVariableExtractors(

var variableName = capturedVariablesPathTree.ExpressionType.Name;
variableName = char.ToLower(variableName[0]) + variableName[1..^"Expression".Length] + ++variableCounter;
code.AppendLine(
$"var {variableName} = ({capturedVariablesPathTree.ExpressionType.Name}){roslynPathSegment};");

if (capturedVariablesPathTree.Children?.Count > 0)
{
// This is an intermediate node which has captured variables in the children. Continue recursing down.
code.AppendLine(
$"var {variableName} = ({capturedVariablesPathTree.ExpressionType.Name}){roslynPathSegment};");

foreach (var child in capturedVariablesPathTree.Children)
{
GenerateCapturedVariableExtractors(variableName, capturedVariablesPathTree.ExpressionType, child);
Expand All @@ -816,7 +851,7 @@ void GenerateCapturedVariableExtractors(
{
code
.Append('"').Append(capturedVariablesPathTree.ParameterName!).AppendLine("\",")
.AppendLine($"Expression.Lambda<Func<object?>>(Expression.Convert({variableName}, typeof(object)))")
.AppendLine($"Expression.Lambda<Func<object?>>(Expression.Convert({roslynPathSegment}, typeof(object)))")
.AppendLine(".Compile(preferInterpretation: true)")
.AppendLine(".Invoke());");
}
Expand Down Expand Up @@ -1073,15 +1108,23 @@ or nameof(EntityFrameworkQueryableExtensions.ToListAsync)
QueryableMethods.GetSumWithSelector(
method.GetParameters()[1].ParameterType.GenericTypeArguments[0].GenericTypeArguments[1])),

// ExecuteDelete/Update behave just like other scalar-returning operators
// ExecuteDelete behaves just like other scalar-returning operators
nameof(EntityFrameworkQueryableExtensions.ExecuteDeleteAsync) when method.DeclaringType
== typeof(EntityFrameworkQueryableExtensions)
=> RewriteToSync(
typeof(EntityFrameworkQueryableExtensions).GetMethod(nameof(EntityFrameworkQueryableExtensions.ExecuteDelete))),
nameof(EntityFrameworkQueryableExtensions.ExecuteUpdateAsync) when method.DeclaringType
== typeof(EntityFrameworkQueryableExtensions)
=> RewriteToSync(
typeof(EntityFrameworkQueryableExtensions).GetMethod(nameof(EntityFrameworkQueryableExtensions.ExecuteUpdate))),

// ExecuteUpdate is special; it accepts a non-expression-tree argument (Func<SetPropertyCalls, SetPropertyCalls>),
// evaluates it immediately, and injects a different MethodCall node into the expression tree with the resulting setter
// expressions.
// When statically analyzing ExecuteUpdate, we have to manually perform the same thing.
nameof(EntityFrameworkQueryableExtensions.ExecuteUpdate) or nameof(EntityFrameworkQueryableExtensions.ExecuteUpdateAsync)
when method.DeclaringType == typeof(EntityFrameworkQueryableExtensions)
=> Expression.Call(
EntityFrameworkQueryableExtensions.ExecuteUpdateMethodInfo.MakeGenericMethod(
terminatingOperator.Arguments[0].Type.GetSequenceType()),
penultimateOperator,
ProcessExecuteUpdate(terminatingOperator)),

// In the regular case (sync terminating operator which needs to stay in the query tree), simply compose the terminating
// operator over the penultimate and return that.
Expand Down Expand Up @@ -1116,6 +1159,56 @@ MethodCallExpression RewriteToSync(MethodInfo? syncMethod)
}
}

// Accepts an expression tree representing a series of SetProperty() calls, parses them and passes them through the SetPropertyCalls
// builder; returns the resulting NewArrayExpression representing all the setters.
private static NewArrayExpression ProcessExecuteUpdate(MethodCallExpression executeUpdateCall)
{
var setPropertyCalls = Activator.CreateInstance<SetPropertyCalls>();
var settersLambda = (LambdaExpression)executeUpdateCall.Arguments[1];
var settersParameter = settersLambda.Parameters.Single();
var expression = settersLambda.Body;

while (expression != settersParameter)
{
if (expression is MethodCallExpression
{
Method:
{
IsGenericMethod: true,
Name: nameof(SetPropertyCalls<int>.SetProperty),
DeclaringType.IsGenericType: true,
},
Arguments:
[
UnaryExpression { NodeType: ExpressionType.Quote, Operand: LambdaExpression propertySelector },
Expression valueSelector
]
} methodCallExpression
&& methodCallExpression.Method.DeclaringType.GetGenericTypeDefinition() == typeof(SetPropertyCalls<>))
{
if (valueSelector is UnaryExpression
{
NodeType: ExpressionType.Quote,
Operand: LambdaExpression unwrappedValueSelector
})
{
setPropertyCalls.SetProperty(propertySelector, unwrappedValueSelector);
}
else
{
setPropertyCalls.SetProperty(propertySelector, valueSelector);
}

expression = methodCallExpression.Object;
continue;
}

throw new InvalidOperationException(RelationalStrings.InvalidArgumentToExecuteUpdate);
}

return setPropertyCalls.BuildSettersExpression();
}

/// <summary>
/// Contains information on a failure to precompile a specific query in the user's source code.
/// Includes information about the query, its location, and the exception that occured.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,22 @@ public partial class RelationalQueryableMethodTranslatingExpressionVisitor
typeof(RelationalSqlTranslatingExpressionVisitor).GetTypeInfo().GetDeclaredMethod(nameof(ParameterValueExtractor))!;

/// <inheritdoc />
protected override UpdateExpression? TranslateExecuteUpdate(ShapedQueryExpression source, LambdaExpression setPropertyCalls)
protected override UpdateExpression? TranslateExecuteUpdate(ShapedQueryExpression source, IReadOnlyList<ExecuteUpdateSetter> setters)
{
if (setters.Count == 0)
{
throw new UnreachableException("Empty setters list");
}

// Our source may have IncludeExpressions because of owned entities or auto-include; unwrap these, as they're meaningless for
// ExecuteUpdate's lambdas. Note that we don't currently support updates across tables.
source = source.UpdateShaperExpression(new IncludePruner().Visit(source.ShaperExpression));

var setters = new List<(LambdaExpression PropertySelector, Expression ValueExpression)>();
PopulateSetPropertyCalls(setPropertyCalls.Body, setters, setPropertyCalls.Parameters[0]);
if (TranslationErrorDetails != null)
{
return null;
}

if (setters.Count == 0)
{
AddTranslationErrorDetails(RelationalStrings.NoSetPropertyInvocation);
return null;
}

// Translate the setters: the left (property) selectors get translated to ColumnExpressions, the right (value) selectors to
// arbitrary SqlExpressions.
// Note that if the query isn't natively supported, we'll do a pushdown (see PushdownWithPkInnerJoinPredicate below); if that
Expand Down Expand Up @@ -67,42 +64,9 @@ public partial class RelationalQueryableMethodTranslatingExpressionVisitor

return PushdownWithPkInnerJoinPredicate();

void PopulateSetPropertyCalls(
Expression expression,
List<(LambdaExpression, Expression)> list,
ParameterExpression parameter)
{
switch (expression)
{
case ParameterExpression p
when parameter == p:
break;

case MethodCallExpression
{
Method:
{
IsGenericMethod: true,
Name: nameof(SetPropertyCalls<int>.SetProperty),
DeclaringType.IsGenericType: true
}
} methodCallExpression
when methodCallExpression.Method.DeclaringType.GetGenericTypeDefinition() == typeof(SetPropertyCalls<>):
list.Add(((LambdaExpression)methodCallExpression.Arguments[0], methodCallExpression.Arguments[1]));

PopulateSetPropertyCalls(methodCallExpression.Object!, list, parameter);

break;

default:
AddTranslationErrorDetails(RelationalStrings.InvalidArgumentToExecuteUpdate);
break;
}
}

bool TranslateSetters(
ShapedQueryExpression source,
List<(LambdaExpression PropertySelector, Expression ValueExpression)> setters,
IReadOnlyList<ExecuteUpdateSetter> setters,
[NotNullWhen(true)] out List<ColumnValueSetter>? translatedSetters,
[NotNullWhen(true)] out TableExpressionBase? targetTable)
{
Expand Down Expand Up @@ -464,7 +428,7 @@ SqlParameterExpression parameter
var inner = source;
var outerParameter = Expression.Parameter(entityType.ClrType);
var outerKeySelector = Expression.Lambda(outerParameter.CreateKeyValuesExpression(pk.Properties), outerParameter);
var firstPropertyLambdaExpression = setters[0].Item1;
var firstPropertyLambdaExpression = setters[0].PropertySelector;
var entitySource = GetEntitySource(RelationalDependencies.Model, firstPropertyLambdaExpression.Body);
var innerKeySelector = Expression.Lambda(
entitySource.CreateKeyValuesExpression(pk.Properties), firstPropertyLambdaExpression.Parameters);
Expand All @@ -481,6 +445,7 @@ SqlParameterExpression parameter

var propertyReplacement = AccessField(transparentIdentifierType, transparentIdentifierParameter, "Outer");
var valueReplacement = AccessField(transparentIdentifierType, transparentIdentifierParameter, "Inner");
var rewrittenSetters = new ExecuteUpdateSetter[setters.Count];
for (var i = 0; i < setters.Count; i++)
{
var (propertyExpression, valueExpression) = setters[i];
Expand All @@ -499,14 +464,14 @@ SqlParameterExpression parameter
transparentIdentifierParameter)
: valueExpression;

setters[i] = (propertyExpression, valueExpression);
rewrittenSetters[i] = new(propertyExpression, valueExpression);
}

tableExpression = (TableExpression)outerSelectExpression.Tables[0];

// Re-translate the property selectors to get column expressions pointing to the new outer select expression (the original one
// has been pushed down into a subquery).
if (!TranslateSetters(outer, setters, out var translatedSetters, out _))
if (!TranslateSetters(outer, rewrittenSetters, out var translatedSetters, out _))
{
return null;
}
Expand Down
7 changes: 6 additions & 1 deletion src/EFCore.Relational/Query/SqlNullabilityProcessor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1316,7 +1316,12 @@ protected virtual SqlExpression VisitSqlParameter(
bool allowOptimizedExpansion,
out bool nullable)
{
var parameterValue = ParameterValues[sqlParameterExpression.Name];
if (!ParameterValues.TryGetValue(sqlParameterExpression.Name, out var parameterValue))
{
throw new UnreachableException(
$"Encountered SqlParameter with name '{sqlParameterExpression.Name}', but such a parameter does not exist.");
}

nullable = parameterValue == null;

if (nullable)
Expand Down
Loading

0 comments on commit f378a22

Please sign in to comment.