diff --git a/src/libraries/Microsoft.PowerFx.Core/Functions/UserDefinedFunction.cs b/src/libraries/Microsoft.PowerFx.Core/Functions/UserDefinedFunction.cs index 31d436703c..939db68a88 100644 --- a/src/libraries/Microsoft.PowerFx.Core/Functions/UserDefinedFunction.cs +++ b/src/libraries/Microsoft.PowerFx.Core/Functions/UserDefinedFunction.cs @@ -8,6 +8,7 @@ using Microsoft.CodeAnalysis; using Microsoft.PowerFx.Core.App; using Microsoft.PowerFx.Core.App.Controls; +using Microsoft.PowerFx.Core.App.ErrorContainers; using Microsoft.PowerFx.Core.Binding; using Microsoft.PowerFx.Core.Binding.BindInfo; using Microsoft.PowerFx.Core.Entities; @@ -57,6 +58,8 @@ public override bool IsServerDelegatable(CallNode callNode, TexlBinding binding) public override bool SupportsParamCoercion => true; + public override bool HasPreciseErrors => true; + private const int MaxParameterCount = 30; public TexlNode UdfBody { get; } @@ -77,6 +80,27 @@ public override bool TryGetDataSource(CallNode callNode, TexlBinding binding, ou public bool HasDelegationWarning => _binding?.ErrorContainer.GetErrors().Any(error => error.MessageKey.Contains("SuggestRemoteExecutionHint")) ?? false; + public override bool CheckTypes(CheckTypesContext context, TexlNode[] args, DType[] argTypes, IErrorContainer errors, out DType returnType, out Dictionary nodeToCoercedTypeMap) + { + if (!base.CheckTypes(context, args, argTypes, errors, out returnType, out nodeToCoercedTypeMap)) + { + return false; + } + + for (int i = 0; i < argTypes.Length; i++) + { + if ((argTypes[i].IsTableNonObjNull || argTypes[i].IsRecordNonObjNull) && + !ParamTypes[i].Accepts(argTypes[i], exact: true, useLegacyDateTimeAccepts: false, usePowerFxV1CompatibilityRules: context.Features.PowerFxV1CompatibilityRules, true) && + !argTypes[i].CoercesTo(ParamTypes[i], true, false, context.Features, true)) + { + errors.EnsureError(DocumentErrorSeverity.Severe, args[i], TexlStrings.ErrBadSchema_ExpectedType, ParamTypes[i].GetKindString()); + return false; + } + } + + return true; + } + /// /// Initializes a new instance of the class. /// @@ -167,15 +191,27 @@ public void CheckTypesOnDeclaration(CheckTypesContext context, DType actualBodyR Contracts.AssertValue(actualBodyReturnType); Contracts.AssertValue(binding); - if (!ReturnType.Accepts(actualBodyReturnType, exact: true, useLegacyDateTimeAccepts: false, usePowerFxV1CompatibilityRules: context.Features.PowerFxV1CompatibilityRules)) + if (!ReturnType.Accepts( + actualBodyReturnType, + exact: true, + useLegacyDateTimeAccepts: false, + usePowerFxV1CompatibilityRules: context.Features.PowerFxV1CompatibilityRules, + restrictiveAggregateTypes: true)) { - if (actualBodyReturnType.CoercesTo(ReturnType, true, false, context.Features)) + if (actualBodyReturnType.CoercesTo(ReturnType, true, false, context.Features, restrictiveAggregateTypes: true)) { _binding.SetCoercedType(binding.Top, ReturnType); } else { var node = UdfBody is VariadicOpNode variadicOpNode ? variadicOpNode.Children.Last() : UdfBody; + + if ((ReturnType.IsTable && actualBodyReturnType.IsTable) || (ReturnType.IsRecord && actualBodyReturnType.IsRecord)) + { + binding.ErrorContainer.EnsureError(DocumentErrorSeverity.Severe, node, TexlStrings.ErrUDF_ReturnTypeSchemaDoesNotMatch, ReturnType.GetKindString()); + return; + } + binding.ErrorContainer.EnsureError(DocumentErrorSeverity.Severe, node, TexlStrings.ErrUDF_ReturnTypeDoesNotMatch, ReturnType.GetKindString(), actualBodyReturnType.GetKindString()); } } diff --git a/src/libraries/Microsoft.PowerFx.Core/Localization/Strings.cs b/src/libraries/Microsoft.PowerFx.Core/Localization/Strings.cs index 33610571dc..b3eecd25c9 100644 --- a/src/libraries/Microsoft.PowerFx.Core/Localization/Strings.cs +++ b/src/libraries/Microsoft.PowerFx.Core/Localization/Strings.cs @@ -783,6 +783,7 @@ internal static class TexlStrings public static ErrorResourceKey ErrUDF_DuplicateParameter = new ErrorResourceKey("ErrUDF_DuplicateParameter"); public static ErrorResourceKey ErrUDF_UnknownType = new ErrorResourceKey("ErrUDF_UnknownType"); public static ErrorResourceKey ErrUDF_ReturnTypeDoesNotMatch = new ErrorResourceKey("ErrUDF_ReturnTypeDoesNotMatch"); + public static ErrorResourceKey ErrUDF_ReturnTypeSchemaDoesNotMatch = new ErrorResourceKey("ErrUDF_ReturnTypeSchemaDoesNotMatch"); public static ErrorResourceKey ErrUDF_TooManyParameters = new ErrorResourceKey("ErrUDF_TooManyParameters"); public static ErrorResourceKey ErrUDF_MissingReturnType = new ErrorResourceKey("ErrUDF_MissingReturnType"); public static ErrorResourceKey ErrUDF_MissingParamType = new ErrorResourceKey("ErrUDF_MissingParamType"); diff --git a/src/libraries/Microsoft.PowerFx.Core/Types/DType.cs b/src/libraries/Microsoft.PowerFx.Core/Types/DType.cs index c33312a036..5b9cad8e3f 100644 --- a/src/libraries/Microsoft.PowerFx.Core/Types/DType.cs +++ b/src/libraries/Microsoft.PowerFx.Core/Types/DType.cs @@ -1851,12 +1851,13 @@ private bool AcceptsEntityType(DType type, bool usePowerFxV1CompatibilityRules) /// Legacy rules for accepting date/time types. /// Use PFx v1 compatibility rules if enabled (less /// permissive Accepts relationships). + /// Flag to restrict using aggregate types with more fields than expected. /// /// True if accepts , false otherwise. /// - public bool Accepts(DType type, bool exact, bool useLegacyDateTimeAccepts, bool usePowerFxV1CompatibilityRules) + public bool Accepts(DType type, bool exact, bool useLegacyDateTimeAccepts, bool usePowerFxV1CompatibilityRules, bool restrictiveAggregateTypes = false) { - return Accepts(type, out _, out _, exact, useLegacyDateTimeAccepts, usePowerFxV1CompatibilityRules); + return Accepts(type, out _, out _, exact, useLegacyDateTimeAccepts, usePowerFxV1CompatibilityRules, restrictiveAggregateTypes); } /// @@ -1885,10 +1886,11 @@ public bool Accepts(DType type, bool exact, bool useLegacyDateTimeAccepts, bool /// Legacy rules for accepting date/time types. /// Use PFx v1 compatibility rules if enabled (less /// permissive Accepts relationships). + /// Flag to restrict using aggregate types with more fields than expected. /// /// True if accepts , false otherwise. /// - public virtual bool Accepts(DType type, out KeyValuePair schemaDifference, out DType schemaDifferenceType, bool exact, bool useLegacyDateTimeAccepts, bool usePowerFxV1CompatibilityRules) + public virtual bool Accepts(DType type, out KeyValuePair schemaDifference, out DType schemaDifferenceType, bool exact, bool useLegacyDateTimeAccepts, bool usePowerFxV1CompatibilityRules, bool restrictiveAggregateTypes = false) { AssertValid(); type.AssertValid(); @@ -1941,7 +1943,7 @@ bool DefaultReturnValue(DType targetType) => if (Kind == type.Kind) { - return TreeAccepts(this, TypeTree, type.TypeTree, out schemaDifference, out schemaDifferenceType, exact, useLegacyDateTimeAccepts, usePowerFxV1CompatibilityRules: usePowerFxV1CompatibilityRules); + return TreeAccepts(this, TypeTree, type.TypeTree, out schemaDifference, out schemaDifferenceType, exact, useLegacyDateTimeAccepts, usePowerFxV1CompatibilityRules: usePowerFxV1CompatibilityRules, restrictiveAggregateTypes); } accepts = type.Kind == DKind.Unknown || type.Kind == DKind.Deferred; @@ -1955,7 +1957,7 @@ bool DefaultReturnValue(DType targetType) => if (Kind == type.Kind || type.IsExpandEntity) { - return TreeAccepts(this, TypeTree, type.TypeTree, out schemaDifference, out schemaDifferenceType, exact, useLegacyDateTimeAccepts, usePowerFxV1CompatibilityRules: usePowerFxV1CompatibilityRules); + return TreeAccepts(this, TypeTree, type.TypeTree, out schemaDifference, out schemaDifferenceType, exact, useLegacyDateTimeAccepts, usePowerFxV1CompatibilityRules: usePowerFxV1CompatibilityRules, restrictiveAggregateTypes); } accepts = (IsMultiSelectOptionSet() && TypeTree.GetPairs().First().Value.OptionSetInfo == type.OptionSetInfo) || type.Kind == DKind.Unknown || type.Kind == DKind.Deferred; @@ -2175,7 +2177,7 @@ bool DefaultReturnValue(DType targetType) => } // Implements Accepts for Record and Table types. - private static bool TreeAccepts(DType parentType, TypeTree treeDst, TypeTree treeSrc, out KeyValuePair schemaDifference, out DType treeSrcSchemaDifferenceType, bool exact = true, bool useLegacyDateTimeAccepts = false, bool usePowerFxV1CompatibilityRules = false) + private static bool TreeAccepts(DType parentType, TypeTree treeDst, TypeTree treeSrc, out KeyValuePair schemaDifference, out DType treeSrcSchemaDifferenceType, bool exact = true, bool useLegacyDateTimeAccepts = false, bool usePowerFxV1CompatibilityRules = false, bool restrictiveAggregateTypes = false) { treeDst.AssertValid(); treeSrc.AssertValid(); @@ -2215,7 +2217,7 @@ private static bool TreeAccepts(DType parentType, TypeTree treeDst, TypeTree tre return false; } - if (!pairDst.Value.Accepts(type, out var recursiveSchemaDifference, out var recursiveSchemaDifferenceType, exact, useLegacyDateTimeAccepts, usePowerFxV1CompatibilityRules: usePowerFxV1CompatibilityRules)) + if (!pairDst.Value.Accepts(type, out var recursiveSchemaDifference, out var recursiveSchemaDifferenceType, exact, useLegacyDateTimeAccepts, usePowerFxV1CompatibilityRules: usePowerFxV1CompatibilityRules, restrictiveAggregateTypes)) { if (!TryGetDisplayNameForColumn(parentType, pairDst.Key, out var colName)) { @@ -2237,6 +2239,17 @@ private static bool TreeAccepts(DType parentType, TypeTree treeDst, TypeTree tre } } + if (restrictiveAggregateTypes) + { + foreach (var pairSrc in treeSrc) + { + if (!treeDst.Contains(pairSrc.Key)) + { + return false; + } + } + } + return true; } @@ -3141,17 +3154,17 @@ public bool ContainsControlType(DPath path) (n.Type.IsAggregate && n.Type.ContainsControlType(DPath.Root))); } - public bool CoercesTo(DType typeDest, bool aggregateCoercion, bool isTopLevelCoercion, Features features) + public bool CoercesTo(DType typeDest, bool aggregateCoercion, bool isTopLevelCoercion, Features features, bool restrictiveAggregateTypes = false) { - return CoercesTo(typeDest, out _, aggregateCoercion, isTopLevelCoercion, features); + return CoercesTo(typeDest, out _, aggregateCoercion, isTopLevelCoercion, features, restrictiveAggregateTypes); } - public bool CoercesTo(DType typeDest, out bool isSafe, bool aggregateCoercion, bool isTopLevelCoercion, Features features) + public bool CoercesTo(DType typeDest, out bool isSafe, bool aggregateCoercion, bool isTopLevelCoercion, Features features, bool restrictiveAggregateTypes = false) { - return CoercesTo(typeDest, out isSafe, out _, out _, out _, aggregateCoercion, isTopLevelCoercion, features); + return CoercesTo(typeDest, out isSafe, out _, out _, out _, aggregateCoercion, isTopLevelCoercion, features, restrictiveAggregateTypes); } - public bool AggregateCoercesTo(DType typeDest, out bool isSafe, out DType coercionType, out KeyValuePair schemaDifference, out DType schemaDifferenceType, Features features, bool aggregateCoercion = true) + public bool AggregateCoercesTo(DType typeDest, out bool isSafe, out DType coercionType, out KeyValuePair schemaDifference, out DType schemaDifferenceType, Features features, bool aggregateCoercion = true, bool restrictiveAggregateTypes = false) { Contracts.Assert(IsAggregate); @@ -3196,7 +3209,8 @@ public bool AggregateCoercesTo(DType typeDest, out bool isSafe, out DType coerci out schemaDifferenceType, aggregateCoercion: true, isTopLevelCoercion: false, - features); + features, + restrictiveAggregateTypes); } if (Kind != typeDest.Kind) @@ -3231,7 +3245,8 @@ public bool AggregateCoercesTo(DType typeDest, out bool isSafe, out DType coerci out var fieldSchemaDifferenceType, aggregateCoercion, isTopLevelCoercion: false, - features); + features, + restrictiveAggregateTypes); // This is the attempted coercion type. If we fail, we need to know this for error handling coercionType = coercionType.Add(typedName.Name, fieldCoercionType); @@ -3259,6 +3274,17 @@ public bool AggregateCoercesTo(DType typeDest, out bool isSafe, out DType coerci isSafe &= fieldIsSafe; } + if (restrictiveAggregateTypes) + { + foreach (var typedName in GetNames(DPath.Root)) + { + if (!typeDest.TryGetType(typedName.Name, out _)) + { + return false; + } + } + } + return isValid; } @@ -3273,7 +3299,8 @@ public virtual bool CoercesTo( out DType schemaDifferenceType, bool aggregateCoercion, bool isTopLevelCoercion, - Features features) + Features features, + bool restrictiveAggregateTypes = false) { AssertValid(); Contracts.Assert(typeDest.IsValid); @@ -3290,7 +3317,7 @@ public virtual bool CoercesTo( return false; } - if (typeDest.Accepts(this, exact: true, useLegacyDateTimeAccepts: false, usePowerFxV1CompatibilityRules: usePowerFxV1CompatibilityRules)) + if (typeDest.Accepts(this, exact: true, useLegacyDateTimeAccepts: false, usePowerFxV1CompatibilityRules: usePowerFxV1CompatibilityRules, restrictiveAggregateTypes)) { coercionType = typeDest; return true; @@ -3335,7 +3362,8 @@ public virtual bool CoercesTo( out schemaDifference, out schemaDifferenceType, features, - aggregateCoercion); + aggregateCoercion, + restrictiveAggregateTypes); } var subtypeCoerces = SubtypeCoercesTo( diff --git a/src/strings/PowerFxResources.en-US.resx b/src/strings/PowerFxResources.en-US.resx index 998cd665d6..e8f01ecbdc 100644 --- a/src/strings/PowerFxResources.en-US.resx +++ b/src/strings/PowerFxResources.en-US.resx @@ -4221,6 +4221,10 @@ The stated function return type '{0}' does not match the return type of the function body '{1}'. This error message shows up when expected return type does not match with actual return type. The arguments '{0}' and '{1}' will be replaced with data types. For example, "The stated function return type 'Number' does not match the return type of the function body 'Table'" + + The schema of stated function return type '{0}' does not match the schema of return type of the function body. + This error message shows up when expected return type schema does not match with schema of actual return type. The arguments '{0}' will be replaced with aggregate data types. For example, "The schema of stated function return type 'Table' does not match the schema of return type of the function body." + Function {0} has too many parameters. User-defined functions support up to {1} parameters. This error message shows up when a user tries to define a function with too many parameters. {0} - the name of the user-defined function, {1} - the max number of parameters allowed. diff --git a/src/tests/Microsoft.PowerFx.Core.Tests.Shared/UserDefinedTypeTests.cs b/src/tests/Microsoft.PowerFx.Core.Tests.Shared/UserDefinedTypeTests.cs index 89e951ca6e..880e9ac56f 100644 --- a/src/tests/Microsoft.PowerFx.Core.Tests.Shared/UserDefinedTypeTests.cs +++ b/src/tests/Microsoft.PowerFx.Core.Tests.Shared/UserDefinedTypeTests.cs @@ -156,6 +156,20 @@ public void TestRecordOfErrors(string typeDefinition, string expectedMessageKey) Assert.Contains(errors, e => e.MessageKey.Contains(expectedMessageKey)); } + [Theory] + [InlineData("f():T = {x: 5, y: 5}; T := Type({x: Number});", "ErrUDF_ReturnTypeSchemaDoesNotMatch")] + [InlineData("f(x:T):Number = x.n; T := Type({n: Number}); g(): Number = f({n: 5, m: 5});", "ErrBadSchema_ExpectedType")] + [InlineData("f():T = [{x: 5, y: 5}]; T := Type([{x: Number}]);", "ErrUDF_ReturnTypeSchemaDoesNotMatch")] + [InlineData("f(x:T):T = x; T := Type([{n: Number}]); g(): T = f([{n: 5, m: 5}]);", "ErrBadSchema_ExpectedType")] + public void TestAggregateTypeErrors(string typeDefinition, string expectedMessageKey) + { + var checkResult = new DefinitionsCheckResult() + .SetText(typeDefinition) + .SetBindingInfo(_primitiveTypes); + var errors = checkResult.ApplyErrors(); + Assert.Contains(errors, e => e.MessageKey.Contains(expectedMessageKey)); + } + [Theory] [InlineData("T := Type({ x: 5+5, y: -5 });", 2)] [InlineData("T := Type(Type(Number));", 1)] diff --git a/src/tests/Microsoft.PowerFx.Interpreter.Tests.Shared/RecalcEngineTests.cs b/src/tests/Microsoft.PowerFx.Interpreter.Tests.Shared/RecalcEngineTests.cs index 5efafa90d1..cdfdbf17a2 100644 --- a/src/tests/Microsoft.PowerFx.Interpreter.Tests.Shared/RecalcEngineTests.cs +++ b/src/tests/Microsoft.PowerFx.Interpreter.Tests.Shared/RecalcEngineTests.cs @@ -1865,24 +1865,12 @@ protected override bool TryGetField(FormulaType fieldType, string fieldName, out true, 42.0)] - // Functions accept record with more/less fields - [InlineData( - "People := Type([{Name: Text, Age: Number}]); countMinors(p: People): Number = CountRows(Filter(p, Age < 18));", - "countMinors([{Name: \"Bob\", Age: 21, Title: \"Engineer\"}, {Name: \"Alice\", Age: 25, Title: \"Manager\"}])", - true, - 0.0)] + // Functions accept record with less fields [InlineData( "Employee := Type({Name: Text, Age: Number, Title: Text}); getAge(e: Employee): Number = e.Age;", "getAge({Name: \"Bob\", Age: 21})", true, 21.0)] - [InlineData( - @"Employee := Type({Name: Text, Age: Number, Title: Text}); Employees := Type([Employee]); EmployeeNames := Type([{Name: Text}]); - getNames(e: Employees):EmployeeNames = ShowColumns(e, Name); - getNamesCount(e: EmployeeNames):Number = CountRows(getNames(e));", - "getNamesCount([{Name: \"Jim\", Age:25}, {Name: \"Tony\", Age:42}])", - true, - 2.0)] [InlineData( @"Employee := Type({Name: Text, Age: Number, Title: Text}); getAge(e: Employee): Number = e.Age; @@ -1949,7 +1937,58 @@ protected override bool TryGetField(FormulaType fieldType, string fieldName, out "f():TestEntity = Entity; g(e: TestEntity):Number = 1;", "g(f())", true, - 1.0)] + 1.0)] + + // Aggregate types with more than expected fields are not allowed in UDF args and return types + // Records + [InlineData( + "f():T = {x: 5, y: 5}; T := Type({x: Number});", + "", + false)] + [InlineData( + "f():T2 = {x: 5, y: 5}; T1 := Type([{x: Number}]); T2 := Type(RecordOf(T1));", + "", + false)] + [InlineData( + "g(x:T):Number = x.n; T := Type({n: Number});", + "g({x: 5, y: 5})", + false)] + + // Nested Records + [InlineData( + "f():T = {a: 5, b: {c: {d: 5, e:42}}}; T := Type({a: Number, b: {c: {d: Number}}});", + "", + false)] + [InlineData( + "g(x:T):Number = x.b.c.d; T := Type({a: Number, b: {c: {d: Number}}});", + "g({a: 5, b: {c: {d: 5, e:42}}})", + false)] + + // Tables + [InlineData( + "f():T = [{x: 5, y: 5}]; T := Type([{x: Number}]);", + "", + false)] + [InlineData( + "People := Type([{Name: Text, Age: Number}]); countMinors(p: People): Number = CountRows(Filter(p, Age < 18));", + "countMinors([{Name: \"Bob\", Age: 21, Title: \"Engineer\"}, {Name: \"Alice\", Age: 25, Title: \"Manager\"}])", + false)] + [InlineData( + @"Employee := Type({Name: Text, Age: Number, Title: Text}); Employees := Type([Employee]); EmployeeNames := Type([{Name: Text}]); + getNames(e: Employees):EmployeeNames = ShowColumns(e, Name); + getNamesCount(e: EmployeeNames):Number = CountRows(getNames(e));", + "getNamesCount([{Name: \"Jim\", Age:25}, {Name: \"Tony\", Age:42}])", + false)] + + // Nested Tables + [InlineData( + "f():T = {a: 5, b: [{c: {d: 5, e:42}}]}; T := Type([{a: Number, b: [{c: {d: Number}}]}]);", + "", + false)] + [InlineData( + "g(x:T):Number = First(First(x).b).c.d; T := Type([{a: Number, b: [{c: {d: Number}}]}]);", + "g({a: 5, b: [{c: {d: 5, e:42}}]})", + false)] public void UserDefinedTypeTest(string userDefinitions, string evalExpression, bool isValid, double expectedResult = 0) { var config = new PowerFxConfig(); @@ -1970,7 +2009,11 @@ public void UserDefinedTypeTest(string userDefinitions, string evalExpression, b } else { - Assert.Throws(() => recalcEngine.AddUserDefinitions(userDefinitions, CultureInfo.InvariantCulture)); + Assert.ThrowsAny(() => + { + recalcEngine.AddUserDefinitions(userDefinitions, CultureInfo.InvariantCulture); + recalcEngine.Eval(evalExpression, options: parserOptions); + }); } }