diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index ff1ffd7e2b62..7acf4a5ed948 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -9997,6 +9997,30 @@ def Torch_AtenReflectionPad2dOp : Torch_Op<"aten.reflection_pad2d", [ }]; } +def Torch_AtenReflectionPad3dOp : Torch_Op<"aten.reflection_pad3d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::reflection_pad3d : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$padding + ); + let results = (outs + AnyTorchOptionalTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenReflectionPad3dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenReflectionPad3dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenPadOp : Torch_Op<"aten.pad", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 1c2f7d6f2a11..73d78d3f89ab 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -7342,6 +7342,91 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } +Value reflectionPadAlongAxis(Value input, ArrayRef unpaddedShape, + int64_t paddingAxisLeft, int64_t paddingAxisRight, + int64_t axis, TensorType resultType, Location loc, + ConversionPatternRewriter &rewriter) { + + SmallVector resultTensors; + auto resultShape = resultType.getShape(); + + auto inputType = dyn_cast(input.getType()); + auto inputRank = inputType.getRank(); + auto inputElemTy = inputType.getElementType(); + + assert(inputRank == resultType.getRank()); + int64_t axisOffset = inputRank - axis - 1; + + // Use tosa.slice and tosa.reverse to get the reflection pads based on the + // padding size + if (paddingAxisLeft > 0) { + SmallVector leftStartSlice(inputRank, 0); + SmallVector leftSizeSlice(unpaddedShape.begin(), + unpaddedShape.end() - axisOffset); + for (int64_t iDim = axisOffset - 1; iDim >= 0; iDim--) { + leftSizeSlice.push_back(resultShape[inputRank - iDim - 1]); + } + + leftStartSlice[axis] = 1; + leftSizeSlice[axis] = paddingAxisLeft; + + SmallVector leftPadShape(unpaddedShape.begin(), + unpaddedShape.end() - (axisOffset + 1)); + leftPadShape.push_back(paddingAxisLeft); + for (int64_t iDim = axisOffset - 1; iDim >= 0; iDim--) { + leftPadShape.push_back(resultShape[inputRank - iDim - 1]); + } + + auto leftPadType = RankedTensorType::get(leftPadShape, inputElemTy); + + auto leftPadSlice = rewriter.create( + loc, leftPadType, input, rewriter.getDenseI64ArrayAttr(leftStartSlice), + rewriter.getDenseI64ArrayAttr(leftSizeSlice)); + + auto leftPad = rewriter.create( + loc, leftPadType, leftPadSlice.getResult(), static_cast(axis)); + + resultTensors.push_back(leftPad.getResult()); + } + + resultTensors.push_back(input); + + if (paddingAxisRight > 0) { + SmallVector rightStartSlice(inputRank, 0); + SmallVector rightSizeSlice(unpaddedShape.begin(), + unpaddedShape.end() - axisOffset); + for (int64_t iDim = axisOffset - 1; iDim >= 0; iDim--) { + rightSizeSlice.push_back(resultShape[inputRank - iDim - 1]); + } + + rightStartSlice[axis] = unpaddedShape[axis] - paddingAxisRight - 1; + rightSizeSlice[axis] = paddingAxisRight; + + SmallVector rightPadShape(unpaddedShape.begin(), + unpaddedShape.end() - (axisOffset + 1)); + rightPadShape.push_back(paddingAxisRight); + for (int64_t iDim = axisOffset - 1; iDim >= 0; iDim--) { + rightPadShape.push_back(resultShape[inputRank - iDim - 1]); + } + + auto rightPadType = RankedTensorType::get(rightPadShape, inputElemTy); + + auto rightPadSlice = rewriter.create( + loc, rightPadType, input, + rewriter.getDenseI64ArrayAttr(rightStartSlice), + rewriter.getDenseI64ArrayAttr(rightSizeSlice)); + + auto rightPad = rewriter.create( + loc, rightPadType, rightPadSlice.getResult(), + static_cast(axis)); + + resultTensors.push_back(rightPad.getResult()); + } + + return tosa::CreateOpAndInfer(rewriter, loc, resultType, + resultTensors, axis); +} + // Legalization for aten.reflection_pad1d template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -7355,7 +7440,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto selfShape = selfType.getShape(); auto selfRank = selfType.getRank(); - auto selfElemTy = selfType.getElementType(); auto resultType = dyn_cast(typeConverter->convertType(op.getType())); @@ -7379,62 +7463,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } - SmallVector resultTensors; - - // Use tosa.slice and tosa.reverse to get the reflection pads based on the - // padding size - if (paddingLeft > 0) { - SmallVector leftStartSlice(selfRank, 0); - SmallVector leftSizeSlice(selfShape); - - leftStartSlice[selfRank - 1] = 1; - leftSizeSlice[selfRank - 1] = paddingLeft; - - SmallVector leftPadShape(selfShape.begin(), selfShape.end() - 1); - leftPadShape.push_back(paddingLeft); - - auto leftPadType = RankedTensorType::get(leftPadShape, selfElemTy); - - auto leftPadSlice = rewriter.create( - op->getLoc(), leftPadType, self, - rewriter.getDenseI64ArrayAttr(leftStartSlice), - rewriter.getDenseI64ArrayAttr(leftSizeSlice)); - - auto leftPad = rewriter.create( - op->getLoc(), leftPadType, leftPadSlice.getResult(), - static_cast(selfRank - 1)); - - resultTensors.push_back(leftPad.getResult()); - } - - resultTensors.push_back(self); - - if (paddingRight > 0) { - SmallVector rightStartSlice(selfRank, 0); - SmallVector rightSizeSlice(selfShape); - - rightStartSlice[selfRank - 1] = selfShape[selfRank - 1] - paddingRight - 1; - rightSizeSlice[selfRank - 1] = paddingRight; - - SmallVector rightPadShape(selfShape.begin(), selfShape.end() - 1); - rightPadShape.push_back(paddingRight); - - auto rightPadType = RankedTensorType::get(rightPadShape, selfElemTy); - - auto rightPadSlice = rewriter.create( - op->getLoc(), rightPadType, self, - rewriter.getDenseI64ArrayAttr(rightStartSlice), - rewriter.getDenseI64ArrayAttr(rightSizeSlice)); - - auto rightPad = rewriter.create( - op->getLoc(), rightPadType, rightPadSlice.getResult(), - static_cast(selfRank - 1)); - - resultTensors.push_back(rightPad.getResult()); - } - - auto result = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), resultType, resultTensors, selfRank - 1); + auto result = + reflectionPadAlongAxis(self, selfShape, paddingLeft, paddingRight, + selfRank - 1, resultType, op->getLoc(), rewriter); rewriter.replaceOp(op, result); return success(); @@ -7483,129 +7514,92 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } - // Use tosa.slice and tosa.reverse to get the reflection pads based on the - // padding size - SmallVector sideTensors; - - if (paddingLeft > 0) { - SmallVector leftStartSlice(selfRank, 0); - SmallVector leftSizeSlice(selfShape); - - leftStartSlice[selfRank - 1] = 1; - leftSizeSlice[selfRank - 1] = paddingLeft; - - SmallVector leftPadShape(selfShape.begin(), selfShape.end() - 1); - leftPadShape.push_back(paddingLeft); - - auto leftPadType = RankedTensorType::get(leftPadShape, selfElemTy); - - auto leftPadSlice = rewriter.create( - op->getLoc(), leftPadType, self, - rewriter.getDenseI64ArrayAttr(leftStartSlice), - rewriter.getDenseI64ArrayAttr(leftSizeSlice)); - - auto leftPad = rewriter.create( - op->getLoc(), leftPadType, leftPadSlice.getResult(), - static_cast(selfRank - 1)); - - sideTensors.push_back(leftPad.getResult()); - } - - sideTensors.push_back(self); - - if (paddingRight > 0) { - SmallVector rightStartSlice(selfRank, 0); - SmallVector rightSizeSlice(selfShape); - - rightStartSlice[selfRank - 1] = selfShape[selfRank - 1] - paddingRight - 1; - rightSizeSlice[selfRank - 1] = paddingRight; - - SmallVector rightPadShape(selfShape.begin(), selfShape.end() - 1); - rightPadShape.push_back(paddingRight); - - auto rightPadType = RankedTensorType::get(rightPadShape, selfElemTy); - - auto rightPadSlice = rewriter.create( - op->getLoc(), rightPadType, self, - rewriter.getDenseI64ArrayAttr(rightStartSlice), - rewriter.getDenseI64ArrayAttr(rightSizeSlice)); - - auto rightPad = rewriter.create( - op->getLoc(), rightPadType, rightPadSlice.getResult(), - static_cast(selfRank - 1)); - - sideTensors.push_back(rightPad.getResult()); - } - SmallVector selfSidePaddedShape(selfShape.begin(), selfShape.end() - 1); selfSidePaddedShape.push_back(resultShape.back()); - auto selfSidePadded = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), - RankedTensorType::get(selfSidePaddedShape, selfElemTy), sideTensors, - selfRank - 1); - - SmallVector resultTensors; - - if (paddingTop > 0) { - SmallVector topStartSlice(selfRank, 0); - SmallVector topSizeSlice(selfShape.begin(), selfShape.end() - 1); - topSizeSlice.push_back(resultShape.back()); + auto selfSidePadded = reflectionPadAlongAxis( + self, selfShape, paddingLeft, paddingRight, selfRank - 1, + RankedTensorType::get(selfSidePaddedShape, selfElemTy), op->getLoc(), + rewriter); - topStartSlice[selfRank - 2] = 1; - topSizeSlice[selfRank - 2] = paddingTop; + auto result = reflectionPadAlongAxis(selfSidePadded, selfShape, paddingTop, + paddingBottom, selfRank - 2, resultType, + op->getLoc(), rewriter); - SmallVector topPadShape(selfShape.begin(), selfShape.end() - 2); - topPadShape.push_back(paddingTop); - topPadShape.push_back(resultShape.back()); + rewriter.replaceOp(op, result); + return success(); +} - auto topPadType = RankedTensorType::get(topPadShape, selfElemTy); +// Legalization for aten.reflection_pad3d +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenReflectionPad3dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { + auto self = adaptor.getSelf(); - auto topPadSlice = rewriter.create( - op->getLoc(), topPadType, selfSidePadded, - rewriter.getDenseI64ArrayAttr(topStartSlice), - rewriter.getDenseI64ArrayAttr(topSizeSlice)); + auto selfType = dyn_cast(self.getType()); + if (!selfType) + return rewriter.notifyMatchFailure(op, "Only tensor types are supported"); - auto topPad = rewriter.create( - op->getLoc(), topPadType, topPadSlice.getResult(), - static_cast(selfRank - 2)); + auto selfShape = selfType.getShape(); + auto selfRank = selfType.getRank(); + auto selfElemTy = selfType.getElementType(); - resultTensors.push_back(topPad.getResult()); - } + auto resultType = + dyn_cast(typeConverter->convertType(op.getType())); + auto resultShape = resultType.getShape(); - resultTensors.push_back(selfSidePadded.getResult()); + SmallVector paddingList; + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(paddingList))) + return rewriter.notifyMatchFailure( + op, "Non-const padding lists are not supported"); - if (paddingBottom > 0) { - SmallVector bottomStartSlice(selfRank, 0); - SmallVector bottomSizeSlice(selfShape.begin(), - selfShape.end() - 1); - bottomSizeSlice.push_back(resultShape.back()); + int64_t paddingLeft = paddingList[0]; + int64_t paddingRight = paddingList[1]; + int64_t paddingTop = paddingList[2]; + int64_t paddingBottom = paddingList[3]; + int64_t paddingFront = paddingList[4]; + int64_t paddingBack = paddingList[5]; - bottomStartSlice[selfRank - 2] = - selfShape[selfRank - 2] - paddingBottom - 1; - bottomSizeSlice[selfRank - 2] = paddingBottom; + if (paddingLeft >= selfShape[selfRank - 1] || + paddingRight >= selfShape[selfRank - 1] || + paddingTop >= selfShape[selfRank - 2] || + paddingBottom >= selfShape[selfRank - 2] || + paddingFront >= selfShape[selfRank - 3] || + paddingBack >= selfShape[selfRank - 3]) + return rewriter.notifyMatchFailure( + op, "Padding must be less than the corresponding input dimension"); - SmallVector bottomPadShape(selfShape.begin(), selfShape.end() - 2); - bottomPadShape.push_back(paddingBottom); - bottomPadShape.push_back(resultShape.back()); + // Identity case + if (paddingLeft == 0 && paddingRight == 0 && paddingTop == 0 && + paddingBottom == 0 && paddingFront == 0 && paddingBack == 0) { + rewriter.replaceOp(op, self); + return success(); + } - auto bottomPadType = RankedTensorType::get(bottomPadShape, selfElemTy); + SmallVector self1dPaddedShape(selfShape.begin(), + selfShape.end() - 1); + self1dPaddedShape.push_back(resultShape.back()); - auto bottomPadSlice = rewriter.create( - op->getLoc(), bottomPadType, selfSidePadded, - rewriter.getDenseI64ArrayAttr(bottomStartSlice), - rewriter.getDenseI64ArrayAttr(bottomSizeSlice)); + auto self1dPadded = reflectionPadAlongAxis( + self, selfShape, paddingLeft, paddingRight, selfRank - 1, + RankedTensorType::get(self1dPaddedShape, selfElemTy), op->getLoc(), + rewriter); - auto bottomPad = rewriter.create( - op->getLoc(), bottomPadType, bottomPadSlice.getResult(), - static_cast(selfRank - 2)); + SmallVector self2dPaddedShape(selfShape.begin(), + selfShape.end() - 2); + self2dPaddedShape.push_back(resultShape[resultShape.size() - 2]); + self2dPaddedShape.push_back(resultShape.back()); - resultTensors.push_back(bottomPad.getResult()); - } + auto self2dPadded = reflectionPadAlongAxis( + self1dPadded, selfShape, paddingTop, paddingBottom, selfRank - 2, + RankedTensorType::get(self2dPaddedShape, selfElemTy), op->getLoc(), + rewriter); - auto result = tosa::CreateOpAndInfer( - rewriter, op->getLoc(), resultType, resultTensors, selfRank - 2); + auto result = + reflectionPadAlongAxis(self2dPadded, selfShape, paddingFront, paddingBack, + selfRank - 3, resultType, op->getLoc(), rewriter); rewriter.replaceOp(op, result); return success(); @@ -7798,11 +7792,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "Only constant int outer length value is supported"); - // Technically, I should calculate the output shape based on the dim and outer - // length values. However, that would just give the same result as me taking - // the result shape straight from resultType and applying tosa::ReshapeOp to - // the input. Therefore, I'm opting for the latter approach here, which is - // more simple and quicker. + // Technically, I should calculate the output shape based on the dim and + // outer length values. However, that would just give the same result as me + // taking the result shape straight from resultType and applying + // tosa::ReshapeOp to the input. Therefore, I'm opting for the latter + // approach here, which is more simple and quicker. rewriter.replaceOpWithNewOp( op, resultType, self, rewriter.getDenseI64ArrayAttr(makeShapeTorchCompatible(resultShape))); @@ -8804,6 +8798,7 @@ std::set torch::populateTorchToTosaConversionPatternsAndIllegalOps( INSERT_ATENOP_PATTERN(PrimsCollapseOp); INSERT_ATENOP_PATTERN(AtenReflectionPad1dOp); INSERT_ATENOP_PATTERN(AtenReflectionPad2dOp); + INSERT_ATENOP_PATTERN(AtenReflectionPad3dOp); INSERT_ATENOP_PATTERN(AtenReplicationPad2dOp); INSERT_ATENOP_PATTERN(PrimsSplitDimOp); INSERT_ATENOP_PATTERN(AtenOuterOp); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 5fd05708961c..ae164e00ab2b 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -10599,14 +10599,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %0 : !torch.tuple, list, list>\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.constant_pad_nd\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float) -> !torch.list {\n" -" %0 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" %false = torch.constant.bool false\n" +" %0 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list, !torch.list, !torch.bool) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @__torch__.pad_shape_fn(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" func.func @__torch__.pad_shape_fn(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" " %true = torch.constant.bool true\n" -" %str = torch.constant.str \"AssertionError: Number of padded dimensions must be less than or equal to the input dimension\"\n" +" %str_0 = torch.constant.str \"AssertionError: Number of padded dimensions must be less than or equal to the input dimension\"\n" " %none = torch.constant.none\n" -" %str_0 = torch.constant.str \"AssertionError: Must have paired low-high pad amount values\"\n" +" %str_1 = torch.constant.str \"AssertionError: Must have paired low-high pad amount values\"\n" " %int2 = torch.constant.int 2\n" " %int0 = torch.constant.int 0\n" " %int1 = torch.constant.int 1\n" @@ -10616,7 +10619,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If %2 -> () {\n" " torch.prim.If.yield\n" " } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" " %3 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" @@ -10626,18 +10629,47 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If %6 -> () {\n" " torch.prim.If.yield\n" " } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" " %7 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" " %8 = torch.aten.floordiv.int %7, %int2 : !torch.int, !torch.int -> !torch.int\n" " torch.prim.Loop %8, %true, init() {\n" -" ^bb0(%arg2: !torch.int):\n" -" %9 = torch.aten.add.int %arg2, %int1 : !torch.int, !torch.int -> !torch.int\n" +" ^bb0(%arg3: !torch.int):\n" +" torch.prim.If %arg2 -> () {\n" +" %20 = torch.aten.mul.int %int2, %arg3 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.__getitem__.t %arg1, %20 : !torch.list, !torch.int -> !torch.int\n" +" %22 = torch.aten.add.int %arg3, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %23 = torch.aten.neg.int %22 : !torch.int -> !torch.int\n" +" %24 = torch.aten.__getitem__.t %arg0, %23 : !torch.list, !torch.int -> !torch.int\n" +" %25 = torch.aten.lt.int %21, %24 : !torch.int, !torch.int -> !torch.bool\n" +" %26 = torch.prim.If %25 -> (!torch.bool) {\n" +" %27 = torch.aten.mul.int %int2, %arg3 : !torch.int, !torch.int -> !torch.int\n" +" %28 = torch.aten.add.int %27, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %29 = torch.aten.__getitem__.t %arg1, %28 : !torch.list, !torch.int -> !torch.int\n" +" %30 = torch.aten.add.int %arg3, %int1 : !torch.int, !torch.int -> !torch.int\n" +" %31 = torch.aten.neg.int %30 : !torch.int -> !torch.int\n" +" %32 = torch.aten.__getitem__.t %arg0, %31 : !torch.list, !torch.int -> !torch.int\n" +" %33 = torch.aten.lt.int %29, %32 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %33 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %26 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.aten.add.int %arg3, %int1 : !torch.int, !torch.int -> !torch.int\n" " %10 = torch.aten.neg.int %9 : !torch.int -> !torch.int\n" -" %11 = torch.aten.mul.int %int2, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" %11 = torch.aten.mul.int %int2, %arg3 : !torch.int, !torch.int -> !torch.int\n" " %12 = torch.aten.__getitem__.t %arg1, %11 : !torch.list, !torch.int -> !torch.int\n" -" %13 = torch.aten.mul.int %int2, %arg2 : !torch.int, !torch.int -> !torch.int\n" +" %13 = torch.aten.mul.int %int2, %arg3 : !torch.int, !torch.int -> !torch.int\n" " %14 = torch.aten.add.int %13, %int1 : !torch.int, !torch.int -> !torch.int\n" " %15 = torch.aten.__getitem__.t %arg1, %14 : !torch.list, !torch.int -> !torch.int\n" " %16 = torch.aten.add.int %12, %15 : !torch.int, !torch.int -> !torch.int\n" @@ -10649,6 +10681,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %arg0 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.replication_pad2d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %false = torch.constant.bool false\n" " %str = torch.constant.str \"AssertionError: padding size expected to be 4\"\n" " %none = torch.constant.none\n" " %str_0 = torch.constant.str \"AssertionError: \"\n" @@ -10670,7 +10703,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list, !torch.list, !torch.bool) -> !torch.list\n" " return %4 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_dtype_fn.aten.replication_pad2d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" @@ -10678,17 +10711,15 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " return %0#1 : !torch.int\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.pad\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.str, %arg3: !torch.optional) -> !torch.list {\n" -" %0 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" %false = torch.constant.bool false\n" +" %0 = call @__torch__.pad_shape_fn(%arg0, %arg1, %false) : (!torch.list, !torch.list, !torch.bool) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.reflection_pad1d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" -" %false = torch.constant.bool false\n" -" %int-1 = torch.constant.int -1\n" +" %true = torch.constant.bool true\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int2 = torch.constant.int 2\n" -" %int1 = torch.constant.int 1\n" -" %int0 = torch.constant.int 0\n" " %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" " %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" " torch.prim.If %1 -> () {\n" @@ -10697,37 +10728,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" -" %3 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %4 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %5 = torch.aten.lt.int %3, %2 : !torch.int, !torch.int -> !torch.bool\n" -" %6 = torch.prim.If %5 -> (!torch.bool) {\n" -" %8 = torch.aten.lt.int %4, %2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %8 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %6 -> () {\n" -" torch.prim.If.yield\n" -" } else {\n" -" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" -" torch.prim.If.yield\n" -" }\n" -" %7 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" -" return %7 : !torch.list\n" +" %2 = call @__torch__.pad_shape_fn(%arg0, %arg1, %true) : (!torch.list, !torch.list, !torch.bool) -> !torch.list\n" +" return %2 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.reflection_pad2d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" -" %false = torch.constant.bool false\n" +" %true = torch.constant.bool true\n" " %str = torch.constant.str \"AssertionError: padding size expected to be 4\"\n" -" %int-1 = torch.constant.int -1\n" -" %int-2 = torch.constant.int -2\n" " %none = torch.constant.none\n" " %str_0 = torch.constant.str \"AssertionError: \"\n" " %int2 = torch.constant.int 2\n" -" %int1 = torch.constant.int 1\n" " %int4 = torch.constant.int 4\n" -" %int0 = torch.constant.int 0\n" -" %int3 = torch.constant.int 3\n" " %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" " %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" " torch.prim.If %1 -> () {\n" @@ -10736,48 +10746,42 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %2 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" -" %3 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" -" %4 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" -" %5 = torch.aten.eq.int %4, %int4 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If %5 -> () {\n" +" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %6 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" -" %7 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" -" %8 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int\n" -" %9 = torch.aten.__getitem__.t %arg1, %int3 : !torch.list, !torch.int -> !torch.int\n" -" %10 = torch.aten.lt.int %6, %3 : !torch.int, !torch.int -> !torch.bool\n" -" %11 = torch.prim.If %10 -> (!torch.bool) {\n" -" %15 = torch.aten.lt.int %7, %3 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %11 -> () {\n" +" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1, %true) : (!torch.list, !torch.list, !torch.bool) -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.reflection_pad3d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %true = torch.constant.bool true\n" +" %str = torch.constant.str \"AssertionError: padding size expected to be 6\"\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: \"\n" +" %int3 = torch.constant.int 3\n" +" %int6 = torch.constant.int 6\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.ge.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %12 = torch.aten.lt.int %8, %2 : !torch.int, !torch.int -> !torch.bool\n" -" %13 = torch.prim.If %12 -> (!torch.bool) {\n" -" %15 = torch.aten.lt.int %9, %2 : !torch.int, !torch.int -> !torch.bool\n" -" torch.prim.If.yield %15 : !torch.bool\n" -" } else {\n" -" torch.prim.If.yield %false : !torch.bool\n" -" }\n" -" torch.prim.If %13 -> () {\n" +" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int6 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" -" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %14 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" -" return %14 : !torch.list\n" +" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1, %true) : (!torch.list, !torch.list, !torch.bool) -> !torch.list\n" +" return %4 : !torch.list\n" " }\n" " func.func @\"__torch_mlir_shape_fn.aten.index.Tensor\"(%arg0: !torch.list, %arg1: !torch.list>>) -> !torch.list {\n" " %0 = call @__torch__.index_tensor_like(%arg0, %arg1) : (!torch.list, !torch.list>>) -> !torch.list\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 9c2a80187c93..91d6b5eb17fc 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -7902,17 +7902,25 @@ class DecomposeAtenPadOp : public OpRewritePattern { if (mode == "reflect") { // only support for relectionpad 1d and 2d - if (numPadDims == 2) { + switch (numPadDims) { + case 1: + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), usefulPads); + break; + case 2: rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), usefulPads); - return success(); - } - if (numPadDims == 1) { - rewriter.replaceOpWithNewOp( + break; + case 3: + rewriter.replaceOpWithNewOp( op, op.getType(), op.getSelf(), usefulPads); - return success(); + break; + default: + return rewriter.notifyMatchFailure( + op, "unsupported number of dims for 'reflect' mode: " + + std::to_string(numPadDims)); } - return failure(); + return success(); } if (mode == "replicate") { diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 1dce55f06158..38eb1f573362 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -506,6 +506,13 @@ "BernoulliTensorModule_basic", "UniformModule_basic", "UniformStaticShapeModule_basic", + "ReflectionPad3dModule_basic", + "ReflectionPad3dModuleTop_basic", + "ReflectionPad3dModuleBottom_basic", + "ReflectionPad3dModuleLeft_basic", + "ReflectionPad3dModuleRight_basic", + "ReflectionPad3dModuleFront_basic", + "ReflectionPad3dModuleBack_basic", } FX_IMPORTER_CRASHING_SET = LINALG_CRASHING_SET | { @@ -801,6 +808,13 @@ "ReflectionPad2dModule_Right", "ReflectionPad2dModule_Top", "ReflectionPad2dModule_basic", + "ReflectionPad3dModule_basic", + "ReflectionPad3dModuleTop_basic", + "ReflectionPad3dModuleBottom_basic", + "ReflectionPad3dModuleLeft_basic", + "ReflectionPad3dModuleRight_basic", + "ReflectionPad3dModuleFront_basic", + "ReflectionPad3dModuleBack_basic", "ReplicationPad2dModule_basic", "ReplicationPad2dModule_bottom0", "ReplicationPad2dModule_left0", @@ -3114,6 +3128,9 @@ "ReduceL1NormComplexModule_basic", "ReduceL2NormComplexModule_basic", "ReduceL3NormKeepDimComplexModule_basic", + "ReflectionPad3dModule_basic", + "ReflectionPad3dModuleFront_basic", + "ReflectionPad3dModuleBack_basic", "RreluWithNoiseBackwardEvalModule_basic", "RreluWithNoiseBackwardEvalStaticModule_basic", "RreluWithNoiseBackwardTrainModule_basic", @@ -3449,6 +3466,7 @@ "ElementwiseSignbitModule_basic", "Aten_TrilinearModuleVaryingRanks_basic", "Aten_TrilinearModuleZerodDimBug_basic", + "AtenNonzero1DDynamicModule_basic", "MaxPool3dEmptyStrideStaticModule_basic", "MaxPool3dLargeDatadModule_basic", "MaxPool3dModuleRandomSimple_basic", @@ -3907,6 +3925,7 @@ ONNX_TOSA_XFAIL_SET = { "AtenFftRfft2DLastDim_basic", "AtenFftRfft2DMiddleDim_basic", + "AtenNonzero1DDynamicModule_basic", "PowFloatIntModule_basic", "PowIntFloatModule_basic", "PowIntIntModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index a73d188d7168..d0170b1bf9b0 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -2153,12 +2153,14 @@ def aten〇native_batch_norm〡shape(input: List[int], weight: Optional[List[int # TODO: This should be upstreamed. # See https://github.com/pytorch/pytorch/pull/76889 for an example. -def pad_shape_fn(input: List[int], pad: List[int]): +def pad_shape_fn(input: List[int], pad: List[int], validate_pad : bool = False): assert len(pad) % 2 == 0, "Must have paired low-high pad amount values" assert len(pad) // 2 <= len(input), "Number of padded dimensions must be less than or equal to the input dimension" # The `pad` list takes the form of Low-high pairs starting at the # *rightmost* dimension of `self`. for i in range(len(pad) // 2): + if validate_pad: + assert pad[2*i] < input[-(i+1)] and pad[2 * i + 1] < input[-(i+1)] input[-(i + 1)] += pad[2 * i] + pad[2 * i + 1] return input @@ -2193,11 +2195,7 @@ def aten〇pad〡shape(self: List[int], pad: List[int], mode: str = "constant", ErrorInvocation(TensorOfShape(1, 4), padding=[1,4])]) def aten〇reflection_pad1d〡shape(self: List[int], padding: List[int]) -> List[int]: assert len(self) >= 2 - hdim = self[-1] - padding_left = padding[0] - padding_right = padding[1] - assert padding_left < hdim and padding_right < hdim - return pad_shape_fn(self, padding) + return pad_shape_fn(self, padding, validate_pad=True) # Padding size must be smaller than corresponding dimension @@ -2210,18 +2208,21 @@ def aten〇reflection_pad1d〡shape(self: List[int], padding: List[int]) -> List ErrorInvocation(TensorOfShape(2, 2, 2), padding=[1,1,2,2])]) def aten〇reflection_pad2d〡shape(self: List[int], padding: List[int]) -> List[int]: assert len(self) >= 2 - vdim = self[-2] - hdim = self[-1] - assert len(padding) == 4, 'padding size expected to be 4' - padding_left = padding[0] - padding_right = padding[1] - padding_top = padding[2] - padding_bottom = padding[3] - assert padding_left < hdim and padding_right < hdim - assert padding_top < vdim and padding_bottom < vdim + return pad_shape_fn(self, padding, validate_pad=True) - return pad_shape_fn(self, padding) +# Padding size must be smaller than corresponding dimension +@check_shape_function([ErrorInvocation(TensorOfShape(2, 2, 2, 2), padding=[2,2,1,1,1,1]), + ErrorInvocation(TensorOfShape(2, 2, 2, 2), padding=[2,1,1,1,1,1]), + ErrorInvocation(TensorOfShape(2, 2, 2, 2), padding=[2,1,1,1,1,3]), + ErrorInvocation(TensorOfShape(2, 2, 2, 2), padding=[2,1]), + Invocation(TensorOfShape(2, 2, 2, 2), padding=[1,1,1,1,1,1]), + ErrorInvocation(TensorOfShape(2, 2, 2, 2), padding=[1,1,1,1,1,2]), + ErrorInvocation(TensorOfShape(2, 2, 2, 2), padding=[1,1,2,1,1,1])]) +def aten〇reflection_pad3d〡shape(self: List[int], padding: List[int]) -> List[int]: + assert len(self) >= 3 + assert len(padding) == 6, 'padding size expected to be 6' + return pad_shape_fn(self, padding, validate_pad=True) # TODO: upstream this def index_tensor_like(self: List[int], indices: List[Optional[List[int]]]) -> List[int]: diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 930979b3c939..8a9c990de9a0 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -789,6 +789,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::replication_pad2d : (Tensor, int[]) -> (Tensor)") emit("aten::reflection_pad1d : (Tensor, int[]) -> (Tensor)") emit("aten::reflection_pad2d : (Tensor, int[]) -> (Tensor)") + emit("aten::reflection_pad3d : (Tensor, int[]) -> (Tensor)") emit("aten::pad : (Tensor, int[], str, float?) -> (Tensor)") emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True) emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py index a97d7f09eda6..b9c58551d657 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py @@ -123,3 +123,164 @@ def forward(self, x): @register_test_case(module_factory=lambda: ReflectionPad2dModuleRight()) def ReflectionPad2dModule_Right(module, tu: TestUtils): module.forward(tu.rand(2, 3, 20, 20)) + + +# ============================================================================== + + +class ReflectionPad3dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 20, 20, 20, 20], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.reflection_pad3d(x, (10, 10, 10, 10, 10, 10)) + + +@register_test_case(module_factory=lambda: ReflectionPad3dModule()) +def ReflectionPad3dModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 20, 20, 20, 20, low=-1)) + + +# ============================================================================== + + +class ReflectionPad3dModuleTop(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([1, 3, 4, 5, 6], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.reflection_pad3d(x, (0, 0, 2, 0, 0, 0)) + + +@register_test_case(module_factory=lambda: ReflectionPad3dModuleTop()) +def ReflectionPad3dModuleTop_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 3, 4, 5, 6)) + + +# ============================================================================== + + +class ReflectionPad3dModuleBottom(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 10, 10, 6], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.reflection_pad3d(x, (0, 0, 0, 5, 0, 0)) + + +@register_test_case(module_factory=lambda: ReflectionPad3dModuleBottom()) +def ReflectionPad3dModuleBottom_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 10, 10, 6)) + + +# ============================================================================== + + +class ReflectionPad3dModuleLeft(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 20, 20, 10], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.reflection_pad3d(x, (9, 0, 0, 0, 0, 0)) + + +@register_test_case(module_factory=lambda: ReflectionPad3dModuleLeft()) +def ReflectionPad3dModuleLeft_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 20, 20, 10)) + + +# ============================================================================== + + +class ReflectionPad3dModuleRight(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 20, 20, 12], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.reflection_pad3d(x, (0, 11, 0, 0, 0, 0)) + + +@register_test_case(module_factory=lambda: ReflectionPad3dModuleRight()) +def ReflectionPad3dModuleRight_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 20, 20, 12)) + + +# ============================================================================== + + +class ReflectionPad3dModuleFront(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 20, 20, 12], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.reflection_pad3d(x, (0, 0, 0, 0, 5, 0)) + + +@register_test_case(module_factory=lambda: ReflectionPad3dModuleFront()) +def ReflectionPad3dModuleFront_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 20, 20, 12)) + + +# ============================================================================== + + +class ReflectionPad3dModuleBack(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 3, 20, 20, 12], torch.float32, True), + ] + ) + def forward(self, x): + return torch.ops.aten.reflection_pad3d(x, (0, 0, 0, 0, 0, 7)) + + +@register_test_case(module_factory=lambda: ReflectionPad3dModuleBack()) +def ReflectionPad3dModuleBack_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 20, 20, 12)) diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index a3d52166385a..1899e09a835b 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -2493,6 +2493,37 @@ func.func @torch.aten.reflection_pad2d$basic(%arg0: !torch.vtensor<[1,20,20],f32 return %1 : !torch.vtensor<[1,40,40],f32> } + +// ----- +// CHECK-LABEL: func.func @torch.aten.reflection_pad3d$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[4,5,7,3,4],f32>) -> !torch.vtensor<[4,5,11,7,8],f32> { +// CHECK: %[[VAL_0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[4,5,7,3,4],f32> -> tensor<4x5x7x3x4xf32> +// CHECK: %[[VAL_1:.*]] = torch.constant.int 2 +// CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]], %[[VAL_1]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list +// CHECK: %[[SLICE_L:.*]] = tosa.slice %[[VAL_0]] {size = array, start = array} : (tensor<4x5x7x3x4xf32>) -> tensor<4x5x7x3x2xf32> +// CHECK: %[[REVERSE_L:.*]] = tosa.reverse %[[SLICE_L]] {axis = 4 : i32} : (tensor<4x5x7x3x2xf32>) -> tensor<4x5x7x3x2xf32> +// CHECK: %[[SLICE_R:.*]] = tosa.slice %[[VAL_0]] {size = array, start = array} : (tensor<4x5x7x3x4xf32>) -> tensor<4x5x7x3x2xf32> +// CHECK: %[[REVERSE_R:.*]] = tosa.reverse %[[SLICE_R]] {axis = 4 : i32} : (tensor<4x5x7x3x2xf32>) -> tensor<4x5x7x3x2xf32> +// CHECK: %[[CONCAT_LR:.*]] = tosa.concat %[[REVERSE_L]], %[[VAL_0]], %[[REVERSE_R]] {axis = 4 : i32} : (tensor<4x5x7x3x2xf32>, tensor<4x5x7x3x4xf32>, tensor<4x5x7x3x2xf32>) -> tensor<4x5x7x3x8xf32> +// CHECK: %[[SLICE_T:.*]] = tosa.slice %[[CONCAT_LR]] {size = array, start = array} : (tensor<4x5x7x3x8xf32>) -> tensor<4x5x7x2x8xf32> +// CHECK: %[[REVERSE_T:.*]] = tosa.reverse %[[SLICE_T]] {axis = 3 : i32} : (tensor<4x5x7x2x8xf32>) -> tensor<4x5x7x2x8xf32> +// CHECK: %[[SLICE_B:.*]] = tosa.slice %[[CONCAT_LR]] {size = array, start = array} : (tensor<4x5x7x3x8xf32>) -> tensor<4x5x7x2x8xf32> +// CHECK: %[[REVERSE_B:.*]] = tosa.reverse %[[SLICE_B]] {axis = 3 : i32} : (tensor<4x5x7x2x8xf32>) -> tensor<4x5x7x2x8xf32> +// CHECK: %[[CONCAT_TB:.*]] = tosa.concat %[[REVERSE_T]], %[[CONCAT_LR]], %[[REVERSE_B]] {axis = 3 : i32} : (tensor<4x5x7x2x8xf32>, tensor<4x5x7x3x8xf32>, tensor<4x5x7x2x8xf32>) -> tensor<4x5x7x7x8xf32> +// CHECK: %[[SLICE_F:.*]] = tosa.slice %[[CONCAT_TB]] {size = array, start = array} : (tensor<4x5x7x7x8xf32>) -> tensor<4x5x2x7x8xf32> +// CHECK: %[[REVERSE_F:.*]] = tosa.reverse %[[SLICE_F]] {axis = 2 : i32} : (tensor<4x5x2x7x8xf32>) -> tensor<4x5x2x7x8xf32> +// CHECK: %[[SLICE_BACK:.*]] = tosa.slice %[[CONCAT_TB]] {size = array, start = array} : (tensor<4x5x7x7x8xf32>) -> tensor<4x5x2x7x8xf32> +// CHECK: %[[REVERSE_BACK:.*]] = tosa.reverse %[[SLICE_BACK]] {axis = 2 : i32} : (tensor<4x5x2x7x8xf32>) -> tensor<4x5x2x7x8xf32> +// CHECK: %[[CONCAT_FB:.*]] = tosa.concat %[[REVERSE_F]], %[[CONCAT_TB]], %[[REVERSE_BACK]] {axis = 2 : i32} : (tensor<4x5x2x7x8xf32>, tensor<4x5x7x7x8xf32>, tensor<4x5x2x7x8xf32>) -> tensor<4x5x11x7x8xf32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CONCAT_FB]] : tensor<4x5x11x7x8xf32> -> !torch.vtensor<[4,5,11,7,8],f32> +// CHECK: return %[[RESULT]] +func.func @torch.aten.reflection_pad3d$basic(%arg0: !torch.vtensor<[4,5,7,3,4],f32>) -> !torch.vtensor<[4,5,11,7,8],f32> { + %int2 = torch.constant.int 2 + %0 = torch.prim.ListConstruct %int2, %int2, %int2, %int2, %int2, %int2 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + %1 = torch.aten.reflection_pad3d %arg0, %0 : !torch.vtensor<[4,5,7,3,4],f32>, !torch.list -> !torch.vtensor<[4,5,11,7,8],f32> + return %1 : !torch.vtensor<[4,5,11,7,8],f32> +} + // ----- // CHECK-LABEL: func.func @torch.aten.replication_pad2d$basic(