Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support convolution with valid padding. #3804

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
op, "only support padding from a list construct");
paddingIntValues = getTypeConvertedValues(rewriter, loc, getTypeConverter(),
paddingIntValues);
if (paddingIntValues.size() == 1) {
for (size_t iDim = 1; iDim < numSpatialDims; iDim++) {
paddingIntValues.push_back(paddingIntValues[0]);
}
}

SmallVector<Value> outputPaddingIntValues;
if (!getListConstructElements(op.getOutputPadding(),
outputPaddingIntValues))
Expand Down
5 changes: 5 additions & 0 deletions lib/Conversion/TorchToStablehlo/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,11 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp<AtenConvolutionOp> {
return rewriter.notifyMatchFailure(op,
"non-const padding list unsupported");
}
if (padding.size() == 1) {
for (auto iDim = 1; iDim < inputTy.getRank() - 2; iDim++) {
padding.push_back(padding[0]);
}
}
SmallVector<int64_t> dilation;
if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilation))) {
return rewriter.notifyMatchFailure(op,
Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2163,6 +2163,9 @@ LogicalResult ConvertAtenOp<AtenConvolutionOp>::matchAndRewrite(
m_TorchListOfConstantInts(padding_2d)))
return rewriter.notifyMatchFailure(op,
"non-const padding list unsupported");
if (padding_2d.size() == 1) {
padding_2d.push_back(padding_2d[0]);
}
// TOSA uses 4D padding {t, b, l, r} while Torch defines 2D padding {t, l}.
// The Torch OFM computation uses 2*pad in each spatial direction, implying
// the same t=b and l=r values for TOSA.
Expand Down
3 changes: 3 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@
"DeformConv2D_basic",
"ReduceAnyDimFloatModule_basic",
"UnfoldModule_basic",
# TorchScript to the backend contract fails for conv.padding specified as str
"Conv2dWithValidPaddingModule_basic",
"Conv2dWithSamePaddingModule_basic",
}

if torch_version_for_comparison() < version.parse("2.5.0.dev"):
Expand Down
52 changes: 52 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,58 @@ def Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier(
module.forward(tu.rand(5, 4, 10, 20))


class Conv2dWithValidPaddingModule(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.conv = torch.nn.Conv2d(
1, 1, 1, stride=[1, 1], padding="valid", dilation=[1, 1], groups=1, bias=1
)
self.train(False)

@export
@annotate_args(
[
None,
([1, 5, 6], torch.float32, True),
]
)
def forward(self, x):
return self.conv(x)


@register_test_case(module_factory=lambda: Conv2dWithValidPaddingModule())
def Conv2dWithValidPaddingModule_basic(module, tu: TestUtils):
t = tu.rand(1, 5, 6)
module.forward(t)


class Conv2dWithSamePaddingModule(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(0)
self.conv = torch.nn.Conv2d(
1, 1, 1, stride=[1, 1], padding="same", dilation=[1, 1], groups=1, bias=1
)
self.train(False)

@export
@annotate_args(
[
None,
([1, 5, 6], torch.float32, True),
]
)
def forward(self, x):
return self.conv(x)


@register_test_case(module_factory=lambda: Conv2dWithSamePaddingModule())
def Conv2dWithSamePaddingModule_basic(module, tu: TestUtils):
t = tu.rand(1, 5, 6)
module.forward(t)


# ==============================================================================


Expand Down