Skip to content

Commit

Permalink
Modify test to avoid simplification
Browse files Browse the repository at this point in the history
  • Loading branch information
roji committed Dec 18, 2024
1 parent 6ba1bab commit b7c0e4a
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 15 deletions.
35 changes: 31 additions & 4 deletions src/EFCore.Cosmos/Query/Internal/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -685,10 +685,37 @@ public virtual SqlExpression Condition(SqlExpression test, SqlExpression ifTrue,
{
var typeMapping = ExpressionExtensions.InferTypeMapping(ifTrue, ifFalse);

return new SqlConditionalExpression(
ApplyTypeMapping(test, _boolTypeMapping),
ApplyTypeMapping(ifTrue, typeMapping),
ApplyTypeMapping(ifFalse, typeMapping));
test = ApplyTypeMapping(test, _boolTypeMapping);
ifTrue = ApplyTypeMapping(ifTrue, typeMapping);
ifFalse = ApplyTypeMapping(ifFalse, typeMapping);

// Simplify:
// a == b ? b : a -> a
// a != b ? a : b -> a
if (test is SqlBinaryExpression
{
OperatorType: ExpressionType.Equal or ExpressionType.NotEqual,
Left: var left,
Right: var right
} binary)
{
// Reverse ifEqual/ifNotEqual for ExpressionType.NotEqual for easier reasoning below
var (ifEqual, ifNotEqual) = binary.OperatorType is ExpressionType.Equal ? (ifTrue, ifFalse) : (ifFalse, ifTrue);

// a == b ? b : a -> a
if (left.Equals(ifNotEqual) && right.Equals(ifEqual))
{
return left;
}

// b == a ? b : a -> a
if (right.Equals(ifNotEqual) && left.Equals(ifEqual))
{
return right;
}
}

return new SqlConditionalExpression(test, ifTrue, ifFalse);
}

/// <summary>
Expand Down
9 changes: 6 additions & 3 deletions src/EFCore.Relational/Query/SqlExpressionFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -835,13 +835,16 @@ public virtual SqlExpression Case(
&& typeMappedWhenClauses is
[
{
Test: SqlBinaryExpression { OperatorType: ExpressionType.Equal or ExpressionType.NotEqual } binary,
Test: SqlBinaryExpression
{
OperatorType: ExpressionType.Equal or ExpressionType.NotEqual,
Left: var left,
Right: var right
} binary,
Result: var result
}
])
{
var (left, right) = (binary.Left, binary.Right);

// Reverse ifEqual/ifNotEqual for ExpressionType.NotEqual for easier reasoning below
var (ifEqual, ifNotEqual) = binary.OperatorType is ExpressionType.Equal
? (result, elseResult ?? Constant(null, result.Type, result.TypeMapping))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@ public override Task Conditional_simplifiable_equality(bool async)
{
await base.Conditional_simplifiable_equality(a);

// TODO: Simplify this away, as per #35327 for relational
AssertSql(
"""
SELECT VALUE c
FROM root c
WHERE (((c["Int"] = 9) ? 9 : c["Int"]) > 1)
WHERE (c["Int"] > 1)
""");
});

Expand All @@ -34,12 +33,11 @@ public override Task Conditional_simplifiable_inequality(bool async)
{
await base.Conditional_simplifiable_inequality(a);

// TODO: Simplify this away, as per #35327 for relational
AssertSql(
"""
SELECT VALUE c
FROM root c
WHERE (((c["Int"] != 8) ? c["Int"] : 8) > 1)
WHERE (c["Int"] > 1)
""");
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -753,7 +753,7 @@ public virtual Task Where_equal_with_conditional(bool async)
ss => ss.Set<NullSemanticsEntity1>().Where(
e => (e.NullableStringA == e.NullableStringB
? e.NullableStringA
: e.NullableStringB)
: e.NullableStringC)
== e.NullableStringC).Select(e => e.Id));

[ConditionalTheory]
Expand All @@ -765,7 +765,7 @@ public virtual Task Where_not_equal_with_conditional(bool async)
e => e.NullableStringC
!= (e.NullableStringA == e.NullableStringB
? e.NullableStringA
: e.NullableStringB)).Select(e => e.Id));
: e.NullableStringC)).Select(e => e.Id));

[ConditionalTheory]
[MemberData(nameof(IsAsyncData))]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2205,7 +2205,13 @@ public override async Task Where_equal_with_conditional(bool async)
"""
SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE [e].[NullableStringB] = [e].[NullableStringC] OR ([e].[NullableStringB] IS NULL AND [e].[NullableStringC] IS NULL)
WHERE CASE
WHEN [e].[NullableStringA] = [e].[NullableStringB] OR ([e].[NullableStringA] IS NULL AND [e].[NullableStringB] IS NULL) THEN [e].[NullableStringA]
ELSE [e].[NullableStringC]
END = [e].[NullableStringC] OR (CASE
WHEN [e].[NullableStringA] = [e].[NullableStringB] OR ([e].[NullableStringA] IS NULL AND [e].[NullableStringB] IS NULL) THEN [e].[NullableStringA]
ELSE [e].[NullableStringC]
END IS NULL AND [e].[NullableStringC] IS NULL)
""");
}

Expand All @@ -2217,7 +2223,16 @@ public override async Task Where_not_equal_with_conditional(bool async)
"""
SELECT [e].[Id]
FROM [Entities1] AS [e]
WHERE ([e].[NullableStringC] <> [e].[NullableStringB] OR [e].[NullableStringC] IS NULL OR [e].[NullableStringB] IS NULL) AND ([e].[NullableStringC] IS NOT NULL OR [e].[NullableStringB] IS NOT NULL)
WHERE ([e].[NullableStringC] <> CASE
WHEN [e].[NullableStringA] = [e].[NullableStringB] OR ([e].[NullableStringA] IS NULL AND [e].[NullableStringB] IS NULL) THEN [e].[NullableStringA]
ELSE [e].[NullableStringC]
END OR [e].[NullableStringC] IS NULL OR CASE
WHEN [e].[NullableStringA] = [e].[NullableStringB] OR ([e].[NullableStringA] IS NULL AND [e].[NullableStringB] IS NULL) THEN [e].[NullableStringA]
ELSE [e].[NullableStringC]
END IS NULL) AND ([e].[NullableStringC] IS NOT NULL OR CASE
WHEN [e].[NullableStringA] = [e].[NullableStringB] OR ([e].[NullableStringA] IS NULL AND [e].[NullableStringB] IS NULL) THEN [e].[NullableStringA]
ELSE [e].[NullableStringC]
END IS NOT NULL)
""");
}

Expand Down

0 comments on commit b7c0e4a

Please sign in to comment.