From 8cd4c5cc3ea2ad3f99df82be0ac7ada7a8138502 Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Tue, 10 Dec 2024 14:26:10 -0500 Subject: [PATCH 1/2] Add support for reflectionpad 3d. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 24 ++ lib/Conversion/TorchToTosa/TorchToTosa.cpp | 390 ++++++++++++------ .../Transforms/AbstractInterpLibrary.cpp | 140 ++++--- .../Torch/Transforms/DecomposeComplexOps.cpp | 22 +- .../build_tools/abstract_interp_lib_gen.py | 55 ++- .../build_tools/torch_ods_gen.py | 1 + 6 files changed, 411 insertions(+), 221 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index ff1ffd7e2b62..7acf4a5ed948 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -9997,6 +9997,30 @@ def Torch_AtenReflectionPad2dOp : Torch_Op<"aten.reflection_pad2d", [ }]; } +def Torch_AtenReflectionPad3dOp : Torch_Op<"aten.reflection_pad3d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::reflection_pad3d : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$padding + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenReflectionPad3dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenReflectionPad3dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenPadOp : Torch_Op<"aten.pad", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 1c2f7d6f2a11..8853f2f15414 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -7342,6 +7342,75 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +Value reflectionPadLeftRight(Value input, ArrayRef unpaddedShape, + int64_t paddingLeft, int64_t paddingRight, + TensorType resultType, Location loc, + ConversionPatternRewriter &rewriter) { + + SmallVector resultTensors; + + auto inputType = dyn_cast(input.getType()); + auto inputRank = inputType.getRank(); + auto inputElemTy = inputType.getElementType(); + + // Use tosa.slice and tosa.reverse to get the reflection pads based on the + // padding size + if (paddingLeft > 0) { + SmallVector leftStartSlice(inputRank, 0); + SmallVector leftSizeSlice(unpaddedShape); + + leftStartSlice[inputRank - 1] = 1; + leftSizeSlice[inputRank - 1] = paddingLeft; + + SmallVector leftPadShape(unpaddedShape.begin(), + unpaddedShape.end() - 1); + leftPadShape.push_back(paddingLeft); + + auto leftPadType = RankedTensorType::get(leftPadShape, inputElemTy); + + auto leftPadSlice = rewriter.create( + loc, leftPadType, input, rewriter.getDenseI64ArrayAttr(leftStartSlice), + rewriter.getDenseI64ArrayAttr(leftSizeSlice)); + + auto leftPad = rewriter.create( + loc, leftPadType, leftPadSlice.getResult(), + static_cast(inputRank - 1)); + + resultTensors.push_back(leftPad.getResult()); + } + + resultTensors.push_back(input); + + if (paddingRight > 0) { + SmallVector rightStartSlice(inputRank, 0); + SmallVector rightSizeSlice(unpaddedShape); + + rightStartSlice[inputRank - 1] = + unpaddedShape[inputRank - 1] - paddingRight - 1; + rightSizeSlice[inputRank - 1] = paddingRight; + + SmallVector rightPadShape(unpaddedShape.begin(), + unpaddedShape.end() - 1); + rightPadShape.push_back(paddingRight); + + auto rightPadType = RankedTensorType::get(rightPadShape, inputElemTy); + + auto rightPadSlice = rewriter.create( + loc, rightPadType, input, + rewriter.getDenseI64ArrayAttr(rightStartSlice), + rewriter.getDenseI64ArrayAttr(rightSizeSlice)); + + auto rightPad = rewriter.create( + loc, rightPadType, rightPadSlice.getResult(), + static_cast(inputRank - 1)); + + resultTensors.push_back(rightPad.getResult()); + } + + return tosa::CreateOpAndInfer(rewriter, loc, resultType, + resultTensors, inputRank - 1); +} + // Legalization for aten.reflection_pad1d template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -7355,7 +7424,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto selfShape = selfType.getShape(); auto selfRank = selfType.getRank(); - auto selfElemTy = selfType.getElementType(); auto resultType = dyn_cast(typeConverter->convertType(op.getType())); @@ -7379,65 +7447,87 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } + auto result = + reflectionPadLeftRight(self, selfShape, paddingLeft, paddingRight, + resultType, op->getLoc(), rewriter); + + rewriter.replaceOp(op, result); + return success(); +} + +Value reflectionPadTopBottom(Value leftRightPaddedInput, + ArrayRef unpaddedShape, + int64_t paddingTop, int64_t paddingBottom, + TensorType resultType, Location loc, + ConversionPatternRewriter &rewriter) { SmallVector resultTensors; + auto resultShape = resultType.getShape(); - // Use tosa.slice and tosa.reverse to get the reflection pads based on the - // padding size - if (paddingLeft > 0) { - SmallVector leftStartSlice(selfRank, 0); - SmallVector leftSizeSlice(selfShape); + auto inputType = dyn_cast(leftRightPaddedInput.getType()); + auto inputRank = inputType.getRank(); + auto inputElemTy = inputType.getElementType(); - leftStartSlice[selfRank - 1] = 1; - leftSizeSlice[selfRank - 1] = paddingLeft; + if (paddingTop > 0) { + SmallVector topStartSlice(inputRank, 0); + SmallVector topSizeSlice(unpaddedShape.begin(), + unpaddedShape.end() - 1); + topSizeSlice.push_back(resultShape.back()); - SmallVector leftPadShape(selfShape.begin(), selfShape.end() - 1); - leftPadShape.push_back(paddingLeft); + topStartSlice[inputRank - 2] = 1; + topSizeSlice[inputRank - 2] = paddingTop; - auto leftPadType = RankedTensorType::get(leftPadShape, selfElemTy); + SmallVector topPadShape(unpaddedShape.begin(), + unpaddedShape.end() - 2); + topPadShape.push_back(paddingTop); + topPadShape.push_back(resultShape.back()); - auto leftPadSlice = rewriter.create( - op->getLoc(), leftPadType, self, - rewriter.getDenseI64ArrayAttr(leftStartSlice), - rewriter.getDenseI64ArrayAttr(leftSizeSlice)); + auto topPadType = RankedTensorType::get(topPadShape, inputElemTy); - auto leftPad = rewriter.create( - op->getLoc(), leftPadType, leftPadSlice.getResult(), - static_cast(selfRank - 1)); + auto topPadSlice = rewriter.create( + loc, topPadType, leftRightPaddedInput, + rewriter.getDenseI64ArrayAttr(topStartSlice), + rewriter.getDenseI64ArrayAttr(topSizeSlice)); - resultTensors.push_back(leftPad.getResult()); + auto topPad = rewriter.create( + loc, topPadType, topPadSlice.getResult(), + static_cast(inputRank - 2)); + + resultTensors.push_back(topPad.getResult()); } - resultTensors.push_back(self); + resultTensors.push_back(leftRightPaddedInput); - if (paddingRight > 0) { - SmallVector rightStartSlice(selfRank, 0); - SmallVector rightSizeSlice(selfShape); + if (paddingBottom > 0) { + SmallVector bottomStartSlice(inputRank, 0); + SmallVector bottomSizeSlice(unpaddedShape.begin(), + unpaddedShape.end() - 1); + bottomSizeSlice.push_back(resultShape.back()); - rightStartSlice[selfRank - 1] = selfShape[selfRank - 1] - paddingRight - 1; - rightSizeSlice[selfRank - 1] = paddingRight; + bottomStartSlice[inputRank - 2] = + unpaddedShape[inputRank - 2] - paddingBottom - 1; + bottomSizeSlice[inputRank - 2] = paddingBottom; - SmallVector rightPadShape(selfShape.begin(), selfShape.end() - 1); - rightPadShape.push_back(paddingRight); + SmallVector bottomPadShape(unpaddedShape.begin(), + unpaddedShape.end() - 2); + bottomPadShape.push_back(paddingBottom); + bottomPadShape.push_back(resultShape.back()); - auto rightPadType = RankedTensorType::get(rightPadShape, selfElemTy); + auto bottomPadType = RankedTensorType::get(bottomPadShape, inputElemTy); - auto rightPadSlice = rewriter.create( - op->getLoc(), rightPadType, self, - rewriter.getDenseI64ArrayAttr(rightStartSlice), - rewriter.getDenseI64ArrayAttr(rightSizeSlice)); + auto bottomPadSlice = rewriter.create( + loc, bottomPadType, leftRightPaddedInput, + rewriter.getDenseI64ArrayAttr(bottomStartSlice), + rewriter.getDenseI64ArrayAttr(bottomSizeSlice)); - auto rightPad = rewriter.create( - op->getLoc(), rightPadType, rightPadSlice.getResult(), - static_cast(selfRank - 1)); + auto bottomPad = rewriter.create( + loc, bottomPadType, bottomPadSlice.getResult(), + static_cast(inputRank - 2)); - resultTensors.push_back(rightPad.getResult()); + resultTensors.push_back(bottomPad.getResult()); } - auto result = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), resultType, resultTensors, selfRank - 1); - - rewriter.replaceOp(op, result); - return success(); + return tosa::CreateOpAndInfer(rewriter, loc, resultType, + resultTensors, inputRank - 2); } // Legalization for aten.reflection_pad2d @@ -7483,129 +7573,167 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } - // Use tosa.slice and tosa.reverse to get the reflection pads based on the - // padding size - SmallVector sideTensors; - - if (paddingLeft > 0) { - SmallVector leftStartSlice(selfRank, 0); - SmallVector leftSizeSlice(selfShape); - - leftStartSlice[selfRank - 1] = 1; - leftSizeSlice[selfRank - 1] = paddingLeft; - - SmallVector leftPadShape(selfShape.begin(), selfShape.end() - 1); - leftPadShape.push_back(paddingLeft); - - auto leftPadType = RankedTensorType::get(leftPadShape, selfElemTy); - - auto leftPadSlice = rewriter.create( - op->getLoc(), leftPadType, self, - rewriter.getDenseI64ArrayAttr(leftStartSlice), - rewriter.getDenseI64ArrayAttr(leftSizeSlice)); - - auto leftPad = rewriter.create( - op->getLoc(), leftPadType, leftPadSlice.getResult(), - static_cast(selfRank - 1)); - - sideTensors.push_back(leftPad.getResult()); - } - - sideTensors.push_back(self); - - if (paddingRight > 0) { - SmallVector rightStartSlice(selfRank, 0); - SmallVector rightSizeSlice(selfShape); - - rightStartSlice[selfRank - 1] = selfShape[selfRank - 1] - paddingRight - 1; - rightSizeSlice[selfRank - 1] = paddingRight; - - SmallVector rightPadShape(selfShape.begin(), selfShape.end() - 1); - rightPadShape.push_back(paddingRight); - - auto rightPadType = RankedTensorType::get(rightPadShape, selfElemTy); - - auto rightPadSlice = rewriter.create( - op->getLoc(), rightPadType, self, - rewriter.getDenseI64ArrayAttr(rightStartSlice), - rewriter.getDenseI64ArrayAttr(rightSizeSlice)); - - auto rightPad = rewriter.create( - op->getLoc(), rightPadType, rightPadSlice.getResult(), - static_cast(selfRank - 1)); - - sideTensors.push_back(rightPad.getResult()); - } - SmallVector selfSidePaddedShape(selfShape.begin(), selfShape.end() - 1); selfSidePaddedShape.push_back(resultShape.back()); - auto selfSidePadded = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), - RankedTensorType::get(selfSidePaddedShape, selfElemTy), sideTensors, - selfRank - 1); + auto selfSidePadded = reflectionPadLeftRight( + self, selfShape, paddingLeft, paddingRight, + RankedTensorType::get(selfSidePaddedShape, selfElemTy), op->getLoc(), + rewriter); + + auto result = + reflectionPadTopBottom(selfSidePadded, selfShape, paddingTop, + paddingBottom, resultType, op->getLoc(), rewriter); + + rewriter.replaceOp(op, result); + return success(); +} +Value reflectionPadFrontBack(Value lrtbPaddedInput, + ArrayRef unpaddedShape, + int64_t paddingFront, int64_t paddingBack, + TensorType resultType, Location loc, + ConversionPatternRewriter &rewriter) { SmallVector resultTensors; + auto resultShape = resultType.getShape(); - if (paddingTop > 0) { - SmallVector topStartSlice(selfRank, 0); - SmallVector topSizeSlice(selfShape.begin(), selfShape.end() - 1); + auto inputType = dyn_cast(lrtbPaddedInput.getType()); + auto inputRank = inputType.getRank(); + auto inputElemTy = inputType.getElementType(); + + if (paddingFront > 0) { + SmallVector topStartSlice(inputRank, 0); + SmallVector topSizeSlice(unpaddedShape.begin(), + unpaddedShape.end() - 1); topSizeSlice.push_back(resultShape.back()); - topStartSlice[selfRank - 2] = 1; - topSizeSlice[selfRank - 2] = paddingTop; + topStartSlice[inputRank - 2] = 1; + topSizeSlice[inputRank - 2] = paddingFront; - SmallVector topPadShape(selfShape.begin(), selfShape.end() - 2); - topPadShape.push_back(paddingTop); + SmallVector topPadShape(unpaddedShape.begin(), + unpaddedShape.end() - 2); + topPadShape.push_back(paddingFront); topPadShape.push_back(resultShape.back()); - auto topPadType = RankedTensorType::get(topPadShape, selfElemTy); + auto topPadType = RankedTensorType::get(topPadShape, inputElemTy); auto topPadSlice = rewriter.create( - op->getLoc(), topPadType, selfSidePadded, + loc, topPadType, lrtbPaddedInput, rewriter.getDenseI64ArrayAttr(topStartSlice), rewriter.getDenseI64ArrayAttr(topSizeSlice)); auto topPad = rewriter.create( - op->getLoc(), topPadType, topPadSlice.getResult(), - static_cast(selfRank - 2)); + loc, topPadType, topPadSlice.getResult(), + static_cast(inputRank - 2)); resultTensors.push_back(topPad.getResult()); } - resultTensors.push_back(selfSidePadded.getResult()); + resultTensors.push_back(lrtbPaddedInput); - if (paddingBottom > 0) { - SmallVector bottomStartSlice(selfRank, 0); - SmallVector bottomSizeSlice(selfShape.begin(), - selfShape.end() - 1); + if (paddingBack > 0) { + SmallVector bottomStartSlice(inputRank, 0); + SmallVector bottomSizeSlice(unpaddedShape.begin(), + unpaddedShape.end() - 1); bottomSizeSlice.push_back(resultShape.back()); - bottomStartSlice[selfRank - 2] = - selfShape[selfRank - 2] - paddingBottom - 1; - bottomSizeSlice[selfRank - 2] = paddingBottom; + bottomStartSlice[inputRank - 2] = + unpaddedShape[inputRank - 2] - paddingBack - 1; + bottomSizeSlice[inputRank - 2] = paddingBack; - SmallVector bottomPadShape(selfShape.begin(), selfShape.end() - 2); - bottomPadShape.push_back(paddingBottom); + SmallVector bottomPadShape(unpaddedShape.begin(), + unpaddedShape.end() - 2); + bottomPadShape.push_back(paddingBack); bottomPadShape.push_back(resultShape.back()); - auto bottomPadType = RankedTensorType::get(bottomPadShape, selfElemTy); + auto bottomPadType = RankedTensorType::get(bottomPadShape, inputElemTy); auto bottomPadSlice = rewriter.create( - op->getLoc(), bottomPadType, selfSidePadded, + loc, bottomPadType, lrtbPaddedInput, rewriter.getDenseI64ArrayAttr(bottomStartSlice), rewriter.getDenseI64ArrayAttr(bottomSizeSlice)); auto bottomPad = rewriter.create( - op->getLoc(), bottomPadType, bottomPadSlice.getResult(), - static_cast(selfRank - 2)); + loc, bottomPadType, bottomPadSlice.getResult(), + static_cast(inputRank - 2)); resultTensors.push_back(bottomPad.getResult()); } - auto result = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), resultType, resultTensors, selfRank - 2); + return tosa::CreateOpAndInfer(rewriter, loc, resultType, + resultTensors, inputRank - 2); +} + +// Legalization for aten.reflection_pad3d +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenReflectionPad3dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getSelf(); + + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); + + auto selfShape = selfType.getShape(); + auto selfRank = selfType.getRank(); + auto selfElemTy = selfType.getElementType(); + + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + auto resultShape = resultType.getShape(); + + SmallVector paddingList; + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingList))) + return rewriter.notifyMatchFailure( + op, "Non-const padding lists are not supported"); + + int64_t paddingLeft = paddingList[0]; + int64_t paddingRight = paddingList[1]; + int64_t paddingTop = paddingList[2]; + int64_t paddingBottom = paddingList[3]; + int64_t paddingFront = paddingList[4]; + int64_t paddingBack = paddingList[5]; + + if (paddingLeft >= selfShape[selfRank - 1] || + paddingRight >= selfShape[selfRank - 1] || + paddingTop >= selfShape[selfRank - 2] || + paddingBottom >= selfShape[selfRank - 2] || + paddingFront >= selfShape[selfRank - 3] || + paddingBack >= selfShape[selfRank - 3]) + return rewriter.notifyMatchFailure( + op, "Padding must be less than the corresponding input dimension"); + + // Identity case + if (paddingLeft == 0 && paddingRight == 0 && paddingTop == 0 && + paddingBottom == 0 && paddingFront == 0 && paddingBack == 0) { + rewriter.replaceOp(op, self); + return success(); + } + + SmallVector self1dPaddedShape(selfShape.begin(), + selfShape.end() - 1); + self1dPaddedShape.push_back(resultShape.back()); + + auto self1dPadded = reflectionPadLeftRight( + self, selfShape, paddingLeft, paddingRight, + RankedTensorType::get(self1dPaddedShape, selfElemTy), op->getLoc(), + rewriter); + + SmallVector self2dPaddedShape(selfShape.begin(), + selfShape.end() - 2); + self2dPaddedShape.push_back(*(resultShape.end() - 1)); + self2dPaddedShape.push_back(resultShape.back()); + + auto self2dPadded = reflectionPadTopBottom( + self1dPadded, selfShape, paddingTop, paddingBottom, + RankedTensorType::get(self2dPaddedShape, selfElemTy), op->getLoc(), + rewriter); + + auto result = + reflectionPadFrontBack(self2dPadded, selfShape, paddingFront, paddingBack, + resultType, op->getLoc(), rewriter); rewriter.replaceOp(op, result); return success(); @@ -7798,11 +7926,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Only constant int outer length value is supported"); - // Technically, I should calculate the output shape based on the dim and outer - // length values. However, that would just give the same result as me taking - // the result shape straight from resultType and applying tosa::ReshapeOp to - // the input. Therefore, I'm opting for the latter approach here, which is - // more simple and quicker. + // Technically, I should calculate the output shape based on the dim and + // outer length values. However, that would just give the same result as me + // taking the result shape straight from resultType and applying + // tosa::ReshapeOp to the input. Therefore, I'm opting for the latter + // approach here, which is more simple and quicker. rewriter.replaceOpWithNewOp( op, resultType, self, rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape))); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 5fd05708961c..ae164e00ab2b 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10599,14 +10599,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %0 : !torch.tuple, list, list>\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.constant_pad_nd\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float) -> !torch.list {\n" -" %0 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" %false = torch.constant.bool false\n" +" %0 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list, !torch.list, !torch.bool) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @__torch__.pad_shape_fn(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" func.func @__torch__.pad_shape_fn(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" " %true = torch.constant.bool true\n" -" %str = torch.constant.str \"AssertionError: Number of padded dimensions must be less than or equal to the input dimension\"\n" +" %str_0 = torch.constant.str \"AssertionError: Number of padded dimensions must be less than or equal to the input dimension\"\n" " %none = torch.constant.none\n" -" %str_0 = torch.constant.str \"AssertionError: Must have paired low-high pad amount values\"\n" +" %str_1 = torch.constant.str \"AssertionError: Must have paired low-high pad amount values\"\n" " %int2 = torch.constant.int 2\n" " %int0 = torch.constant.int 0\n" " %int1 = torch.constant.int 1\n" @@ -10616,7 +10619,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" " %3 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" @@ -10626,18 +10629,47 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If %6 -> () {\n" " torch.prim.If.yield\n" " } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" " %7 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" " %8 = torch.aten.floordiv.int %7, %int2 : !torch.int, !torch.int -> !torch.int\n" " torch.prim.Loop %8, %true, init() {\n" -" ^bb0(%arg2: !torch.int):\n" -" %9 = torch.aten.add.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int\n" +" ^bb0(%arg3: !torch.int):\n" +" torch.prim.If %arg2 -> () {\n" +" %20 = torch.aten.mul.int %int2, %arg3 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.__getitem__.t %arg1, %20 : !torch.list, !torch.int -> !torch.int\n" +" %22 = torch.aten.add.int %arg3, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %23 = torch.aten.neg.int %22 : !torch.int -> !torch.int\n" +" %24 = torch.aten.__getitem__.t %arg0, %23 : !torch.list, !torch.int -> !torch.int\n" +" %25 = torch.aten.lt.int %21, %24 : !torch.int, !torch.int -> !torch.bool\n" +" %26 = torch.prim.If %25 -> (!torch.bool) {\n" +" %27 = torch.aten.mul.int %int2, %arg3 : !torch.int, !torch.int -> !torch.int\n" +" %28 = torch.aten.add.int %27, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %29 = torch.aten.__getitem__.t %arg1, %28 : !torch.list, !torch.int -> !torch.int\n" +" %30 = torch.aten.add.int %arg3, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %31 = torch.aten.neg.int %30 : !torch.int -> !torch.int\n" +" %32 = torch.aten.__getitem__.t %arg0, %31 : !torch.list, !torch.int -> !torch.int\n" +" %33 = torch.aten.lt.int %29, %32 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %33 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %26 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.aten.add.int %arg3, %int1 : !torch.int, !torch.int -> !torch.int\n" " %10 = torch.aten.neg.int %9 : !torch.int -> !torch.int\n" -" %11 = torch.aten.mul.int %int2, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" %11 = torch.aten.mul.int %int2, %arg3 : !torch.int, !torch.int -> !torch.int\n" " %12 = torch.aten.__getitem__.t %arg1, %11 : !torch.list, !torch.int -> !torch.int\n" -" %13 = torch.aten.mul.int %int2, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" %13 = torch.aten.mul.int %int2, %arg3 : !torch.int, !torch.int -> !torch.int\n" " %14 = torch.aten.add.int %13, %int1 : !torch.int, !torch.int -> !torch.int\n" " %15 = torch.aten.__getitem__.t %arg1, %14 : !torch.list, !torch.int -> !torch.int\n" " %16 = torch.aten.add.int %12, %15 : !torch.int, !torch.int -> !torch.int\n" @@ -10649,6 +10681,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %arg0 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.replication_pad2d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %false = torch.constant.bool false\n" " %str = torch.constant.str \"AssertionError: padding size expected to be 4\"\n" " %none = torch.constant.none\n" " %str_0 = torch.constant.str \"AssertionError: \"\n" @@ -10670,7 +10703,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list, !torch.list, !torch.bool) -> !torch.list\n" " return %4 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.replication_pad2d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" @@ -10678,17 +10711,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %0#1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.pad\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.str, %arg3: !torch.optional) -> !torch.list {\n" -" %0 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" %false = torch.constant.bool false\n" +" %0 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list, !torch.list, !torch.bool) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.reflection_pad1d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" -" %false = torch.constant.bool false\n" -" %int-1 = torch.constant.int -1\n" +" %true = torch.constant.bool true\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int2 = torch.constant.int 2\n" -" %int1 = torch.constant.int 1\n" -" %int0 = torch.constant.int 0\n" " %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" " %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" " torch.prim.If %1 -> () {\n" @@ -10697,37 +10728,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" -" %3 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %4 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %5 = torch.aten.lt.int %3, %2 : !torch.int, !torch.int -> !torch.bool\n" -" %6 = torch.prim.If %5 -> (!torch.bool) {\n" -" %8 = torch.aten.lt.int %4, %2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %8 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %6 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %7 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" -" return %7 : !torch.list\n" +" %2 = call @__torch__.pad_shape_fn(%arg0, %arg1, %true) : (!torch.list, !torch.list, !torch.bool) -> !torch.list\n" +" return %2 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.reflection_pad2d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" -" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" " %str = torch.constant.str \"AssertionError: padding size expected to be 4\"\n" -" %int-1 = torch.constant.int -1\n" -" %int-2 = torch.constant.int -2\n" " %none = torch.constant.none\n" " %str_0 = torch.constant.str \"AssertionError: \"\n" " %int2 = torch.constant.int 2\n" -" %int1 = torch.constant.int 1\n" " %int4 = torch.constant.int 4\n" -" %int0 = torch.constant.int 0\n" -" %int3 = torch.constant.int 3\n" " %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" " %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" " torch.prim.If %1 -> () {\n" @@ -10736,48 +10746,42 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" -" %3 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" -" %4 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" -" %5 = torch.aten.eq.int %4, %int4 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %5 -> () {\n" +" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %6 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %7 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %8 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int\n" -" %9 = torch.aten.__getitem__.t %arg1, %int3 : !torch.list, !torch.int -> !torch.int\n" -" %10 = torch.aten.lt.int %6, %3 : !torch.int, !torch.int -> !torch.bool\n" -" %11 = torch.prim.If %10 -> (!torch.bool) {\n" -" %15 = torch.aten.lt.int %7, %3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %11 -> () {\n" +" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1, %true) : (!torch.list, !torch.list, !torch.bool) -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.reflection_pad3d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %str = torch.constant.str \"AssertionError: padding size expected to be 6\"\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: \"\n" +" %int3 = torch.constant.int 3\n" +" %int6 = torch.constant.int 6\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.ge.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %12 = torch.aten.lt.int %8, %2 : !torch.int, !torch.int -> !torch.bool\n" -" %13 = torch.prim.If %12 -> (!torch.bool) {\n" -" %15 = torch.aten.lt.int %9, %2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %13 -> () {\n" +" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %14 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" -" return %14 : !torch.list\n" +" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1, %true) : (!torch.list, !torch.list, !torch.bool) -> !torch.list\n" +" return %4 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.index.Tensor\"(%arg0: !torch.list, %arg1: !torch.list>>) -> !torch.list {\n" " %0 = call @__torch__.index_tensor_like(%arg0, %arg1) : (!torch.list, !torch.list>>) -> !torch.list\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 9c2a80187c93..91d6b5eb17fc 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -7902,17 +7902,25 @@ class DecomposeAtenPadOp : public OpRewritePattern { if (mode == "reflect") { // only support for relectionpad 1d and 2d - if (numPadDims == 2) { + switch (numPadDims) { + case 1: + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), usefulPads); + break; + case 2: rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), usefulPads); - return success(); - } - if (numPadDims == 1) { - rewriter.replaceOpWithNewOp( + break; + case 3: + rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), usefulPads); - return success(); + break; + default: + return rewriter.notifyMatchFailure( + op, "unsupported number of dims for 'reflect' mode: " + + std::to_string(numPadDims)); } - return failure(); + return success(); } if (mode == "replicate") { diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index a73d188d7168..497c35defd3e 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -2153,12 +2153,14 @@ def aten〇native_batch_norm〡shape(input: List[int], weight: Optional[List[int # TODO: This should be upstreamed. # See https://github.com/pytorch/pytorch/pull/76889 for an example. -def pad_shape_fn(input: List[int], pad: List[int]): +def pad_shape_fn(input: List[int], pad: List[int], validate_pad : bool = False): assert len(pad) % 2 == 0, "Must have paired low-high pad amount values" assert len(pad) // 2 <= len(input), "Number of padded dimensions must be less than or equal to the input dimension" # The `pad` list takes the form of Low-high pairs starting at the # *rightmost* dimension of `self`. for i in range(len(pad) // 2): + if validate_pad: + assert pad[2*i] < input[-(i+1)] and pad[2 * i + 1] < input[-(i+1)] input[-(i + 1)] += pad[2 * i] + pad[2 * i + 1] return input @@ -2193,11 +2195,11 @@ def aten〇pad〡shape(self: List[int], pad: List[int], mode: str = "constant", ErrorInvocation(TensorOfShape(1, 4), padding=[1,4])]) def aten〇reflection_pad1d〡shape(self: List[int], padding: List[int]) -> List[int]: assert len(self) >= 2 - hdim = self[-1] - padding_left = padding[0] - padding_right = padding[1] - assert padding_left < hdim and padding_right < hdim - return pad_shape_fn(self, padding) + # hdim = self[-1] + # padding_left = padding[0] + # padding_right = padding[1] + # assert padding_left < hdim and padding_right < hdim + return pad_shape_fn(self, padding, validate_pad=True) # Padding size must be smaller than corresponding dimension @@ -2210,18 +2212,41 @@ def aten〇reflection_pad1d〡shape(self: List[int], padding: List[int]) -> List ErrorInvocation(TensorOfShape(2, 2, 2), padding=[1,1,2,2])]) def aten〇reflection_pad2d〡shape(self: List[int], padding: List[int]) -> List[int]: assert len(self) >= 2 - vdim = self[-2] - hdim = self[-1] + # vdim = self[-2] + # hdim = self[-1] assert len(padding) == 4, 'padding size expected to be 4' - padding_left = padding[0] - padding_right = padding[1] - padding_top = padding[2] - padding_bottom = padding[3] - assert padding_left < hdim and padding_right < hdim - assert padding_top < vdim and padding_bottom < vdim + # padding_left = padding[0] + # padding_right = padding[1] + # padding_top = padding[2] + # padding_bottom = padding[3] + # assert padding_left < hdim and padding_right < hdim + # assert padding_top < vdim and padding_bottom < vdim - return pad_shape_fn(self, padding) + return pad_shape_fn(self, padding, validate_pad=True) + +# Padding size must be smaller than corresponding dimension +@check_shape_function([ErrorInvocation(TensorOfShape(2, 2, 2, 2), padding=[2,2,1,1,1,1]), + ErrorInvocation(TensorOfShape(2, 2, 2, 2), padding=[2,1,1,1,1,1]), + ErrorInvocation(TensorOfShape(2, 2, 2, 2), padding=[2,1,1,1,1,3]), + ErrorInvocation(TensorOfShape(2, 2, 2, 2), padding=[2,1]), + Invocation(TensorOfShape(2, 2, 2, 2), padding=[1,1,1,1,1,1]), + ErrorInvocation(TensorOfShape(2, 2, 2, 2), padding=[1,1,1,1,1,2]), + ErrorInvocation(TensorOfShape(2, 2, 2, 2), padding=[1,1,2,1,1,1])]) +def aten〇reflection_pad3d〡shape(self: List[int], padding: List[int]) -> List[int]: + assert len(self) >= 3 + # vdim = self[-2] + # hdim = self[-1] + + assert len(padding) == 6, 'padding size expected to be 6' + # padding_left = padding[0] + # padding_right = padding[1] + # padding_top = padding[2] + # padding_bottom = padding[3] + # assert padding_left < hdim and padding_right < hdim + # assert padding_top < vdim and padding_bottom < vdim + + return pad_shape_fn(self, padding, validate_pad=True) # TODO: upstream this def index_tensor_like(self: List[int], indices: List[Optional[List[int]]]) -> List[int]: diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 930979b3c939..8a9c990de9a0 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -789,6 +789,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::replication_pad2d : (Tensor, int[]) -> (Tensor)") emit("aten::reflection_pad1d : (Tensor, int[]) -> (Tensor)") emit("aten::reflection_pad2d : (Tensor, int[]) -> (Tensor)") + emit("aten::reflection_pad3d : (Tensor, int[]) -> (Tensor)") emit("aten::pad : (Tensor, int[], str, float?) -> (Tensor)") emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True) emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True) From 74a52104213a337f1e93986e89019d8c6c986c5d Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Tue, 24 Dec 2024 12:20:43 -0500 Subject: [PATCH 2/2] Add tests and update xfail_sets.py --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 235 ++++-------------- projects/pt1/e2e_testing/xfail_sets.py | 19 ++ .../build_tools/abstract_interp_lib_gen.py | 24 -- .../torch_mlir_e2e_test/test_suite/padding.py | 161 ++++++++++++ test/Conversion/TorchToTosa/basic.mlir | 31 +++ 5 files changed, 262 insertions(+), 208 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 8853f2f15414..73d78d3f89ab 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -7342,29 +7342,40 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -Value reflectionPadLeftRight(Value input, ArrayRef unpaddedShape, - int64_t paddingLeft, int64_t paddingRight, - TensorType resultType, Location loc, +Value reflectionPadAlongAxis(Value input, ArrayRef unpaddedShape, + int64_t paddingAxisLeft, int64_t paddingAxisRight, + int64_t axis, TensorType resultType, Location loc, ConversionPatternRewriter &rewriter) { SmallVector resultTensors; + auto resultShape = resultType.getShape(); auto inputType = dyn_cast(input.getType()); auto inputRank = inputType.getRank(); auto inputElemTy = inputType.getElementType(); + assert(inputRank == resultType.getRank()); + int64_t axisOffset = inputRank - axis - 1; + // Use tosa.slice and tosa.reverse to get the reflection pads based on the // padding size - if (paddingLeft > 0) { + if (paddingAxisLeft > 0) { SmallVector leftStartSlice(inputRank, 0); - SmallVector leftSizeSlice(unpaddedShape); + SmallVector leftSizeSlice(unpaddedShape.begin(), + unpaddedShape.end() - axisOffset); + for (int64_t iDim = axisOffset - 1; iDim >= 0; iDim--) { + leftSizeSlice.push_back(resultShape[inputRank - iDim - 1]); + } - leftStartSlice[inputRank - 1] = 1; - leftSizeSlice[inputRank - 1] = paddingLeft; + leftStartSlice[axis] = 1; + leftSizeSlice[axis] = paddingAxisLeft; SmallVector leftPadShape(unpaddedShape.begin(), - unpaddedShape.end() - 1); - leftPadShape.push_back(paddingLeft); + unpaddedShape.end() - (axisOffset + 1)); + leftPadShape.push_back(paddingAxisLeft); + for (int64_t iDim = axisOffset - 1; iDim >= 0; iDim--) { + leftPadShape.push_back(resultShape[inputRank - iDim - 1]); + } auto leftPadType = RankedTensorType::get(leftPadShape, inputElemTy); @@ -7373,25 +7384,30 @@ Value reflectionPadLeftRight(Value input, ArrayRef unpaddedShape, rewriter.getDenseI64ArrayAttr(leftSizeSlice)); auto leftPad = rewriter.create( - loc, leftPadType, leftPadSlice.getResult(), - static_cast(inputRank - 1)); + loc, leftPadType, leftPadSlice.getResult(), static_cast(axis)); resultTensors.push_back(leftPad.getResult()); } resultTensors.push_back(input); - if (paddingRight > 0) { + if (paddingAxisRight > 0) { SmallVector rightStartSlice(inputRank, 0); - SmallVector rightSizeSlice(unpaddedShape); + SmallVector rightSizeSlice(unpaddedShape.begin(), + unpaddedShape.end() - axisOffset); + for (int64_t iDim = axisOffset - 1; iDim >= 0; iDim--) { + rightSizeSlice.push_back(resultShape[inputRank - iDim - 1]); + } - rightStartSlice[inputRank - 1] = - unpaddedShape[inputRank - 1] - paddingRight - 1; - rightSizeSlice[inputRank - 1] = paddingRight; + rightStartSlice[axis] = unpaddedShape[axis] - paddingAxisRight - 1; + rightSizeSlice[axis] = paddingAxisRight; SmallVector rightPadShape(unpaddedShape.begin(), - unpaddedShape.end() - 1); - rightPadShape.push_back(paddingRight); + unpaddedShape.end() - (axisOffset + 1)); + rightPadShape.push_back(paddingAxisRight); + for (int64_t iDim = axisOffset - 1; iDim >= 0; iDim--) { + rightPadShape.push_back(resultShape[inputRank - iDim - 1]); + } auto rightPadType = RankedTensorType::get(rightPadShape, inputElemTy); @@ -7402,13 +7418,13 @@ Value reflectionPadLeftRight(Value input, ArrayRef unpaddedShape, auto rightPad = rewriter.create( loc, rightPadType, rightPadSlice.getResult(), - static_cast(inputRank - 1)); + static_cast(axis)); resultTensors.push_back(rightPad.getResult()); } return tosa::CreateOpAndInfer(rewriter, loc, resultType, - resultTensors, inputRank - 1); + resultTensors, axis); } // Legalization for aten.reflection_pad1d @@ -7448,88 +7464,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } auto result = - reflectionPadLeftRight(self, selfShape, paddingLeft, paddingRight, - resultType, op->getLoc(), rewriter); + reflectionPadAlongAxis(self, selfShape, paddingLeft, paddingRight, + selfRank - 1, resultType, op->getLoc(), rewriter); rewriter.replaceOp(op, result); return success(); } -Value reflectionPadTopBottom(Value leftRightPaddedInput, - ArrayRef unpaddedShape, - int64_t paddingTop, int64_t paddingBottom, - TensorType resultType, Location loc, - ConversionPatternRewriter &rewriter) { - SmallVector resultTensors; - auto resultShape = resultType.getShape(); - - auto inputType = dyn_cast(leftRightPaddedInput.getType()); - auto inputRank = inputType.getRank(); - auto inputElemTy = inputType.getElementType(); - - if (paddingTop > 0) { - SmallVector topStartSlice(inputRank, 0); - SmallVector topSizeSlice(unpaddedShape.begin(), - unpaddedShape.end() - 1); - topSizeSlice.push_back(resultShape.back()); - - topStartSlice[inputRank - 2] = 1; - topSizeSlice[inputRank - 2] = paddingTop; - - SmallVector topPadShape(unpaddedShape.begin(), - unpaddedShape.end() - 2); - topPadShape.push_back(paddingTop); - topPadShape.push_back(resultShape.back()); - - auto topPadType = RankedTensorType::get(topPadShape, inputElemTy); - - auto topPadSlice = rewriter.create( - loc, topPadType, leftRightPaddedInput, - rewriter.getDenseI64ArrayAttr(topStartSlice), - rewriter.getDenseI64ArrayAttr(topSizeSlice)); - - auto topPad = rewriter.create( - loc, topPadType, topPadSlice.getResult(), - static_cast(inputRank - 2)); - - resultTensors.push_back(topPad.getResult()); - } - - resultTensors.push_back(leftRightPaddedInput); - - if (paddingBottom > 0) { - SmallVector bottomStartSlice(inputRank, 0); - SmallVector bottomSizeSlice(unpaddedShape.begin(), - unpaddedShape.end() - 1); - bottomSizeSlice.push_back(resultShape.back()); - - bottomStartSlice[inputRank - 2] = - unpaddedShape[inputRank - 2] - paddingBottom - 1; - bottomSizeSlice[inputRank - 2] = paddingBottom; - - SmallVector bottomPadShape(unpaddedShape.begin(), - unpaddedShape.end() - 2); - bottomPadShape.push_back(paddingBottom); - bottomPadShape.push_back(resultShape.back()); - - auto bottomPadType = RankedTensorType::get(bottomPadShape, inputElemTy); - - auto bottomPadSlice = rewriter.create( - loc, bottomPadType, leftRightPaddedInput, - rewriter.getDenseI64ArrayAttr(bottomStartSlice), - rewriter.getDenseI64ArrayAttr(bottomSizeSlice)); - - auto bottomPad = rewriter.create( - loc, bottomPadType, bottomPadSlice.getResult(), - static_cast(inputRank - 2)); - - resultTensors.push_back(bottomPad.getResult()); - } - - return tosa::CreateOpAndInfer(rewriter, loc, resultType, - resultTensors, inputRank - 2); -} - // Legalization for aten.reflection_pad2d template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -7577,94 +7518,19 @@ LogicalResult ConvertAtenOp::matchAndRewrite( selfShape.end() - 1); selfSidePaddedShape.push_back(resultShape.back()); - auto selfSidePadded = reflectionPadLeftRight( - self, selfShape, paddingLeft, paddingRight, + auto selfSidePadded = reflectionPadAlongAxis( + self, selfShape, paddingLeft, paddingRight, selfRank - 1, RankedTensorType::get(selfSidePaddedShape, selfElemTy), op->getLoc(), rewriter); - auto result = - reflectionPadTopBottom(selfSidePadded, selfShape, paddingTop, - paddingBottom, resultType, op->getLoc(), rewriter); + auto result = reflectionPadAlongAxis(selfSidePadded, selfShape, paddingTop, + paddingBottom, selfRank - 2, resultType, + op->getLoc(), rewriter); rewriter.replaceOp(op, result); return success(); } -Value reflectionPadFrontBack(Value lrtbPaddedInput, - ArrayRef unpaddedShape, - int64_t paddingFront, int64_t paddingBack, - TensorType resultType, Location loc, - ConversionPatternRewriter &rewriter) { - SmallVector resultTensors; - auto resultShape = resultType.getShape(); - - auto inputType = dyn_cast(lrtbPaddedInput.getType()); - auto inputRank = inputType.getRank(); - auto inputElemTy = inputType.getElementType(); - - if (paddingFront > 0) { - SmallVector topStartSlice(inputRank, 0); - SmallVector topSizeSlice(unpaddedShape.begin(), - unpaddedShape.end() - 1); - topSizeSlice.push_back(resultShape.back()); - - topStartSlice[inputRank - 2] = 1; - topSizeSlice[inputRank - 2] = paddingFront; - - SmallVector topPadShape(unpaddedShape.begin(), - unpaddedShape.end() - 2); - topPadShape.push_back(paddingFront); - topPadShape.push_back(resultShape.back()); - - auto topPadType = RankedTensorType::get(topPadShape, inputElemTy); - - auto topPadSlice = rewriter.create( - loc, topPadType, lrtbPaddedInput, - rewriter.getDenseI64ArrayAttr(topStartSlice), - rewriter.getDenseI64ArrayAttr(topSizeSlice)); - - auto topPad = rewriter.create( - loc, topPadType, topPadSlice.getResult(), - static_cast(inputRank - 2)); - - resultTensors.push_back(topPad.getResult()); - } - - resultTensors.push_back(lrtbPaddedInput); - - if (paddingBack > 0) { - SmallVector bottomStartSlice(inputRank, 0); - SmallVector bottomSizeSlice(unpaddedShape.begin(), - unpaddedShape.end() - 1); - bottomSizeSlice.push_back(resultShape.back()); - - bottomStartSlice[inputRank - 2] = - unpaddedShape[inputRank - 2] - paddingBack - 1; - bottomSizeSlice[inputRank - 2] = paddingBack; - - SmallVector bottomPadShape(unpaddedShape.begin(), - unpaddedShape.end() - 2); - bottomPadShape.push_back(paddingBack); - bottomPadShape.push_back(resultShape.back()); - - auto bottomPadType = RankedTensorType::get(bottomPadShape, inputElemTy); - - auto bottomPadSlice = rewriter.create( - loc, bottomPadType, lrtbPaddedInput, - rewriter.getDenseI64ArrayAttr(bottomStartSlice), - rewriter.getDenseI64ArrayAttr(bottomSizeSlice)); - - auto bottomPad = rewriter.create( - loc, bottomPadType, bottomPadSlice.getResult(), - static_cast(inputRank - 2)); - - resultTensors.push_back(bottomPad.getResult()); - } - - return tosa::CreateOpAndInfer(rewriter, loc, resultType, - resultTensors, inputRank - 2); -} - // Legalization for aten.reflection_pad3d template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -7716,24 +7582,24 @@ LogicalResult ConvertAtenOp::matchAndRewrite( selfShape.end() - 1); self1dPaddedShape.push_back(resultShape.back()); - auto self1dPadded = reflectionPadLeftRight( - self, selfShape, paddingLeft, paddingRight, + auto self1dPadded = reflectionPadAlongAxis( + self, selfShape, paddingLeft, paddingRight, selfRank - 1, RankedTensorType::get(self1dPaddedShape, selfElemTy), op->getLoc(), rewriter); SmallVector self2dPaddedShape(selfShape.begin(), selfShape.end() - 2); - self2dPaddedShape.push_back(*(resultShape.end() - 1)); + self2dPaddedShape.push_back(resultShape[resultShape.size() - 2]); self2dPaddedShape.push_back(resultShape.back()); - auto self2dPadded = reflectionPadTopBottom( - self1dPadded, selfShape, paddingTop, paddingBottom, + auto self2dPadded = reflectionPadAlongAxis( + self1dPadded, selfShape, paddingTop, paddingBottom, selfRank - 2, RankedTensorType::get(self2dPaddedShape, selfElemTy), op->getLoc(), rewriter); auto result = - reflectionPadFrontBack(self2dPadded, selfShape, paddingFront, paddingBack, - resultType, op->getLoc(), rewriter); + reflectionPadAlongAxis(self2dPadded, selfShape, paddingFront, paddingBack, + selfRank - 3, resultType, op->getLoc(), rewriter); rewriter.replaceOp(op, result); return success(); @@ -8932,6 +8798,7 @@ std::set torch::populateTorchToTosaConversionPatternsAndIllegalOps( INSERT_ATENOP_PATTERN(PrimsCollapseOp); INSERT_ATENOP_PATTERN(AtenReflectionPad1dOp); INSERT_ATENOP_PATTERN(AtenReflectionPad2dOp); + INSERT_ATENOP_PATTERN(AtenReflectionPad3dOp); INSERT_ATENOP_PATTERN(AtenReplicationPad2dOp); INSERT_ATENOP_PATTERN(PrimsSplitDimOp); INSERT_ATENOP_PATTERN(AtenOuterOp); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 1dce55f06158..38eb1f573362 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -506,6 +506,13 @@ "BernoulliTensorModule_basic", "UniformModule_basic", "UniformStaticShapeModule_basic", + "ReflectionPad3dModule_basic", + "ReflectionPad3dModuleTop_basic", + "ReflectionPad3dModuleBottom_basic", + "ReflectionPad3dModuleLeft_basic", + "ReflectionPad3dModuleRight_basic", + "ReflectionPad3dModuleFront_basic", + "ReflectionPad3dModuleBack_basic", } FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | { @@ -801,6 +808,13 @@ "ReflectionPad2dModule_Right", "ReflectionPad2dModule_Top", "ReflectionPad2dModule_basic", + "ReflectionPad3dModule_basic", + "ReflectionPad3dModuleTop_basic", + "ReflectionPad3dModuleBottom_basic", + "ReflectionPad3dModuleLeft_basic", + "ReflectionPad3dModuleRight_basic", + "ReflectionPad3dModuleFront_basic", + "ReflectionPad3dModuleBack_basic", "ReplicationPad2dModule_basic", "ReplicationPad2dModule_bottom0", "ReplicationPad2dModule_left0", @@ -3114,6 +3128,9 @@ "ReduceL1NormComplexModule_basic", "ReduceL2NormComplexModule_basic", "ReduceL3NormKeepDimComplexModule_basic", + "ReflectionPad3dModule_basic", + "ReflectionPad3dModuleFront_basic", + "ReflectionPad3dModuleBack_basic", "RreluWithNoiseBackwardEvalModule_basic", "RreluWithNoiseBackwardEvalStaticModule_basic", "RreluWithNoiseBackwardTrainModule_basic", @@ -3449,6 +3466,7 @@ "ElementwiseSignbitModule_basic", "Aten_TrilinearModuleVaryingRanks_basic", "Aten_TrilinearModuleZerodDimBug_basic", + "AtenNonzero1DDynamicModule_basic", "MaxPool3dEmptyStrideStaticModule_basic", "MaxPool3dLargeDatadModule_basic", "MaxPool3dModuleRandomSimple_basic", @@ -3907,6 +3925,7 @@ ONNX_TOSA_XFAIL_SET = { "AtenFftRfft2DLastDim_basic", "AtenFftRfft2DMiddleDim_basic", + "AtenNonzero1DDynamicModule_basic", "PowFloatIntModule_basic", "PowIntFloatModule_basic", "PowIntIntModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 497c35defd3e..d0170b1bf9b0 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -2195,10 +2195,6 @@ def aten〇pad〡shape(self: List[int], pad: List[int], mode: str = "constant", ErrorInvocation(TensorOfShape(1, 4), padding=[1,4])]) def aten〇reflection_pad1d〡shape(self: List[int], padding: List[int]) -> List[int]: assert len(self) >= 2 - # hdim = self[-1] - # padding_left = padding[0] - # padding_right = padding[1] - # assert padding_left < hdim and padding_right < hdim return pad_shape_fn(self, padding, validate_pad=True) @@ -2212,17 +2208,7 @@ def aten〇reflection_pad1d〡shape(self: List[int], padding: List[int]) -> List ErrorInvocation(TensorOfShape(2, 2, 2), padding=[1,1,2,2])]) def aten〇reflection_pad2d〡shape(self: List[int], padding: List[int]) -> List[int]: assert len(self) >= 2 - # vdim = self[-2] - # hdim = self[-1] - assert len(padding) == 4, 'padding size expected to be 4' - # padding_left = padding[0] - # padding_right = padding[1] - # padding_top = padding[2] - # padding_bottom = padding[3] - # assert padding_left < hdim and padding_right < hdim - # assert padding_top < vdim and padding_bottom < vdim - return pad_shape_fn(self, padding, validate_pad=True) # Padding size must be smaller than corresponding dimension @@ -2235,17 +2221,7 @@ def aten〇reflection_pad2d〡shape(self: List[int], padding: List[int]) -> List ErrorInvocation(TensorOfShape(2, 2, 2, 2), padding=[1,1,2,1,1,1])]) def aten〇reflection_pad3d〡shape(self: List[int], padding: List[int]) -> List[int]: assert len(self) >= 3 - # vdim = self[-2] - # hdim = self[-1] - assert len(padding) == 6, 'padding size expected to be 6' - # padding_left = padding[0] - # padding_right = padding[1] - # padding_top = padding[2] - # padding_bottom = padding[3] - # assert padding_left < hdim and padding_right < hdim - # assert padding_top < vdim and padding_bottom < vdim - return pad_shape_fn(self, padding, validate_pad=True) # TODO: upstream this diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py index a97d7f09eda6..b9c58551d657 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py @@ -123,3 +123,164 @@ def forward(self, x): @register_test_case(module_factory=lambda: ReflectionPad2dModuleRight()) def ReflectionPad2dModule_Right(module, tu: TestUtils): module.forward(tu.rand(2, 3, 20, 20)) + + +# ============================================================================== + + +class ReflectionPad3dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 20, 20, 20, 20], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.reflection_pad3d(x, (10, 10, 10, 10, 10, 10)) + + +@register_test_case(module_factory=lambda: ReflectionPad3dModule()) +def ReflectionPad3dModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 20, 20, 20, 20, low=-1)) + + +# ============================================================================== + + +class ReflectionPad3dModuleTop(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 3, 4, 5, 6], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.reflection_pad3d(x, (0, 0, 2, 0, 0, 0)) + + +@register_test_case(module_factory=lambda: ReflectionPad3dModuleTop()) +def ReflectionPad3dModuleTop_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 3, 4, 5, 6)) + + +# ============================================================================== + + +class ReflectionPad3dModuleBottom(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 10, 10, 6], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.reflection_pad3d(x, (0, 0, 0, 5, 0, 0)) + + +@register_test_case(module_factory=lambda: ReflectionPad3dModuleBottom()) +def ReflectionPad3dModuleBottom_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 10, 10, 6)) + + +# ============================================================================== + + +class ReflectionPad3dModuleLeft(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 20, 20, 10], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.reflection_pad3d(x, (9, 0, 0, 0, 0, 0)) + + +@register_test_case(module_factory=lambda: ReflectionPad3dModuleLeft()) +def ReflectionPad3dModuleLeft_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 20, 20, 10)) + + +# ============================================================================== + + +class ReflectionPad3dModuleRight(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 20, 20, 12], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.reflection_pad3d(x, (0, 11, 0, 0, 0, 0)) + + +@register_test_case(module_factory=lambda: ReflectionPad3dModuleRight()) +def ReflectionPad3dModuleRight_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 20, 20, 12)) + + +# ============================================================================== + + +class ReflectionPad3dModuleFront(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 20, 20, 12], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.reflection_pad3d(x, (0, 0, 0, 0, 5, 0)) + + +@register_test_case(module_factory=lambda: ReflectionPad3dModuleFront()) +def ReflectionPad3dModuleFront_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 20, 20, 12)) + + +# ============================================================================== + + +class ReflectionPad3dModuleBack(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 20, 20, 12], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.reflection_pad3d(x, (0, 0, 0, 0, 0, 7)) + + +@register_test_case(module_factory=lambda: ReflectionPad3dModuleBack()) +def ReflectionPad3dModuleBack_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 20, 20, 12)) diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index a3d52166385a..1899e09a835b 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -2493,6 +2493,37 @@ func.func @torch.aten.reflection_pad2d$basic(%arg0: !torch.vtensor<[1,20,20],f32 return %1 : !torch.vtensor<[1,40,40],f32> } + +// ----- +// CHECK-LABEL: func.func @torch.aten.reflection_pad3d$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,5,7,3,4],f32>) -> !torch.vtensor<[4,5,11,7,8],f32> { +// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,5,7,3,4],f32> -> tensor<4x5x7x3x4xf32> +// CHECK: %[[VAL_1:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[SLICE_L:.*]] = tosa.slice %[[VAL_0]] {size = array, start = array} : (tensor<4x5x7x3x4xf32>) -> tensor<4x5x7x3x2xf32> +// CHECK: %[[REVERSE_L:.*]] = tosa.reverse %[[SLICE_L]] {axis = 4 : i32} : (tensor<4x5x7x3x2xf32>) -> tensor<4x5x7x3x2xf32> +// CHECK: %[[SLICE_R:.*]] = tosa.slice %[[VAL_0]] {size = array, start = array} : (tensor<4x5x7x3x4xf32>) -> tensor<4x5x7x3x2xf32> +// CHECK: %[[REVERSE_R:.*]] = tosa.reverse %[[SLICE_R]] {axis = 4 : i32} : (tensor<4x5x7x3x2xf32>) -> tensor<4x5x7x3x2xf32> +// CHECK: %[[CONCAT_LR:.*]] = tosa.concat %[[REVERSE_L]], %[[VAL_0]], %[[REVERSE_R]] {axis = 4 : i32} : (tensor<4x5x7x3x2xf32>, tensor<4x5x7x3x4xf32>, tensor<4x5x7x3x2xf32>) -> tensor<4x5x7x3x8xf32> +// CHECK: %[[SLICE_T:.*]] = tosa.slice %[[CONCAT_LR]] {size = array, start = array} : (tensor<4x5x7x3x8xf32>) -> tensor<4x5x7x2x8xf32> +// CHECK: %[[REVERSE_T:.*]] = tosa.reverse %[[SLICE_T]] {axis = 3 : i32} : (tensor<4x5x7x2x8xf32>) -> tensor<4x5x7x2x8xf32> +// CHECK: %[[SLICE_B:.*]] = tosa.slice %[[CONCAT_LR]] {size = array, start = array} : (tensor<4x5x7x3x8xf32>) -> tensor<4x5x7x2x8xf32> +// CHECK: %[[REVERSE_B:.*]] = tosa.reverse %[[SLICE_B]] {axis = 3 : i32} : (tensor<4x5x7x2x8xf32>) -> tensor<4x5x7x2x8xf32> +// CHECK: %[[CONCAT_TB:.*]] = tosa.concat %[[REVERSE_T]], %[[CONCAT_LR]], %[[REVERSE_B]] {axis = 3 : i32} : (tensor<4x5x7x2x8xf32>, tensor<4x5x7x3x8xf32>, tensor<4x5x7x2x8xf32>) -> tensor<4x5x7x7x8xf32> +// CHECK: %[[SLICE_F:.*]] = tosa.slice %[[CONCAT_TB]] {size = array, start = array} : (tensor<4x5x7x7x8xf32>) -> tensor<4x5x2x7x8xf32> +// CHECK: %[[REVERSE_F:.*]] = tosa.reverse %[[SLICE_F]] {axis = 2 : i32} : (tensor<4x5x2x7x8xf32>) -> tensor<4x5x2x7x8xf32> +// CHECK: %[[SLICE_BACK:.*]] = tosa.slice %[[CONCAT_TB]] {size = array, start = array} : (tensor<4x5x7x7x8xf32>) -> tensor<4x5x2x7x8xf32> +// CHECK: %[[REVERSE_BACK:.*]] = tosa.reverse %[[SLICE_BACK]] {axis = 2 : i32} : (tensor<4x5x2x7x8xf32>) -> tensor<4x5x2x7x8xf32> +// CHECK: %[[CONCAT_FB:.*]] = tosa.concat %[[REVERSE_F]], %[[CONCAT_TB]], %[[REVERSE_BACK]] {axis = 2 : i32} : (tensor<4x5x2x7x8xf32>, tensor<4x5x7x7x8xf32>, tensor<4x5x2x7x8xf32>) -> tensor<4x5x11x7x8xf32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CONCAT_FB]] : tensor<4x5x11x7x8xf32> -> !torch.vtensor<[4,5,11,7,8],f32> +// CHECK: return %[[RESULT]] +func.func @torch.aten.reflection_pad3d$basic(%arg0: !torch.vtensor<[4,5,7,3,4],f32>) -> !torch.vtensor<[4,5,11,7,8],f32> { + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int2, %int2, %int2, %int2, %int2, %int2 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.reflection_pad3d %arg0, %0 : !torch.vtensor<[4,5,7,3,4],f32>, !torch.list -> !torch.vtensor<[4,5,11,7,8],f32> + return %1 : !torch.vtensor<[4,5,11,7,8],f32> +} + // ----- // CHECK-LABEL: func.func @torch.aten.replication_pad2d$basic(