From de4cb0d41567cacb28be7e16012033bc10afeb79 Mon Sep 17 00:00:00 2001 From: Yanxing-Shi Date: Fri, 2 Dec 2022 23:13:54 +0000 Subject: [PATCH 1/3] Initial Commit --- .../backend/rocm/conv2d/__init__.py | 2 + .../aitemplate/backend/rocm/conv2d/common.py | 6 +- .../backend/rocm/conv2d/conv2d_bias_silu.py | 192 +++ .../aitemplate/compiler/ops/conv/__init__.py | 1 + .../compiler/ops/conv/conv2d_bias_silu.py | 77 + .../compiler/transform/fuse_conv_patterns.py | 8 + .../aitemplate/frontend/nn/conv2d/__init__.py | 1 + .../frontend/nn/conv2d/conv2d_bias_silu.py | 45 + .../utils/mk_ck_lib/conv2d_operation.py | 2 + .../aitemplate/utils/mk_ck_lib/generator.py | 7 + python/aitemplate/utils/mk_ck_lib/library.py | 5 + .../compiler/test_fuse_conv_elementwise.py | 1249 +++++++++-------- 12 files changed, 991 insertions(+), 604 deletions(-) create mode 100644 python/aitemplate/backend/rocm/conv2d/conv2d_bias_silu.py create mode 100644 python/aitemplate/compiler/ops/conv/conv2d_bias_silu.py create mode 100644 python/aitemplate/frontend/nn/conv2d/conv2d_bias_silu.py diff --git a/python/aitemplate/backend/rocm/conv2d/__init__.py b/python/aitemplate/backend/rocm/conv2d/__init__.py index b40cf91b0..786068c34 100644 --- a/python/aitemplate/backend/rocm/conv2d/__init__.py +++ b/python/aitemplate/backend/rocm/conv2d/__init__.py @@ -22,6 +22,7 @@ conv2d_bias_add_relu, conv2d_bias_relu, conv2d_bias_sigmoid, + conv2d_bias_silu, transposed_conv2d, transposed_conv2d_bias_relu, ) @@ -33,6 +34,7 @@ "conv2d_bias_add_relu", "conv2d_bias_relu", "conv2d_bias_sigmoid", + "conv2d_bias_silu", "transposed_conv2d", "transposed_conv2d_bias_relu", ] diff --git a/python/aitemplate/backend/rocm/conv2d/common.py b/python/aitemplate/backend/rocm/conv2d/common.py index 8e8dc3d7c..315fa2435 100644 --- a/python/aitemplate/backend/rocm/conv2d/common.py +++ b/python/aitemplate/backend/rocm/conv2d/common.py @@ -40,7 +40,7 @@ {% if conv2d_flag == "" %} {{indent}} {}, -{% elif conv2d_flag in ["bias", "bias_relu", "bias_sigmoid"] %} +{% elif conv2d_flag in ["bias", "bias_relu", "bias_sigmoid", "bias_silu"] %} {{indent}} std::array{static_cast(bias_ptr)}, {% elif conv2d_flag in ["bias_add_relu", "bias_add_identity"] %} {{indent}} std::array{static_cast(bias_ptr), static_cast(res_ptr)}, @@ -52,7 +52,7 @@ {{indent}} b_g_k_c_xs_strides, {% if conv2d_flag == "" %} {{indent}} {}, {}, -{% elif conv2d_flag in ["bias", "bias_relu", "bias_sigmoid"] %} +{% elif conv2d_flag in ["bias", "bias_relu", "bias_sigmoid", "bias_silu"] %} {{indent}} std::array, 1>{ {d_g_n_k_wos_lengths} }, {{indent}} std::array, 1>{ {d_g_n_k_wos_strides} }, {% elif conv2d_flag in ["bias_add_relu", "bias_add_identity"] %} @@ -75,6 +75,8 @@ {{indent}} ck::tensor_operation::element_wise::AddRelu{} {% elif conv2d_flag == "bias_sigmoid" %} {{indent}} ck::tensor_operation::element_wise::AddSigmoid{} +{% elif conv2d_flag == "bias_silu" %} +{{indent}} ck::tensor_operation::element_wise::AddSiLU{} {% elif conv2d_flag == "bias_add_identity" %} {{indent}} ck::tensor_operation::element_wise::AddAdd{} {% elif conv2d_flag == "bias_add_relu" %} diff --git a/python/aitemplate/backend/rocm/conv2d/conv2d_bias_silu.py b/python/aitemplate/backend/rocm/conv2d/conv2d_bias_silu.py new file mode 100644 index 000000000..f17ccf41b --- /dev/null +++ b/python/aitemplate/backend/rocm/conv2d/conv2d_bias_silu.py @@ -0,0 +1,192 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +ROCM codegen functions for conv2d_bias_silu. +""" +import jinja2 + +from ... import registry +from . import common + +# pylint: disable=C0103,C0415,W0613 + +EXTRA_CODE = jinja2.Template( + """ +#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" + +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +namespace ck { +namespace tensor_operation { +namespace element_wise { +namespace { +struct AddSiLU +{ + template + __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const{ + T a; + ck::tensor_operation::element_wise::Sigmoid{}(a, x0 + x1); + y = a * (x0 + x1); + }; +}; +} // namespace +} // namespace element_wise +} // namespace tensor_operation +} // namespace ck +""" +) + + +@registry.reg("rocm.conv2d_bias_silu.config") +def conv2d_config(func_attrs): + """Extracts (operation name, operation instance) pair from + all operation candidates. + + + Parameters + ---------- + func_attrs : Dict + Operation attributes. + + Returns + ------- + Dict + Extracted (operation name, operation instance) pair + from all operation candidates. + """ + import ck_lib + + op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasRelu + extra_kind = ck_lib.library.TensorOperation.AddSiLU + func_attrs["op_instance"] = common.extract_config(op_kind, extra_kind) + + +@registry.reg("rocm.conv2d_bias_silu.gen_profiler") +def conv2d_gen_profiler(func_attrs, workdir, shape_template): + """Generates standalone executables for profiler. + + Parameters + ---------- + func_attrs : Dict + Operation attributes. + workdir : str + Directory to store the generated outputs. + shape_template : jinja2.Template + Generates shape calculation. + The template is passed from compiler/ops/pool. + """ + return common.gen_profiler( + func_attrs=func_attrs, + workdir=workdir, + shape_template=shape_template, + conv2d_flag="bias_silu", + extra_code=EXTRA_CODE.render(), + ) + + +@registry.reg("rocm.conv2d_bias_silu.gen_function") +def conv2d_gen_function( + func_attrs, + exec_cond_remplate, + shape_eval_template, + shape_save_template, +): + """Generates function body. + + Parameters + ---------- + func_attrs : Dict + Operation attributes. + exec_cond_remplate : jinja2.Template + Generates if statement to execute kernel. + shape_eval_template : jinja2.Template + Generates shape calculation. + The template is passed from compiler/ops/pool. + shape_save_template : jinja2.Template + Generates output dimensions. + The template is passed from compiler/ops/pool. + + Returns + ------- + str + The rendered template of generated function body. + """ + return common.gen_function( + func_attrs, + exec_cond_remplate, + shape_eval_template, + shape_save_template, + "bias_silu", + extra_code=EXTRA_CODE.render(), + ) + + +@registry.reg("rocm.conv2d_bias_silu.func_decl") +def conv2d_gen_function_decl(func_attrs): + """Generates function declarations. + + Parameters + ---------- + func_attrs : Dict + Operation attributes. + + Returns + ------- + str + The rentered template of function declaration. + """ + func_name = func_attrs["name"] + return common.gen_function_decl(func_name=func_name, conv2d_flag="bias_sigmoid") + + +@registry.reg("rocm.conv2d_bias_silu.func_call") +def conv2d_gen_function_call(func_attrs, indent=" "): + """Generates function call. + + Parameters + ---------- + func_attrs : Dict + Stores the operation attributes. + indent : str, optional + Indent for codegen, target dependent e.g. C++, python, etc., by default " ". + + Returns + ------- + str + The rendered template of generated function call. + """ + return common.gen_function_call(func_attrs, indent, conv2d_flag="bias_silu") + + +@registry.reg("rocm.conv2d_bias_silu.filter") +def conv2d_function_filter(cfg, func_attrs, x_shape): + """Generates function filter. + + Parameters + ---------- + cfg: str + The filename generated for profiler. + func_attrs : Dict + Stores the operation attributes. + x_shape: + Input shapes. + + Returns + ------- + bool + If input cfg should be filtered. + """ + # return common.function_filter(cfg, func_attrs, 3) + return True diff --git a/python/aitemplate/compiler/ops/conv/__init__.py b/python/aitemplate/compiler/ops/conv/__init__.py index 1233111c6..4f785e790 100644 --- a/python/aitemplate/compiler/ops/conv/__init__.py +++ b/python/aitemplate/compiler/ops/conv/__init__.py @@ -27,6 +27,7 @@ from .conv2d_bias_relu import conv2d_bias_relu from .conv2d_bias_relu_few_channels import conv2d_bias_relu_few_channels from .conv2d_bias_sigmoid import conv2d_bias_sigmoid +from .conv2d_bias_silu import conv2d_bias_silu from .conv2d_depthwise import conv2d_depthwise from .conv2d_depthwise_bias import conv2d_depthwise_bias from .conv3d import conv3d diff --git a/python/aitemplate/compiler/ops/conv/conv2d_bias_silu.py b/python/aitemplate/compiler/ops/conv/conv2d_bias_silu.py new file mode 100644 index 000000000..30bd05d94 --- /dev/null +++ b/python/aitemplate/compiler/ops/conv/conv2d_bias_silu.py @@ -0,0 +1,77 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Fused conv2d_bias_silu op. +""" +from .common_conv2d_bias_activation import conv2d_bias_activation + + +# pylint: disable=C0103 +class conv2d_bias_silu(conv2d_bias_activation): + r"""Conv2d with bias + silu. + + Applies a 2D convolution on input in shape (N, H, W, C_in), adds a bias in shape (C_out), performs sigmoid and produces output in shape (N, H_out, W_out, C_out). N is batch size, H, W are the height and width of the input images in pixels, and C is the number of channels. + + Args: + input: input tensor of shape :math:`(N , H , W, \text{in\_channels})` + + weight: filters of shape :math:`(\text{out\_channels} , K_h, K_w, \frac{\text{in\_channels}}{\text{groups}})` + + bias: optional bias tensor of shape :math:`(\text{out\_channels})` + + This operator uses "channels_last" data format. Below is an example and its equivalence in PyTorch: + + .. highlight:: python + .. code-block:: python + + X = Tensor(shape=[N, H, W, C_in], dtype="float16", name="images", is_input=True) + W = Tensor(shape=[C_out, K_h, K_w, C_in], dtype="float16", name="weight", is_input=True) + B = Tensor(shape=[C_out], dtype="float16", name="weight", is_input=True) + OP = aitemplate.compiler.ops.conv2d_bias_sigmoid(stride=1, pad=1, dilate=1) + Result_ait = OP(X, W, B) + + .. highlight:: python + .. code-block:: python + + X_pt = NHWC2NCHW(X_ait) + W_pt = NHWC2NCHW(W_ait) + B_pt = NHWC2NCHW(B_ait) + + Y = torch.nn.functional.conv2d(X_pt, W_pt, bias=B_pt) + Result_pt =torch.nn.functional.silu(Y) + Result_ait = NCHW2NHWC(Result_pt) + """ + + def __init__(self, stride, pad, dilate=1, group=1) -> None: + """Conv2d_bias_silu constructor. + + Parameters + ---------- + stride : int + Stride of the convolution + pad : int + Size of padding to add to the input + dilate : int, optional + Size of spacing between kernel elements, by default 1 + group : int, optional + Number of input channels to process to compute one output channel, by default 1 + """ + super().__init__("silu", stride, pad, dilate=dilate, group=group) + + def _get_op_attributes(self): + attr = super()._get_op_attributes() + del attr["activation"] + + return attr diff --git a/python/aitemplate/compiler/transform/fuse_conv_patterns.py b/python/aitemplate/compiler/transform/fuse_conv_patterns.py index fadc7f69b..93a78f8d6 100644 --- a/python/aitemplate/compiler/transform/fuse_conv_patterns.py +++ b/python/aitemplate/compiler/transform/fuse_conv_patterns.py @@ -23,6 +23,7 @@ conv2d_bias_relu, conv2d_bias_relu_few_channels, conv2d_bias_sigmoid, + conv2d_bias_silu, transposed_conv2d, transposed_conv2d_bias, transposed_conv2d_bias_relu, @@ -67,6 +68,13 @@ def get_conv2d_bias_elementwise_patterns(): ), conv2d_bias_sigmoid, ), + ( + ( + conv2d_bias(stride=1, pad=0), + elementwise(FuncEnum.SILU), + ), + conv2d_bias_silu, + ), ( ( conv2d_bias(stride=1, pad=0), diff --git a/python/aitemplate/frontend/nn/conv2d/__init__.py b/python/aitemplate/frontend/nn/conv2d/__init__.py index 79375c8f1..45824b362 100644 --- a/python/aitemplate/frontend/nn/conv2d/__init__.py +++ b/python/aitemplate/frontend/nn/conv2d/__init__.py @@ -26,6 +26,7 @@ from .conv2d_bias_relu import Conv2dBiasRelu from .conv2d_bias_relu_few_channels import Conv2dBiasReluFewChannels from .conv2d_bias_sigmoid import Conv2dBiasSigmoid +from .conv2d_bias_silu import Conv2dBiasSiLU from .conv2d_depthwise import Conv2dDepthwise from .conv2d_depthwise_bias import Conv2dDepthwiseBias from .transposed_conv2d_bias import ConvTranspose2dBias diff --git a/python/aitemplate/frontend/nn/conv2d/conv2d_bias_silu.py b/python/aitemplate/frontend/nn/conv2d/conv2d_bias_silu.py new file mode 100644 index 000000000..fe4482b40 --- /dev/null +++ b/python/aitemplate/frontend/nn/conv2d/conv2d_bias_silu.py @@ -0,0 +1,45 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +conv2d bias sigmoid module +""" +from .common_conv2d_bias_act import Conv2dBiasAct + + +class Conv2dBiasSiLU(Conv2dBiasAct): + r"""Applies 2D convolution with bias + silu.""" + + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + padding=0, + dilation=1, + groups=1, + dtype="float16", + ): + super().__init__( + "conv2d_bias_silu", + in_channels, + out_channels, + kernel_size, + stride, + padding, + dilation, + groups, + dtype, + ) diff --git a/python/aitemplate/utils/mk_ck_lib/conv2d_operation.py b/python/aitemplate/utils/mk_ck_lib/conv2d_operation.py index 85c1056c4..b9e2922fb 100644 --- a/python/aitemplate/utils/mk_ck_lib/conv2d_operation.py +++ b/python/aitemplate/utils/mk_ck_lib/conv2d_operation.py @@ -53,6 +53,7 @@ class XdlOpType(enum.Enum): DeviceConv2d_Xdl_CShuffle_Bias_Relu = auto() DeviceConv2d_Xdl_CShuffle_Bias_Relu_Add = auto() DeviceConv2d_Xdl_CShuffle_Bias_Sigmoid = auto() + DeviceConv2d_Xdl_CShuffle_Bias_SiLU = auto() DeviceGroupedConv2D_Xdl_CShuffle_Bias_Relu = auto() DeviceConvNdBwdDataNwcKxcNwk_Xdl = auto() DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 = auto() @@ -63,6 +64,7 @@ class XdlOpType(enum.Enum): XdlOpType.DeviceConv2d_Xdl_CShuffle_Bias_Relu: "ck::tensor_operation::device::DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K", XdlOpType.DeviceConv2d_Xdl_CShuffle_Bias_Relu_Add: "ck::tensor_operation::device::DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K", XdlOpType.DeviceConv2d_Xdl_CShuffle_Bias_Sigmoid: "ck::tensor_operation::device::DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K", + XdlOpType.DeviceConv2d_Xdl_CShuffle_Bias_SiLU: "ck::tensor_operation::device::DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K", XdlOpType.DeviceGroupedConv2D_Xdl_CShuffle_Bias_Relu: "ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle", XdlOpType.DeviceConvNdBwdDataNwcKxcNwk_Xdl: "ck::tensor_operation::device::DeviceConvNdBwdDataNwcKxcNwk_Xdl", XdlOpType.DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1: "ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1", diff --git a/python/aitemplate/utils/mk_ck_lib/generator.py b/python/aitemplate/utils/mk_ck_lib/generator.py index 2a20266b5..08f52c7b4 100644 --- a/python/aitemplate/utils/mk_ck_lib/generator.py +++ b/python/aitemplate/utils/mk_ck_lib/generator.py @@ -2429,6 +2429,13 @@ def GenerateTensorOp(manifest): library.TensorOperation.AddSigmoid, library.MemoryDataOperation.MemorySet, ) + # Conv2dBiasSiLU + CreateConv2dFwdOperator( + manifest, + library.Conv2dKind.GroupConv2dBiasRelu, + library.TensorOperation.AddSiLU, + library.MemoryDataOperation.MemorySet, + ) # TranposedConv2d CreateConv2dBwdOperator( manifest, diff --git a/python/aitemplate/utils/mk_ck_lib/library.py b/python/aitemplate/utils/mk_ck_lib/library.py index a3fdb1c00..bda623bae 100644 --- a/python/aitemplate/utils/mk_ck_lib/library.py +++ b/python/aitemplate/utils/mk_ck_lib/library.py @@ -228,6 +228,7 @@ class Conv2dKind(enum.Enum): Conv2dBiasRelu = auto() Conv2dBiasReluAdd = auto() Conv2dBiasSigmoid = auto() + Conv2dBiasSiLU = auto() GroupConv2dBiasRelu = auto() TransposedConv2d = auto() TransposedConv2dBiasRelu = auto() @@ -238,6 +239,7 @@ class Conv2dKind(enum.Enum): Conv2dKind.Conv2dBiasRelu: "conv2d_bias_relu", Conv2dKind.Conv2dBiasReluAdd: "conv2d_bias_relu_add", Conv2dKind.Conv2dBiasSigmoid: "conv2d_bias_sigmoid", + Conv2dKind.Conv2dBiasSiLU: "conv2d_bias_silu", Conv2dKind.GroupConv2dBiasRelu: "group_conv2d_bias_relu", Conv2dKind.TransposedConv2d: "transposed_conv2d", Conv2dKind.TransposedConv2dBiasRelu: "transposed_conv2d_bias_relu", @@ -283,6 +285,7 @@ class TensorOperation(enum.Enum): AddTanh = auto() AddHardswish = auto() AddSigmoid = auto() + AddSiLU = auto() AddReluAdd = auto() AddAddRelu = auto() AddSigmoidMul = auto() @@ -311,6 +314,7 @@ class TensorOperation(enum.Enum): TensorOperation.AddFastGelu: "ck::tensor_operation::element_wise::AddFastGelu", TensorOperation.AddTanh: "ck::tensor_operation::element_wise::AddTanh", TensorOperation.AddSigmoid: "ck::tensor_operation::element_wise::AddSigmoid", + TensorOperation.AddSiLU: "ck::tensor_operation::element_wise::AddSiLU", TensorOperation.AddHardswish: "ck::tensor_operation::element_wise::AddHardswish", TensorOperation.AddReluAdd: "ck::tensor_operation::element_wise::AddReluAdd", TensorOperation.AddAddRelu: "ck::tensor_operation::element_wise::AddAddRelu", @@ -340,6 +344,7 @@ class TensorOperation(enum.Enum): TensorOperation.AddFastGelu: "AFG", TensorOperation.AddTanh: "AT", TensorOperation.AddSigmoid: "AS", + TensorOperation.AddSiLU: "ASLU", TensorOperation.AddHardswish: "AH", TensorOperation.AddReluAdd: "ARA", TensorOperation.AddAddRelu: "AAR", diff --git a/tests/unittest/compiler/test_fuse_conv_elementwise.py b/tests/unittest/compiler/test_fuse_conv_elementwise.py index fcd21a6f1..725821bd6 100644 --- a/tests/unittest/compiler/test_fuse_conv_elementwise.py +++ b/tests/unittest/compiler/test_fuse_conv_elementwise.py @@ -24,146 +24,146 @@ from aitemplate.utils import shape_utils -@unittest.skipIf( - detect_target().name() == "cuda" and detect_target()._arch < "80", - "On CUDA, only supported on > SM80 arch.", -) -class FuseConvCase(unittest.TestCase): - def _build_conv2d( - self, - batch_dim, - CO, - HH, - WW, - CI, - filter_HW, - stride=1, - transpose=False, - ): - X = Tensor( - shape=[batch_dim, HH, WW, CI], - dtype="float16", - name="input_0", - is_input=True, - ) - - W = Tensor( - shape=[CO, filter_HW, filter_HW, CI], - dtype="float16", - name="input_1", - is_input=True, - ) - if transpose: - conv2d = ops.transposed_conv2d(stride=stride, pad=0)(X, W) - else: - conv2d = ops.conv2d(stride=stride, pad=0)(X, W) - - return conv2d - - def test_do_not_fuse_with_add_not_1d(self): - """ - We can't turn conv2d into conv2d_bias if the thing we do - an add with is not 1d. - """ - - # Keep IntImm batch here just not to mess with profiling strategy - B = [1] - batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") - CO, HH, WW, CI = 256, 28, 28, 128 - filter_HW = 3 - - bias = Tensor( - shape=[batch_dim, 26, 26, CO], dtype="float16", name="bias", is_input=True - ) - conv2d = self._build_conv2d(batch_dim, CO, HH, WW, CI, filter_HW) - output = ops.elementwise(FuncEnum.ADD)(bias, conv2d) - output._attrs["is_output"] = True - output._attrs["name"] = "output_0" - - target = detect_target() - module = compile_model( - output, target, "./tmp", "test_do_not_fuse_with_add_not_1d" - ) - - check_tensor = None - for tensor in module.debug_sorted_graph: - if tensor._attrs["name"] == "output_0": - check_tensor = tensor - break - self.assertIsNotNone(check_tensor) - self.assertEqual(len(check_tensor.src_ops()), 1) - src_op = list(check_tensor.src_ops())[0] - self.assertEqual(src_op._attrs["op"], "fused_elementwise") - - for b in B: - X_pt = torch.randn(b, CI, HH, WW).cuda().half() - W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() - Y_pt = torch.nn.functional.conv2d(X_pt, W_pt) - B_pt = torch.randn(Y_pt.size()).cuda().half() - Y_pt = Y_pt + B_pt - - x = X_pt.permute((0, 2, 3, 1)).contiguous() - w = W_pt.permute((0, 2, 3, 1)).contiguous() - b_pt = B_pt.permute((0, 2, 3, 1)).contiguous() - inputs = {"input_0": x, "input_1": w, "bias": b_pt} - - y = torch.empty([b, 26, 26, CO]).cuda().half() - module.run_with_tensors(inputs, [y]) - y_transpose = y.permute(0, 3, 1, 2) - self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) - - def test_do_not_fuse_transpose_with_add_not_1d(self): - """ - We can't turn transposed_conv2d into transposed_conv2d_bias if the thing we do - an add with is not 1d. - """ - B = [1] - CO, HH, WW, CI = 256, 28, 28, 256 - filter_HW = 2 - - batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") - bias = Tensor( - shape=[batch_dim, 56, 56, CO], dtype="float16", name="bias", is_input=True - ) - conv2d = self._build_conv2d( - batch_dim, CO, HH, WW, CI, filter_HW, stride=2, transpose=True - ) - output = ops.elementwise(FuncEnum.ADD)(bias, conv2d) - output._attrs["is_output"] = True - output._attrs["name"] = "output_0" - - target = detect_target() - module = compile_model( - output, target, "./tmp", "test_do_not_fuse_with_add_not_1d" - ) - - check_tensor = None - for tensor in module.debug_sorted_graph: - if tensor._attrs["name"] == "output_0": - check_tensor = tensor - break - self.assertIsNotNone(check_tensor) - self.assertEqual(len(check_tensor.src_ops()), 1) - src_op = list(check_tensor.src_ops())[0] - self.assertEqual(src_op._attrs["op"], "fused_elementwise") - - for b in B: - X_pt = torch.randn(b, CI, HH, WW).cuda().half() - W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() - W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() - Y_pt = torch.nn.functional.conv_transpose2d(X_pt, W_pt, stride=2) - B_pt = torch.randn(b, CO, 56, 56).cuda().half() - Y_pt = Y_pt + B_pt - - x = X_pt.permute((0, 2, 3, 1)).contiguous() - w = W_pt.permute((0, 2, 3, 1)).contiguous() - b_pt = B_pt.permute((0, 2, 3, 1)).contiguous() - inputs = {"input_0": x, "input_1": w, "bias": b_pt} - - y = torch.empty([b, 56, 56, CO]).cuda().half() - module.run_with_tensors(inputs, [y]) - y_transpose = y.permute(0, 3, 1, 2) - self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) +# @unittest.skipIf( +# detect_target().name() == "cuda" and detect_target()._arch < "80", +# "On CUDA, only supported on > SM80 arch.", +# ) +# class FuseConvCase(unittest.TestCase): +# def _build_conv2d( +# self, +# batch_dim, +# CO, +# HH, +# WW, +# CI, +# filter_HW, +# stride=1, +# transpose=False, +# ): +# X = Tensor( +# shape=[batch_dim, HH, WW, CI], +# dtype="float16", +# name="input_0", +# is_input=True, +# ) + +# W = Tensor( +# shape=[CO, filter_HW, filter_HW, CI], +# dtype="float16", +# name="input_1", +# is_input=True, +# ) +# if transpose: +# conv2d = ops.transposed_conv2d(stride=stride, pad=0)(X, W) +# else: +# conv2d = ops.conv2d(stride=stride, pad=0)(X, W) + +# return conv2d + +# def test_do_not_fuse_with_add_not_1d(self): +# """ +# We can't turn conv2d into conv2d_bias if the thing we do +# an add with is not 1d. +# """ + +# # Keep IntImm batch here just not to mess with profiling strategy +# B = [1] +# batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") +# CO, HH, WW, CI = 256, 28, 28, 128 +# filter_HW = 3 + +# bias = Tensor( +# shape=[batch_dim, 26, 26, CO], dtype="float16", name="bias", is_input=True +# ) +# conv2d = self._build_conv2d(batch_dim, CO, HH, WW, CI, filter_HW) +# output = ops.elementwise(FuncEnum.ADD)(bias, conv2d) +# output._attrs["is_output"] = True +# output._attrs["name"] = "output_0" + +# target = detect_target() +# module = compile_model( +# output, target, "./tmp", "test_do_not_fuse_with_add_not_1d" +# ) + +# check_tensor = None +# for tensor in module.debug_sorted_graph: +# if tensor._attrs["name"] == "output_0": +# check_tensor = tensor +# break +# self.assertIsNotNone(check_tensor) +# self.assertEqual(len(check_tensor.src_ops()), 1) +# src_op = list(check_tensor.src_ops())[0] +# self.assertEqual(src_op._attrs["op"], "fused_elementwise") + +# for b in B: +# X_pt = torch.randn(b, CI, HH, WW).cuda().half() +# W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() +# Y_pt = torch.nn.functional.conv2d(X_pt, W_pt) +# B_pt = torch.randn(Y_pt.size()).cuda().half() +# Y_pt = Y_pt + B_pt + +# x = X_pt.permute((0, 2, 3, 1)).contiguous() +# w = W_pt.permute((0, 2, 3, 1)).contiguous() +# b_pt = B_pt.permute((0, 2, 3, 1)).contiguous() +# inputs = {"input_0": x, "input_1": w, "bias": b_pt} + +# y = torch.empty([b, 26, 26, CO]).cuda().half() +# module.run_with_tensors(inputs, [y]) +# y_transpose = y.permute(0, 3, 1, 2) +# self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) + +# def test_do_not_fuse_transpose_with_add_not_1d(self): +# """ +# We can't turn transposed_conv2d into transposed_conv2d_bias if the thing we do +# an add with is not 1d. +# """ +# B = [1] +# CO, HH, WW, CI = 256, 28, 28, 256 +# filter_HW = 2 + +# batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") +# bias = Tensor( +# shape=[batch_dim, 56, 56, CO], dtype="float16", name="bias", is_input=True +# ) +# conv2d = self._build_conv2d( +# batch_dim, CO, HH, WW, CI, filter_HW, stride=2, transpose=True +# ) +# output = ops.elementwise(FuncEnum.ADD)(bias, conv2d) +# output._attrs["is_output"] = True +# output._attrs["name"] = "output_0" + +# target = detect_target() +# module = compile_model( +# output, target, "./tmp", "test_do_not_fuse_with_add_not_1d" +# ) + +# check_tensor = None +# for tensor in module.debug_sorted_graph: +# if tensor._attrs["name"] == "output_0": +# check_tensor = tensor +# break +# self.assertIsNotNone(check_tensor) +# self.assertEqual(len(check_tensor.src_ops()), 1) +# src_op = list(check_tensor.src_ops())[0] +# self.assertEqual(src_op._attrs["op"], "fused_elementwise") + +# for b in B: +# X_pt = torch.randn(b, CI, HH, WW).cuda().half() +# W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() +# W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() +# Y_pt = torch.nn.functional.conv_transpose2d(X_pt, W_pt, stride=2) +# B_pt = torch.randn(b, CO, 56, 56).cuda().half() +# Y_pt = Y_pt + B_pt + +# x = X_pt.permute((0, 2, 3, 1)).contiguous() +# w = W_pt.permute((0, 2, 3, 1)).contiguous() +# b_pt = B_pt.permute((0, 2, 3, 1)).contiguous() +# inputs = {"input_0": x, "input_1": w, "bias": b_pt} + +# y = torch.empty([b, 56, 56, CO]).cuda().half() +# module.run_with_tensors(inputs, [y]) +# y_transpose = y.permute(0, 3, 1, 2) +# self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) class FuseConvBiasCase(unittest.TestCase): @@ -190,108 +190,198 @@ def _build_conv2d_bias(self, batch_dim, CO, HH, WW, CI, filter_HW, decomposed): return conv2d_bias - def test_conv2d_bias(self): - # Keep IntImm batch here just not to mess with profiling strategy - B = [1] - batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") - CO, HH, WW, CI = 256, 28, 28, 128 - filter_HW = 3 - - conv2d_bias = self._build_conv2d_bias( - batch_dim, CO, HH, WW, CI, filter_HW, True - ) - conv2d_bias._attrs["is_output"] = True - conv2d_bias._attrs["name"] = "output_0" - - target = detect_target() - module = compile_model(conv2d_bias, target, "./tmp", "test_conv2d_bias") - - check_tensor = None - for tensor in module.debug_sorted_graph: - if tensor._attrs["name"] == "output_0": - check_tensor = tensor - break - self.assertIsNotNone(check_tensor) - self.assertEqual(len(check_tensor.src_ops()), 1) - src_op = list(check_tensor.src_ops())[0] - self.assertEqual(src_op._attrs["op"], "conv2d_bias") - - for b in B: - X_pt = torch.randn(b, CI, HH, WW).cuda().half() - W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() - B_pt = torch.randn(1, CO, 1, 1).cuda().half() - Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=1) - Y_pt = Y_pt + B_pt - - x = X_pt.permute((0, 2, 3, 1)).contiguous() - w = W_pt.permute((0, 2, 3, 1)).contiguous() - inputs = {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()} - - y = torch.empty([b, HH, WW, CO]).cuda().half() - module.run_with_tensors(inputs, [y]) - y_transpose = y.permute(0, 3, 1, 2) - self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) - - def test_conv2d_bias_add_relu(self): - B = [1] - batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") - CO, HH, WW, CI = 256, 28, 28, 128 - filter_HW = 3 - - conv2d_bias = self._build_conv2d_bias( - batch_dim, CO, HH, WW, CI, filter_HW, False - ) - D = Tensor( - shape=[batch_dim, HH, WW, CO], - dtype="float16", - name="input_3", - is_input=True, - ) - conv2d_bias_add = ops.elementwise(FuncEnum.ADD)(conv2d_bias, D) - conv2d_bias_add_relu = ops.elementwise(FuncEnum.RELU)(conv2d_bias_add) - conv2d_bias_add_relu._attrs["is_output"] = True - conv2d_bias_add_relu._attrs["name"] = "output_0" - - target = detect_target() - module = compile_model( - conv2d_bias_add_relu, target, "./tmp", "test_conv2d_bias_add_relu" - ) - - check_tensor = None - for tensor in module.debug_sorted_graph: - if tensor._attrs["name"] == "output_0": - check_tensor = tensor - break - self.assertIsNotNone(check_tensor) - self.assertEqual(len(check_tensor.src_ops()), 1) - src_op = list(check_tensor.src_ops())[0] - self.assertEqual(src_op._attrs["op"], "conv2d_bias_add_relu") - - for b in B: - X_pt = torch.randn(b, CI, HH, WW).cuda().half() - W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() - B_pt = torch.randn(1, CO, 1, 1).cuda().half() - D_pt = torch.randn(b, CO, HH, WW).cuda().half() - Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=1) - Y_pt = Y_pt + B_pt + D_pt - Y_pt = torch.nn.functional.relu(Y_pt) - - x = X_pt.permute((0, 2, 3, 1)).contiguous() - w = W_pt.permute((0, 2, 3, 1)).contiguous() - d = D_pt.permute((0, 2, 3, 1)).contiguous() - inputs = { - "input_0": x, - "input_1": w, - "input_2": B_pt.squeeze(), - "input_3": d, - } - - y = torch.empty([b, HH, WW, CO]).cuda().half() - module.run_with_tensors(inputs, [y]) - y_transpose = y.permute(0, 3, 1, 2) - self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) - - def test_conv2d_bias_relu(self): + # def test_conv2d_bias(self): + # # Keep IntImm batch here just not to mess with profiling strategy + # B = [1] + # batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") + # CO, HH, WW, CI = 256, 28, 28, 128 + # filter_HW = 3 + + # conv2d_bias = self._build_conv2d_bias( + # batch_dim, CO, HH, WW, CI, filter_HW, True + # ) + # conv2d_bias._attrs["is_output"] = True + # conv2d_bias._attrs["name"] = "output_0" + + # target = detect_target() + # module = compile_model(conv2d_bias, target, "./tmp", "test_conv2d_bias") + + # check_tensor = None + # for tensor in module.debug_sorted_graph: + # if tensor._attrs["name"] == "output_0": + # check_tensor = tensor + # break + # self.assertIsNotNone(check_tensor) + # self.assertEqual(len(check_tensor.src_ops()), 1) + # src_op = list(check_tensor.src_ops())[0] + # self.assertEqual(src_op._attrs["op"], "conv2d_bias") + + # for b in B: + # X_pt = torch.randn(b, CI, HH, WW).cuda().half() + # W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() + # B_pt = torch.randn(1, CO, 1, 1).cuda().half() + # Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=1) + # Y_pt = Y_pt + B_pt + + # x = X_pt.permute((0, 2, 3, 1)).contiguous() + # w = W_pt.permute((0, 2, 3, 1)).contiguous() + # inputs = {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()} + + # y = torch.empty([b, HH, WW, CO]).cuda().half() + # module.run_with_tensors(inputs, [y]) + # y_transpose = y.permute(0, 3, 1, 2) + # self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) + + # def test_conv2d_bias_add_relu(self): + # B = [1] + # batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") + # CO, HH, WW, CI = 256, 28, 28, 128 + # filter_HW = 3 + + # conv2d_bias = self._build_conv2d_bias( + # batch_dim, CO, HH, WW, CI, filter_HW, False + # ) + # D = Tensor( + # shape=[batch_dim, HH, WW, CO], + # dtype="float16", + # name="input_3", + # is_input=True, + # ) + # conv2d_bias_add = ops.elementwise(FuncEnum.ADD)(conv2d_bias, D) + # conv2d_bias_add_relu = ops.elementwise(FuncEnum.RELU)(conv2d_bias_add) + # conv2d_bias_add_relu._attrs["is_output"] = True + # conv2d_bias_add_relu._attrs["name"] = "output_0" + + # target = detect_target() + # module = compile_model( + # conv2d_bias_add_relu, target, "./tmp", "test_conv2d_bias_add_relu" + # ) + + # check_tensor = None + # for tensor in module.debug_sorted_graph: + # if tensor._attrs["name"] == "output_0": + # check_tensor = tensor + # break + # self.assertIsNotNone(check_tensor) + # self.assertEqual(len(check_tensor.src_ops()), 1) + # src_op = list(check_tensor.src_ops())[0] + # self.assertEqual(src_op._attrs["op"], "conv2d_bias_add_relu") + + # for b in B: + # X_pt = torch.randn(b, CI, HH, WW).cuda().half() + # W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() + # B_pt = torch.randn(1, CO, 1, 1).cuda().half() + # D_pt = torch.randn(b, CO, HH, WW).cuda().half() + # Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=1) + # Y_pt = Y_pt + B_pt + D_pt + # Y_pt = torch.nn.functional.relu(Y_pt) + + # x = X_pt.permute((0, 2, 3, 1)).contiguous() + # w = W_pt.permute((0, 2, 3, 1)).contiguous() + # d = D_pt.permute((0, 2, 3, 1)).contiguous() + # inputs = { + # "input_0": x, + # "input_1": w, + # "input_2": B_pt.squeeze(), + # "input_3": d, + # } + + # y = torch.empty([b, HH, WW, CO]).cuda().half() + # module.run_with_tensors(inputs, [y]) + # y_transpose = y.permute(0, 3, 1, 2) + # self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) + + # def test_conv2d_bias_relu(self): + # B = [1] + # batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") + # CO, HH, WW, CI = 256, 28, 28, 128 + # filter_HW = 3 + + # conv2d_bias = self._build_conv2d_bias( + # batch_dim, CO, HH, WW, CI, filter_HW, False + # ) + # conv2d_bias_relu = ops.elementwise(FuncEnum.RELU)(conv2d_bias) + # conv2d_bias_relu._attrs["is_output"] = True + # conv2d_bias_relu._attrs["name"] = "output_0" + + # target = detect_target() + # module = compile_model( + # conv2d_bias_relu, target, "./tmp", "test_conv2d_bias_relu" + # ) + + # check_tensor = None + # for tensor in module.debug_sorted_graph: + # if tensor._attrs["name"] == "output_0": + # check_tensor = tensor + # break + # self.assertIsNotNone(check_tensor) + # self.assertEqual(len(check_tensor.src_ops()), 1) + # src_op = list(check_tensor.src_ops())[0] + # self.assertEqual(src_op._attrs["op"], "conv2d_bias_relu") + + # for b in B: + # X_pt = torch.randn(b, CI, HH, WW).cuda().half() + # W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() + # B_pt = torch.randn(1, CO, 1, 1).cuda().half() + # Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=1) + # Y_pt = Y_pt + B_pt + # Y_pt = torch.nn.functional.relu(Y_pt) + + # x = X_pt.permute((0, 2, 3, 1)).contiguous() + # w = W_pt.permute((0, 2, 3, 1)).contiguous() + # inputs = {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()} + + # y = torch.empty([b, HH, WW, CO]).cuda().half() + # module.run_with_tensors(inputs, [y]) + # y_transpose = y.permute(0, 3, 1, 2) + # self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) + + # def test_conv2d_bias_sigmoid(self): + # B = [1] + # batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") + # CO, HH, WW, CI = 256, 28, 28, 128 + # filter_HW = 3 + + # conv2d_bias = self._build_conv2d_bias( + # batch_dim, CO, HH, WW, CI, filter_HW, False + # ) + # conv2d_bias_sigmoid = ops.elementwise(FuncEnum.SIGMOID)(conv2d_bias) + # conv2d_bias_sigmoid._attrs["is_output"] = True + # conv2d_bias_sigmoid._attrs["name"] = "output_0" + + # target = detect_target() + # module = compile_model( + # conv2d_bias_sigmoid, target, "./tmp", "test_conv2d_bias_sigmoid" + # ) + + # check_tensor = None + # for tensor in module.debug_sorted_graph: + # if tensor._attrs["name"] == "output_0": + # check_tensor = tensor + # break + # self.assertIsNotNone(check_tensor) + # self.assertEqual(len(check_tensor.src_ops()), 1) + # src_op = list(check_tensor.src_ops())[0] + # self.assertEqual(src_op._attrs["op"], "conv2d_bias_sigmoid") + + # for b in B: + # X_pt = torch.randn(b, CI, HH, WW).cuda().half() + # W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() + # B_pt = torch.randn(1, CO, 1, 1).cuda().half() + # Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=1) + # Y_pt = Y_pt + B_pt + # Y_pt = torch.sigmoid(Y_pt) + + # x = X_pt.permute((0, 2, 3, 1)).contiguous() + # w = W_pt.permute((0, 2, 3, 1)).contiguous() + # inputs = {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()} + + # y = torch.empty([b, HH, WW, CO]).cuda().half() + # module.run_with_tensors(inputs, [y]) + # y_transpose = y.permute(0, 3, 1, 2) + # self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) + + def test_conv2d_bias_silu(self): B = [1] batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") CO, HH, WW, CI = 256, 28, 28, 128 @@ -300,13 +390,13 @@ def test_conv2d_bias_relu(self): conv2d_bias = self._build_conv2d_bias( batch_dim, CO, HH, WW, CI, filter_HW, False ) - conv2d_bias_relu = ops.elementwise(FuncEnum.RELU)(conv2d_bias) - conv2d_bias_relu._attrs["is_output"] = True - conv2d_bias_relu._attrs["name"] = "output_0" + conv2d_bias_silu = ops.elementwise(FuncEnum.SILU)(conv2d_bias) + conv2d_bias_silu._attrs["is_output"] = True + conv2d_bias_silu._attrs["name"] = "output_0" target = detect_target() module = compile_model( - conv2d_bias_relu, target, "./tmp", "test_conv2d_bias_relu" + conv2d_bias_silu, target, "./tmp", "test_conv2d_bias_silu" ) check_tensor = None @@ -317,7 +407,7 @@ def test_conv2d_bias_relu(self): self.assertIsNotNone(check_tensor) self.assertEqual(len(check_tensor.src_ops()), 1) src_op = list(check_tensor.src_ops())[0] - self.assertEqual(src_op._attrs["op"], "conv2d_bias_relu") + self.assertEqual(src_op._attrs["op"], "conv2d_bias_silu") for b in B: X_pt = torch.randn(b, CI, HH, WW).cuda().half() @@ -325,7 +415,7 @@ def test_conv2d_bias_relu(self): B_pt = torch.randn(1, CO, 1, 1).cuda().half() Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=1) Y_pt = Y_pt + B_pt - Y_pt = torch.nn.functional.relu(Y_pt) + Y_pt = torch.nn.functional.silu(Y_pt) x = X_pt.permute((0, 2, 3, 1)).contiguous() w = W_pt.permute((0, 2, 3, 1)).contiguous() @@ -335,360 +425,315 @@ def test_conv2d_bias_relu(self): module.run_with_tensors(inputs, [y]) y_transpose = y.permute(0, 3, 1, 2) self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) - - def test_conv2d_bias_sigmoid(self): - B = [1] - batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") - CO, HH, WW, CI = 256, 28, 28, 128 - filter_HW = 3 - - conv2d_bias = self._build_conv2d_bias( - batch_dim, CO, HH, WW, CI, filter_HW, False - ) - conv2d_bias_sigmoid = ops.elementwise(FuncEnum.SIGMOID)(conv2d_bias) - conv2d_bias_sigmoid._attrs["is_output"] = True - conv2d_bias_sigmoid._attrs["name"] = "output_0" - - target = detect_target() - module = compile_model( - conv2d_bias_sigmoid, target, "./tmp", "test_conv2d_bias_sigmoid" - ) - - check_tensor = None - for tensor in module.debug_sorted_graph: - if tensor._attrs["name"] == "output_0": - check_tensor = tensor - break - self.assertIsNotNone(check_tensor) - self.assertEqual(len(check_tensor.src_ops()), 1) - src_op = list(check_tensor.src_ops())[0] - self.assertEqual(src_op._attrs["op"], "conv2d_bias_sigmoid") - - for b in B: - X_pt = torch.randn(b, CI, HH, WW).cuda().half() - W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() - B_pt = torch.randn(1, CO, 1, 1).cuda().half() - Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=1) - Y_pt = Y_pt + B_pt - Y_pt = torch.sigmoid(Y_pt) - - x = X_pt.permute((0, 2, 3, 1)).contiguous() - w = W_pt.permute((0, 2, 3, 1)).contiguous() - inputs = {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()} - - y = torch.empty([b, HH, WW, CO]).cuda().half() - module.run_with_tensors(inputs, [y]) - y_transpose = y.permute(0, 3, 1, 2) - self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) - - def test_conv2d_bias_add_fusion(self): - target = detect_target() - if target.name() == "rocm": - return - - B = [1] - batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") - CO, HH, WW, CI = 256, 28, 28, 128 - filter_HW = 3 - R = Tensor( - shape=[batch_dim, HH, WW, CO], - dtype="float16", - name="residual", - is_input=True, - ) - - conv2d_bias = self._build_conv2d_bias( - batch_dim, CO, HH, WW, CI, filter_HW, False - ) - conv2d_bias_add = ops.elementwise(FuncEnum.ADD)(conv2d_bias, R) - conv2d_bias_add._attrs["is_output"] = True - conv2d_bias_add._attrs["name"] = "output_0" - - module = compile_model(conv2d_bias_add, target, "./tmp", "test_conv2d_bias_add") - - check_tensor = None - for tensor in module.debug_sorted_graph: - if tensor._attrs["name"] == "output_0": - check_tensor = tensor - break - self.assertIsNotNone(check_tensor) - self.assertEqual(len(check_tensor.src_ops()), 1) - src_op = list(check_tensor.src_ops())[0] - self.assertEqual(src_op._attrs["op"], "conv2d_bias_add_identity") - - for b in B: - X_pt = torch.randn(b, CI, HH, WW).cuda().half() - W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() - B_pt = torch.randn(1, CO, 1, 1).cuda().half() - R_pt = torch.randn(b, CO, HH, WW).cuda().half() - Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=1) - Y_pt = Y_pt + B_pt + R_pt - - x = X_pt.permute((0, 2, 3, 1)).contiguous() - w = W_pt.permute((0, 2, 3, 1)).contiguous() - r = R_pt.permute((0, 2, 3, 1)).contiguous() - inputs = { - "input_0": x, - "input_1": w, - "input_2": B_pt.squeeze(), - "residual": r, - } - - y = torch.empty([b, HH, WW, CO]).cuda().half() - module.run_with_tensors(inputs, [y]) - y_transpose = y.permute(0, 3, 1, 2) - self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) - - def test_conv2d_bias_add_do_not_fuse(self): - B = [1] - batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") - CO, HH, WW, CI = 256, 28, 28, 128 - filter_HW = 3 - R = Tensor( - shape=[batch_dim, 1, WW, CO], - dtype="float16", - name="residual", - is_input=True, - ) - - conv2d_bias = self._build_conv2d_bias( - batch_dim, CO, HH, WW, CI, filter_HW, False - ) - conv2d_bias_add = ops.elementwise(FuncEnum.ADD)(conv2d_bias, R) - conv2d_bias_add._attrs["is_output"] = True - conv2d_bias_add._attrs["name"] = "output_0" - - target = detect_target() - module = compile_model(conv2d_bias_add, target, "./tmp", "test_conv2d_bias_add") - - graph = module.debug_sorted_graph - - self.assertFalse(graph_has_op(graph, "conv2d_bias_add_identity")) - self.assertTrue(graph_has_op(graph, "conv2d_bias")) - - -@unittest.skipIf(detect_target().name() == "rocm", "Not supported by ROCM.") -class FuseConvBiasFewChannelCase(unittest.TestCase): - def test_conv2d_bias_relu_few_channels(self): - HH, WW, CI, CO, batch = 224, 224, 4, 64, 4 - KK = 7 - stride = 2 - pad = 3 - target = detect_target() - X = Tensor( - shape=[batch, HH, WW, CI], - dtype="float16", - name="input_0", - is_input=True, - ) - W = Tensor( - shape=[CO, KK, KK, CI], dtype="float16", name="input_1", is_input=True - ) - B = Tensor(shape=[CO], dtype="float16", name="input_2", is_input=True) - OP = ops.conv2d_bias_few_channels(stride=stride, pad=pad, dilate=1) - Y = OP(X, W, B) - Y = ops.elementwise(FuncEnum.RELU)(Y) - Y._attrs["name"] = "output_0" - Y._attrs["is_output"] = True - - module = compile_model(Y, target, "./tmp", "test_conv_bias_relu_few_channels") - - check_tensor = None - for tensor in module.debug_sorted_graph: - if tensor._attrs["name"] == "output_0": - check_tensor = tensor - break - self.assertIsNotNone(check_tensor) - self.assertEqual(len(check_tensor.src_ops()), 1) - src_op = list(check_tensor.src_ops())[0] - self.assertEqual(src_op._attrs["op"], "conv2d_bias_relu_few_channels") - - X_pt = torch.randn(batch, CI, HH, WW).cuda().half() - W_pt = torch.randn(CO, CI, KK, KK).cuda().half() - B_pt = torch.randn(1, CO, 1, 1).cuda().half() - Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=pad, stride=stride) - Y_pt = Y_pt + B_pt - Y_pt = torch.nn.functional.relu(Y_pt) - x = X_pt.permute((0, 2, 3, 1)).contiguous() - w = W_pt.permute((0, 2, 3, 1)).contiguous() - inputs = {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()} - y = torch.empty([batch, HH // stride, WW // stride, CO]).cuda().half() - module.run_with_tensors(inputs, [y]) - y_transpose = y.permute((0, 3, 1, 2)) - self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-2, rtol=1e-2)) - - -@unittest.skipIf(detect_target().name() == "rocm", "Not supported by ROCM.") -@unittest.skipIf( - detect_target().name() == "cuda" and int(detect_target()._arch) < 80, - "Not supported by CUDA < SM80.", -) -class FuseTransposedConvCase(unittest.TestCase): - def _build_transposedConv2d_bias_relu_chain( - self, batch, HH, WW, CI, CO, filter_HW, stride, pad, dilate, depth, decomposed - ): - X = Tensor( - shape=[batch, HH, WW, CI], - dtype="float16", - name="input_0", - is_input=True, - ) - W = Tensor( - shape=[CO, filter_HW, filter_HW, CI], - dtype="float16", - name="input_1", - is_input=True, - ) - B = Tensor(shape=[CO], dtype="float16", name="input_2", is_input=True) - if decomposed: - transposed_conv2d = ops.transposed_conv2d( - stride=stride, pad=pad, dilate=dilate - )(X, W) - if depth == 0: - return transposed_conv2d - - transposed_conv2d_bias = ops.elementwise(FuncEnum.ADD)(transposed_conv2d, B) - else: - transposed_conv2d_bias = ops.transposed_conv2d_bias( - stride=stride, pad=pad, dilate=dilate - )(X, W, B) - if depth == 0: - raise RuntimeError("depth == 0 needs to be decomposed.") - if depth == 1: - return transposed_conv2d_bias - - transposed_conv2d_bias_relu = ops.elementwise(FuncEnum.RELU)( - transposed_conv2d_bias - ) - if depth == 2: - return transposed_conv2d_bias_relu - - raise RuntimeError(f"depth should be <= 2, unknown depth {depth}") - - def _test_transposed_conv2d_bias(self, decomposed): - batch = 4 - HH, WW, CI, CO = 14, 14, 256, 256 - filter_HW = 2 - stride = 2 - pad = 0 - dilate = 1 - transposed_conv2d_bias = self._build_transposedConv2d_bias_relu_chain( - batch, - HH, - WW, - CI, - CO, - filter_HW, - stride, - pad, - dilate, - 1, - decomposed=decomposed, - ) - transposed_conv2d_bias._attrs["is_output"] = True - transposed_conv2d_bias._attrs["name"] = "output_0" - - target = detect_target() - module = compile_model( - transposed_conv2d_bias, - target, - "./tmp", - f"fuse_transpose_conv2d_bias_{decomposed}", - ) - - check_tensor = None - for tensor in module.debug_sorted_graph: - if tensor._attrs["name"] == "output_0": - check_tensor = tensor - break - self.assertIsNotNone(check_tensor) - self.assertEqual(len(check_tensor.src_ops()), 1) - src_op = list(check_tensor.src_ops())[0] - self.assertEqual(src_op._attrs["op"], "transposed_conv2d_bias") - - X_pt = torch.randn(batch, CI, HH, WW).cuda().half() - W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() - B_pt = torch.randn(1, CO, 1, 1).cuda().half() - Y_pt = torch.nn.functional.conv_transpose2d( - X_pt, W_pt, padding=pad, stride=stride - ) - Y_pt = Y_pt + B_pt - - x = X_pt.permute((0, 2, 3, 1)).contiguous() - w = W_pt.permute((0, 2, 3, 1)).contiguous() - y = torch.empty([batch, 28, 28, CO]).cuda().half() - module.run_with_tensors( - {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()}, [y] - ) - y_transpose = y.permute((0, 3, 1, 2)) - self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) - - def test_transposed_conv2d_bias(self): - self._test_transposed_conv2d_bias(True) - self._test_transposed_conv2d_bias(False) - - def _test_transposed_conv2d_bias_relu(self, decomposed): - batch = 4 - HH, WW, CI, CO = 14, 14, 256, 256 - filter_HW = 2 - stride = 2 - pad = 0 - dilate = 1 - transposed_conv2d_bias_relu = self._build_transposedConv2d_bias_relu_chain( - batch, - HH, - WW, - CI, - CO, - filter_HW, - stride, - pad, - dilate, - 2, - decomposed=decomposed, - ) - transposed_conv2d_bias_relu._attrs["is_output"] = True - transposed_conv2d_bias_relu._attrs["name"] = "output_0" - - target = detect_target() - module = compile_model( - transposed_conv2d_bias_relu, - target, - "./tmp", - f"fuse_transpose_conv2d_bias_relu_{decomposed}", - ) - - check_tensor = None - for tensor in module.debug_sorted_graph: - if tensor._attrs["name"] == "output_0": - check_tensor = tensor - break - self.assertIsNotNone(check_tensor) - self.assertEqual(len(check_tensor.src_ops()), 1) - src_op = list(check_tensor.src_ops())[0] - self.assertEqual(src_op._attrs["op"], "transposed_conv2d_bias_relu") - - X_pt = torch.randn(batch, CI, HH, WW).cuda().half() - W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() - B_pt = torch.randn(1, CO, 1, 1).cuda().half() - Y_pt = torch.nn.functional.conv_transpose2d( - X_pt, W_pt, padding=pad, stride=stride - ) - Y_pt = Y_pt + B_pt - Y_pt = torch.relu(Y_pt) - - x = X_pt.permute((0, 2, 3, 1)).contiguous() - w = W_pt.permute((0, 2, 3, 1)).contiguous() - y = torch.empty([batch, 28, 28, CO]).cuda().half() - module.run_with_tensors( - {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()}, [y] - ) - y_transpose = y.permute((0, 3, 1, 2)) - self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) - - def test_transposed_conv2d_bias_relu(self): - self._test_transposed_conv2d_bias_relu(True) - self._test_transposed_conv2d_bias_relu(False) + +# def test_conv2d_bias_add_fusion(self): +# target = detect_target() +# if target.name() == "rocm": +# return + +# B = [1] +# batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") +# CO, HH, WW, CI = 256, 28, 28, 128 +# filter_HW = 3 +# R = Tensor( +# shape=[batch_dim, HH, WW, CO], +# dtype="float16", +# name="residual", +# is_input=True, +# ) + +# conv2d_bias = self._build_conv2d_bias( +# batch_dim, CO, HH, WW, CI, filter_HW, False +# ) +# conv2d_bias_add = ops.elementwise(FuncEnum.ADD)(conv2d_bias, R) +# conv2d_bias_add._attrs["is_output"] = True +# conv2d_bias_add._attrs["name"] = "output_0" + +# module = compile_model(conv2d_bias_add, target, "./tmp", "test_conv2d_bias_add") + +# check_tensor = None +# for tensor in module.debug_sorted_graph: +# if tensor._attrs["name"] == "output_0": +# check_tensor = tensor +# break +# self.assertIsNotNone(check_tensor) +# self.assertEqual(len(check_tensor.src_ops()), 1) +# src_op = list(check_tensor.src_ops())[0] +# self.assertEqual(src_op._attrs["op"], "conv2d_bias_add_identity") + +# for b in B: +# X_pt = torch.randn(b, CI, HH, WW).cuda().half() +# W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() +# B_pt = torch.randn(1, CO, 1, 1).cuda().half() +# R_pt = torch.randn(b, CO, HH, WW).cuda().half() +# Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=1) +# Y_pt = Y_pt + B_pt + R_pt + +# x = X_pt.permute((0, 2, 3, 1)).contiguous() +# w = W_pt.permute((0, 2, 3, 1)).contiguous() +# r = R_pt.permute((0, 2, 3, 1)).contiguous() +# inputs = { +# "input_0": x, +# "input_1": w, +# "input_2": B_pt.squeeze(), +# "residual": r, +# } + +# y = torch.empty([b, HH, WW, CO]).cuda().half() +# module.run_with_tensors(inputs, [y]) +# y_transpose = y.permute(0, 3, 1, 2) +# self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) + +# def test_conv2d_bias_add_do_not_fuse(self): +# B = [1] +# batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") +# CO, HH, WW, CI = 256, 28, 28, 128 +# filter_HW = 3 +# R = Tensor( +# shape=[batch_dim, 1, WW, CO], +# dtype="float16", +# name="residual", +# is_input=True, +# ) + +# conv2d_bias = self._build_conv2d_bias( +# batch_dim, CO, HH, WW, CI, filter_HW, False +# ) +# conv2d_bias_add = ops.elementwise(FuncEnum.ADD)(conv2d_bias, R) +# conv2d_bias_add._attrs["is_output"] = True +# conv2d_bias_add._attrs["name"] = "output_0" + +# target = detect_target() +# module = compile_model(conv2d_bias_add, target, "./tmp", "test_conv2d_bias_add") + +# graph = module.debug_sorted_graph + +# self.assertFalse(graph_has_op(graph, "conv2d_bias_add_identity")) +# self.assertTrue(graph_has_op(graph, "conv2d_bias")) + + +# @unittest.skipIf(detect_target().name() == "rocm", "Not supported by ROCM.") +# class FuseConvBiasFewChannelCase(unittest.TestCase): +# def test_conv2d_bias_relu_few_channels(self): +# HH, WW, CI, CO, batch = 224, 224, 4, 64, 4 +# KK = 7 +# stride = 2 +# pad = 3 +# target = detect_target() +# X = Tensor( +# shape=[batch, HH, WW, CI], +# dtype="float16", +# name="input_0", +# is_input=True, +# ) +# W = Tensor( +# shape=[CO, KK, KK, CI], dtype="float16", name="input_1", is_input=True +# ) +# B = Tensor(shape=[CO], dtype="float16", name="input_2", is_input=True) +# OP = ops.conv2d_bias_few_channels(stride=stride, pad=pad, dilate=1) +# Y = OP(X, W, B) +# Y = ops.elementwise(FuncEnum.RELU)(Y) +# Y._attrs["name"] = "output_0" +# Y._attrs["is_output"] = True + +# module = compile_model(Y, target, "./tmp", "test_conv_bias_relu_few_channels") + +# check_tensor = None +# for tensor in module.debug_sorted_graph: +# if tensor._attrs["name"] == "output_0": +# check_tensor = tensor +# break +# self.assertIsNotNone(check_tensor) +# self.assertEqual(len(check_tensor.src_ops()), 1) +# src_op = list(check_tensor.src_ops())[0] +# self.assertEqual(src_op._attrs["op"], "conv2d_bias_relu_few_channels") + +# X_pt = torch.randn(batch, CI, HH, WW).cuda().half() +# W_pt = torch.randn(CO, CI, KK, KK).cuda().half() +# B_pt = torch.randn(1, CO, 1, 1).cuda().half() +# Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=pad, stride=stride) +# Y_pt = Y_pt + B_pt +# Y_pt = torch.nn.functional.relu(Y_pt) +# x = X_pt.permute((0, 2, 3, 1)).contiguous() +# w = W_pt.permute((0, 2, 3, 1)).contiguous() +# inputs = {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()} +# y = torch.empty([batch, HH // stride, WW // stride, CO]).cuda().half() +# module.run_with_tensors(inputs, [y]) +# y_transpose = y.permute((0, 3, 1, 2)) +# self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-2, rtol=1e-2)) + + +# @unittest.skipIf(detect_target().name() == "rocm", "Not supported by ROCM.") +# @unittest.skipIf( +# detect_target().name() == "cuda" and int(detect_target()._arch) < 80, +# "Not supported by CUDA < SM80.", +# ) +# class FuseTransposedConvCase(unittest.TestCase): +# def _build_transposedConv2d_bias_relu_chain( +# self, batch, HH, WW, CI, CO, filter_HW, stride, pad, dilate, depth, decomposed +# ): +# X = Tensor( +# shape=[batch, HH, WW, CI], +# dtype="float16", +# name="input_0", +# is_input=True, +# ) +# W = Tensor( +# shape=[CO, filter_HW, filter_HW, CI], +# dtype="float16", +# name="input_1", +# is_input=True, +# ) +# B = Tensor(shape=[CO], dtype="float16", name="input_2", is_input=True) +# if decomposed: +# transposed_conv2d = ops.transposed_conv2d( +# stride=stride, pad=pad, dilate=dilate +# )(X, W) +# if depth == 0: +# return transposed_conv2d + +# transposed_conv2d_bias = ops.elementwise(FuncEnum.ADD)(transposed_conv2d, B) +# else: +# transposed_conv2d_bias = ops.transposed_conv2d_bias( +# stride=stride, pad=pad, dilate=dilate +# )(X, W, B) +# if depth == 0: +# raise RuntimeError("depth == 0 needs to be decomposed.") +# if depth == 1: +# return transposed_conv2d_bias + +# transposed_conv2d_bias_relu = ops.elementwise(FuncEnum.RELU)( +# transposed_conv2d_bias +# ) +# if depth == 2: +# return transposed_conv2d_bias_relu + +# raise RuntimeError(f"depth should be <= 2, unknown depth {depth}") + +# def _test_transposed_conv2d_bias(self, decomposed): +# batch = 4 +# HH, WW, CI, CO = 14, 14, 256, 256 +# filter_HW = 2 +# stride = 2 +# pad = 0 +# dilate = 1 +# transposed_conv2d_bias = self._build_transposedConv2d_bias_relu_chain( +# batch, +# HH, +# WW, +# CI, +# CO, +# filter_HW, +# stride, +# pad, +# dilate, +# 1, +# decomposed=decomposed, +# ) +# transposed_conv2d_bias._attrs["is_output"] = True +# transposed_conv2d_bias._attrs["name"] = "output_0" + +# target = detect_target() +# module = compile_model( +# transposed_conv2d_bias, +# target, +# "./tmp", +# f"fuse_transpose_conv2d_bias_{decomposed}", +# ) + +# check_tensor = None +# for tensor in module.debug_sorted_graph: +# if tensor._attrs["name"] == "output_0": +# check_tensor = tensor +# break +# self.assertIsNotNone(check_tensor) +# self.assertEqual(len(check_tensor.src_ops()), 1) +# src_op = list(check_tensor.src_ops())[0] +# self.assertEqual(src_op._attrs["op"], "transposed_conv2d_bias") + +# X_pt = torch.randn(batch, CI, HH, WW).cuda().half() +# W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() +# B_pt = torch.randn(1, CO, 1, 1).cuda().half() +# Y_pt = torch.nn.functional.conv_transpose2d( +# X_pt, W_pt, padding=pad, stride=stride +# ) +# Y_pt = Y_pt + B_pt + +# x = X_pt.permute((0, 2, 3, 1)).contiguous() +# w = W_pt.permute((0, 2, 3, 1)).contiguous() +# y = torch.empty([batch, 28, 28, CO]).cuda().half() +# module.run_with_tensors( +# {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()}, [y] +# ) +# y_transpose = y.permute((0, 3, 1, 2)) +# self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) + +# def test_transposed_conv2d_bias(self): +# self._test_transposed_conv2d_bias(True) +# self._test_transposed_conv2d_bias(False) + +# def _test_transposed_conv2d_bias_relu(self, decomposed): +# batch = 4 +# HH, WW, CI, CO = 14, 14, 256, 256 +# filter_HW = 2 +# stride = 2 +# pad = 0 +# dilate = 1 +# transposed_conv2d_bias_relu = self._build_transposedConv2d_bias_relu_chain( +# batch, +# HH, +# WW, +# CI, +# CO, +# filter_HW, +# stride, +# pad, +# dilate, +# 2, +# decomposed=decomposed, +# ) +# transposed_conv2d_bias_relu._attrs["is_output"] = True +# transposed_conv2d_bias_relu._attrs["name"] = "output_0" + +# target = detect_target() +# module = compile_model( +# transposed_conv2d_bias_relu, +# target, +# "./tmp", +# f"fuse_transpose_conv2d_bias_relu_{decomposed}", +# ) + +# check_tensor = None +# for tensor in module.debug_sorted_graph: +# if tensor._attrs["name"] == "output_0": +# check_tensor = tensor +# break +# self.assertIsNotNone(check_tensor) +# self.assertEqual(len(check_tensor.src_ops()), 1) +# src_op = list(check_tensor.src_ops())[0] +# self.assertEqual(src_op._attrs["op"], "transposed_conv2d_bias_relu") + +# X_pt = torch.randn(batch, CI, HH, WW).cuda().half() +# W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() +# B_pt = torch.randn(1, CO, 1, 1).cuda().half() +# Y_pt = torch.nn.functional.conv_transpose2d( +# X_pt, W_pt, padding=pad, stride=stride +# ) +# Y_pt = Y_pt + B_pt +# Y_pt = torch.relu(Y_pt) + +# x = X_pt.permute((0, 2, 3, 1)).contiguous() +# w = W_pt.permute((0, 2, 3, 1)).contiguous() +# y = torch.empty([batch, 28, 28, CO]).cuda().half() +# module.run_with_tensors( +# {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()}, [y] +# ) +# y_transpose = y.permute((0, 3, 1, 2)) +# self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) + +# def test_transposed_conv2d_bias_relu(self): +# self._test_transposed_conv2d_bias_relu(True) +# self._test_transposed_conv2d_bias_relu(False) if __name__ == "__main__": From f945df6348a89ffef918dec5a023fac17346129f Mon Sep 17 00:00:00 2001 From: Yanxing-Shi Date: Sat, 3 Dec 2022 14:28:13 +0000 Subject: [PATCH 2/3] fix exp error --- .../rocm/conv2d/conv2d_bias_sigmoid.py | 26 +- .../backend/rocm/conv2d/conv2d_bias_silu.py | 30 +- .../compiler/test_fuse_conv_elementwise.py | 1276 ++++++++--------- 3 files changed, 687 insertions(+), 645 deletions(-) diff --git a/python/aitemplate/backend/rocm/conv2d/conv2d_bias_sigmoid.py b/python/aitemplate/backend/rocm/conv2d/conv2d_bias_sigmoid.py index aad3c2c15..ea0dc72bc 100644 --- a/python/aitemplate/backend/rocm/conv2d/conv2d_bias_sigmoid.py +++ b/python/aitemplate/backend/rocm/conv2d/conv2d_bias_sigmoid.py @@ -35,8 +35,30 @@ struct AddSigmoid { template - __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const{ - ck::tensor_operation::element_wise::Sigmoid{}(y, x0 + x1); + __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + const float a = x0 + x1; + y = 1.0f / (1.0f + exp(-a)); + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + const double a = x0 + x1; + y = 1.0 / (1.0 + exp(-a)); + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + const half_t a = x0 + x1; + y = type_convert(1.0) / (type_convert(1.0) + type_convert(exp(ck::type_convert(-a)))); }; }; } // namespace diff --git a/python/aitemplate/backend/rocm/conv2d/conv2d_bias_silu.py b/python/aitemplate/backend/rocm/conv2d/conv2d_bias_silu.py index f17ccf41b..a00c83841 100644 --- a/python/aitemplate/backend/rocm/conv2d/conv2d_bias_silu.py +++ b/python/aitemplate/backend/rocm/conv2d/conv2d_bias_silu.py @@ -26,7 +26,7 @@ """ #include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" +#include "ck/utility/data_type.hpp" namespace ck { namespace tensor_operation { @@ -35,10 +35,30 @@ struct AddSiLU { template - __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const{ - T a; - ck::tensor_operation::element_wise::Sigmoid{}(a, x0 + x1); - y = a * (x0 + x1); + __host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const; + + template <> + __host__ __device__ constexpr void + operator()(float& y, const float& x0, const float& x1) const + { + const float a = x0 + x1; + y = 1.0f / (1.0f + exp(-a)) * a; + }; + + template <> + __host__ __device__ constexpr void + operator()(double& y, const double& x0, const double& x1) const + { + const double a = x0 + x1; + y = 1.0 / (1.0 + exp(-a)) * a; + }; + + template <> + __host__ __device__ constexpr void + operator()(half_t& y, const half_t& x0, const half_t& x1) const + { + const half_t a = x0 + x1; + y = type_convert(1.0) / (type_convert(1.0) + type_convert(exp(ck::type_convert(-a)))) * a; }; }; } // namespace diff --git a/tests/unittest/compiler/test_fuse_conv_elementwise.py b/tests/unittest/compiler/test_fuse_conv_elementwise.py index 725821bd6..c120e1508 100644 --- a/tests/unittest/compiler/test_fuse_conv_elementwise.py +++ b/tests/unittest/compiler/test_fuse_conv_elementwise.py @@ -24,146 +24,146 @@ from aitemplate.utils import shape_utils -# @unittest.skipIf( -# detect_target().name() == "cuda" and detect_target()._arch < "80", -# "On CUDA, only supported on > SM80 arch.", -# ) -# class FuseConvCase(unittest.TestCase): -# def _build_conv2d( -# self, -# batch_dim, -# CO, -# HH, -# WW, -# CI, -# filter_HW, -# stride=1, -# transpose=False, -# ): -# X = Tensor( -# shape=[batch_dim, HH, WW, CI], -# dtype="float16", -# name="input_0", -# is_input=True, -# ) - -# W = Tensor( -# shape=[CO, filter_HW, filter_HW, CI], -# dtype="float16", -# name="input_1", -# is_input=True, -# ) -# if transpose: -# conv2d = ops.transposed_conv2d(stride=stride, pad=0)(X, W) -# else: -# conv2d = ops.conv2d(stride=stride, pad=0)(X, W) - -# return conv2d - -# def test_do_not_fuse_with_add_not_1d(self): -# """ -# We can't turn conv2d into conv2d_bias if the thing we do -# an add with is not 1d. -# """ - -# # Keep IntImm batch here just not to mess with profiling strategy -# B = [1] -# batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") -# CO, HH, WW, CI = 256, 28, 28, 128 -# filter_HW = 3 - -# bias = Tensor( -# shape=[batch_dim, 26, 26, CO], dtype="float16", name="bias", is_input=True -# ) -# conv2d = self._build_conv2d(batch_dim, CO, HH, WW, CI, filter_HW) -# output = ops.elementwise(FuncEnum.ADD)(bias, conv2d) -# output._attrs["is_output"] = True -# output._attrs["name"] = "output_0" - -# target = detect_target() -# module = compile_model( -# output, target, "./tmp", "test_do_not_fuse_with_add_not_1d" -# ) - -# check_tensor = None -# for tensor in module.debug_sorted_graph: -# if tensor._attrs["name"] == "output_0": -# check_tensor = tensor -# break -# self.assertIsNotNone(check_tensor) -# self.assertEqual(len(check_tensor.src_ops()), 1) -# src_op = list(check_tensor.src_ops())[0] -# self.assertEqual(src_op._attrs["op"], "fused_elementwise") - -# for b in B: -# X_pt = torch.randn(b, CI, HH, WW).cuda().half() -# W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() -# Y_pt = torch.nn.functional.conv2d(X_pt, W_pt) -# B_pt = torch.randn(Y_pt.size()).cuda().half() -# Y_pt = Y_pt + B_pt - -# x = X_pt.permute((0, 2, 3, 1)).contiguous() -# w = W_pt.permute((0, 2, 3, 1)).contiguous() -# b_pt = B_pt.permute((0, 2, 3, 1)).contiguous() -# inputs = {"input_0": x, "input_1": w, "bias": b_pt} - -# y = torch.empty([b, 26, 26, CO]).cuda().half() -# module.run_with_tensors(inputs, [y]) -# y_transpose = y.permute(0, 3, 1, 2) -# self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) - -# def test_do_not_fuse_transpose_with_add_not_1d(self): -# """ -# We can't turn transposed_conv2d into transposed_conv2d_bias if the thing we do -# an add with is not 1d. -# """ -# B = [1] -# CO, HH, WW, CI = 256, 28, 28, 256 -# filter_HW = 2 - -# batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") -# bias = Tensor( -# shape=[batch_dim, 56, 56, CO], dtype="float16", name="bias", is_input=True -# ) -# conv2d = self._build_conv2d( -# batch_dim, CO, HH, WW, CI, filter_HW, stride=2, transpose=True -# ) -# output = ops.elementwise(FuncEnum.ADD)(bias, conv2d) -# output._attrs["is_output"] = True -# output._attrs["name"] = "output_0" - -# target = detect_target() -# module = compile_model( -# output, target, "./tmp", "test_do_not_fuse_with_add_not_1d" -# ) - -# check_tensor = None -# for tensor in module.debug_sorted_graph: -# if tensor._attrs["name"] == "output_0": -# check_tensor = tensor -# break -# self.assertIsNotNone(check_tensor) -# self.assertEqual(len(check_tensor.src_ops()), 1) -# src_op = list(check_tensor.src_ops())[0] -# self.assertEqual(src_op._attrs["op"], "fused_elementwise") - -# for b in B: -# X_pt = torch.randn(b, CI, HH, WW).cuda().half() -# W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() -# W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() -# Y_pt = torch.nn.functional.conv_transpose2d(X_pt, W_pt, stride=2) -# B_pt = torch.randn(b, CO, 56, 56).cuda().half() -# Y_pt = Y_pt + B_pt - -# x = X_pt.permute((0, 2, 3, 1)).contiguous() -# w = W_pt.permute((0, 2, 3, 1)).contiguous() -# b_pt = B_pt.permute((0, 2, 3, 1)).contiguous() -# inputs = {"input_0": x, "input_1": w, "bias": b_pt} - -# y = torch.empty([b, 56, 56, CO]).cuda().half() -# module.run_with_tensors(inputs, [y]) -# y_transpose = y.permute(0, 3, 1, 2) -# self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) +@unittest.skipIf( + detect_target().name() == "cuda" and detect_target()._arch < "80", + "On CUDA, only supported on > SM80 arch.", +) +class FuseConvCase(unittest.TestCase): + def _build_conv2d( + self, + batch_dim, + CO, + HH, + WW, + CI, + filter_HW, + stride=1, + transpose=False, + ): + X = Tensor( + shape=[batch_dim, HH, WW, CI], + dtype="float16", + name="input_0", + is_input=True, + ) + + W = Tensor( + shape=[CO, filter_HW, filter_HW, CI], + dtype="float16", + name="input_1", + is_input=True, + ) + if transpose: + conv2d = ops.transposed_conv2d(stride=stride, pad=0)(X, W) + else: + conv2d = ops.conv2d(stride=stride, pad=0)(X, W) + + return conv2d + + def test_do_not_fuse_with_add_not_1d(self): + """ + We can't turn conv2d into conv2d_bias if the thing we do + an add with is not 1d. + """ + + # Keep IntImm batch here just not to mess with profiling strategy + B = [1] + batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") + CO, HH, WW, CI = 256, 28, 28, 128 + filter_HW = 3 + + bias = Tensor( + shape=[batch_dim, 26, 26, CO], dtype="float16", name="bias", is_input=True + ) + conv2d = self._build_conv2d(batch_dim, CO, HH, WW, CI, filter_HW) + output = ops.elementwise(FuncEnum.ADD)(bias, conv2d) + output._attrs["is_output"] = True + output._attrs["name"] = "output_0" + + target = detect_target() + module = compile_model( + output, target, "./tmp", "test_do_not_fuse_with_add_not_1d" + ) + + check_tensor = None + for tensor in module.debug_sorted_graph: + if tensor._attrs["name"] == "output_0": + check_tensor = tensor + break + self.assertIsNotNone(check_tensor) + self.assertEqual(len(check_tensor.src_ops()), 1) + src_op = list(check_tensor.src_ops())[0] + self.assertEqual(src_op._attrs["op"], "fused_elementwise") + + for b in B: + X_pt = torch.randn(b, CI, HH, WW).cuda().half() + W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() + Y_pt = torch.nn.functional.conv2d(X_pt, W_pt) + B_pt = torch.randn(Y_pt.size()).cuda().half() + Y_pt = Y_pt + B_pt + + x = X_pt.permute((0, 2, 3, 1)).contiguous() + w = W_pt.permute((0, 2, 3, 1)).contiguous() + b_pt = B_pt.permute((0, 2, 3, 1)).contiguous() + inputs = {"input_0": x, "input_1": w, "bias": b_pt} + + y = torch.empty([b, 26, 26, CO]).cuda().half() + module.run_with_tensors(inputs, [y]) + y_transpose = y.permute(0, 3, 1, 2) + self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) + + def test_do_not_fuse_transpose_with_add_not_1d(self): + """ + We can't turn transposed_conv2d into transposed_conv2d_bias if the thing we do + an add with is not 1d. + """ + B = [1] + CO, HH, WW, CI = 256, 28, 28, 256 + filter_HW = 2 + + batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") + bias = Tensor( + shape=[batch_dim, 56, 56, CO], dtype="float16", name="bias", is_input=True + ) + conv2d = self._build_conv2d( + batch_dim, CO, HH, WW, CI, filter_HW, stride=2, transpose=True + ) + output = ops.elementwise(FuncEnum.ADD)(bias, conv2d) + output._attrs["is_output"] = True + output._attrs["name"] = "output_0" + + target = detect_target() + module = compile_model( + output, target, "./tmp", "test_do_not_fuse_with_add_not_1d" + ) + + check_tensor = None + for tensor in module.debug_sorted_graph: + if tensor._attrs["name"] == "output_0": + check_tensor = tensor + break + self.assertIsNotNone(check_tensor) + self.assertEqual(len(check_tensor.src_ops()), 1) + src_op = list(check_tensor.src_ops())[0] + self.assertEqual(src_op._attrs["op"], "fused_elementwise") + + for b in B: + X_pt = torch.randn(b, CI, HH, WW).cuda().half() + W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() + W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() + Y_pt = torch.nn.functional.conv_transpose2d(X_pt, W_pt, stride=2) + B_pt = torch.randn(b, CO, 56, 56).cuda().half() + Y_pt = Y_pt + B_pt + + x = X_pt.permute((0, 2, 3, 1)).contiguous() + w = W_pt.permute((0, 2, 3, 1)).contiguous() + b_pt = B_pt.permute((0, 2, 3, 1)).contiguous() + inputs = {"input_0": x, "input_1": w, "bias": b_pt} + + y = torch.empty([b, 56, 56, CO]).cuda().half() + module.run_with_tensors(inputs, [y]) + y_transpose = y.permute(0, 3, 1, 2) + self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) class FuseConvBiasCase(unittest.TestCase): @@ -190,196 +190,196 @@ def _build_conv2d_bias(self, batch_dim, CO, HH, WW, CI, filter_HW, decomposed): return conv2d_bias - # def test_conv2d_bias(self): - # # Keep IntImm batch here just not to mess with profiling strategy - # B = [1] - # batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") - # CO, HH, WW, CI = 256, 28, 28, 128 - # filter_HW = 3 - - # conv2d_bias = self._build_conv2d_bias( - # batch_dim, CO, HH, WW, CI, filter_HW, True - # ) - # conv2d_bias._attrs["is_output"] = True - # conv2d_bias._attrs["name"] = "output_0" - - # target = detect_target() - # module = compile_model(conv2d_bias, target, "./tmp", "test_conv2d_bias") - - # check_tensor = None - # for tensor in module.debug_sorted_graph: - # if tensor._attrs["name"] == "output_0": - # check_tensor = tensor - # break - # self.assertIsNotNone(check_tensor) - # self.assertEqual(len(check_tensor.src_ops()), 1) - # src_op = list(check_tensor.src_ops())[0] - # self.assertEqual(src_op._attrs["op"], "conv2d_bias") - - # for b in B: - # X_pt = torch.randn(b, CI, HH, WW).cuda().half() - # W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() - # B_pt = torch.randn(1, CO, 1, 1).cuda().half() - # Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=1) - # Y_pt = Y_pt + B_pt - - # x = X_pt.permute((0, 2, 3, 1)).contiguous() - # w = W_pt.permute((0, 2, 3, 1)).contiguous() - # inputs = {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()} - - # y = torch.empty([b, HH, WW, CO]).cuda().half() - # module.run_with_tensors(inputs, [y]) - # y_transpose = y.permute(0, 3, 1, 2) - # self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) - - # def test_conv2d_bias_add_relu(self): - # B = [1] - # batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") - # CO, HH, WW, CI = 256, 28, 28, 128 - # filter_HW = 3 - - # conv2d_bias = self._build_conv2d_bias( - # batch_dim, CO, HH, WW, CI, filter_HW, False - # ) - # D = Tensor( - # shape=[batch_dim, HH, WW, CO], - # dtype="float16", - # name="input_3", - # is_input=True, - # ) - # conv2d_bias_add = ops.elementwise(FuncEnum.ADD)(conv2d_bias, D) - # conv2d_bias_add_relu = ops.elementwise(FuncEnum.RELU)(conv2d_bias_add) - # conv2d_bias_add_relu._attrs["is_output"] = True - # conv2d_bias_add_relu._attrs["name"] = "output_0" - - # target = detect_target() - # module = compile_model( - # conv2d_bias_add_relu, target, "./tmp", "test_conv2d_bias_add_relu" - # ) - - # check_tensor = None - # for tensor in module.debug_sorted_graph: - # if tensor._attrs["name"] == "output_0": - # check_tensor = tensor - # break - # self.assertIsNotNone(check_tensor) - # self.assertEqual(len(check_tensor.src_ops()), 1) - # src_op = list(check_tensor.src_ops())[0] - # self.assertEqual(src_op._attrs["op"], "conv2d_bias_add_relu") - - # for b in B: - # X_pt = torch.randn(b, CI, HH, WW).cuda().half() - # W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() - # B_pt = torch.randn(1, CO, 1, 1).cuda().half() - # D_pt = torch.randn(b, CO, HH, WW).cuda().half() - # Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=1) - # Y_pt = Y_pt + B_pt + D_pt - # Y_pt = torch.nn.functional.relu(Y_pt) - - # x = X_pt.permute((0, 2, 3, 1)).contiguous() - # w = W_pt.permute((0, 2, 3, 1)).contiguous() - # d = D_pt.permute((0, 2, 3, 1)).contiguous() - # inputs = { - # "input_0": x, - # "input_1": w, - # "input_2": B_pt.squeeze(), - # "input_3": d, - # } - - # y = torch.empty([b, HH, WW, CO]).cuda().half() - # module.run_with_tensors(inputs, [y]) - # y_transpose = y.permute(0, 3, 1, 2) - # self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) - - # def test_conv2d_bias_relu(self): - # B = [1] - # batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") - # CO, HH, WW, CI = 256, 28, 28, 128 - # filter_HW = 3 - - # conv2d_bias = self._build_conv2d_bias( - # batch_dim, CO, HH, WW, CI, filter_HW, False - # ) - # conv2d_bias_relu = ops.elementwise(FuncEnum.RELU)(conv2d_bias) - # conv2d_bias_relu._attrs["is_output"] = True - # conv2d_bias_relu._attrs["name"] = "output_0" - - # target = detect_target() - # module = compile_model( - # conv2d_bias_relu, target, "./tmp", "test_conv2d_bias_relu" - # ) - - # check_tensor = None - # for tensor in module.debug_sorted_graph: - # if tensor._attrs["name"] == "output_0": - # check_tensor = tensor - # break - # self.assertIsNotNone(check_tensor) - # self.assertEqual(len(check_tensor.src_ops()), 1) - # src_op = list(check_tensor.src_ops())[0] - # self.assertEqual(src_op._attrs["op"], "conv2d_bias_relu") - - # for b in B: - # X_pt = torch.randn(b, CI, HH, WW).cuda().half() - # W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() - # B_pt = torch.randn(1, CO, 1, 1).cuda().half() - # Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=1) - # Y_pt = Y_pt + B_pt - # Y_pt = torch.nn.functional.relu(Y_pt) - - # x = X_pt.permute((0, 2, 3, 1)).contiguous() - # w = W_pt.permute((0, 2, 3, 1)).contiguous() - # inputs = {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()} - - # y = torch.empty([b, HH, WW, CO]).cuda().half() - # module.run_with_tensors(inputs, [y]) - # y_transpose = y.permute(0, 3, 1, 2) - # self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) - - # def test_conv2d_bias_sigmoid(self): - # B = [1] - # batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") - # CO, HH, WW, CI = 256, 28, 28, 128 - # filter_HW = 3 - - # conv2d_bias = self._build_conv2d_bias( - # batch_dim, CO, HH, WW, CI, filter_HW, False - # ) - # conv2d_bias_sigmoid = ops.elementwise(FuncEnum.SIGMOID)(conv2d_bias) - # conv2d_bias_sigmoid._attrs["is_output"] = True - # conv2d_bias_sigmoid._attrs["name"] = "output_0" - - # target = detect_target() - # module = compile_model( - # conv2d_bias_sigmoid, target, "./tmp", "test_conv2d_bias_sigmoid" - # ) - - # check_tensor = None - # for tensor in module.debug_sorted_graph: - # if tensor._attrs["name"] == "output_0": - # check_tensor = tensor - # break - # self.assertIsNotNone(check_tensor) - # self.assertEqual(len(check_tensor.src_ops()), 1) - # src_op = list(check_tensor.src_ops())[0] - # self.assertEqual(src_op._attrs["op"], "conv2d_bias_sigmoid") - - # for b in B: - # X_pt = torch.randn(b, CI, HH, WW).cuda().half() - # W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() - # B_pt = torch.randn(1, CO, 1, 1).cuda().half() - # Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=1) - # Y_pt = Y_pt + B_pt - # Y_pt = torch.sigmoid(Y_pt) - - # x = X_pt.permute((0, 2, 3, 1)).contiguous() - # w = W_pt.permute((0, 2, 3, 1)).contiguous() - # inputs = {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()} - - # y = torch.empty([b, HH, WW, CO]).cuda().half() - # module.run_with_tensors(inputs, [y]) - # y_transpose = y.permute(0, 3, 1, 2) - # self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) + def test_conv2d_bias(self): + # Keep IntImm batch here just not to mess with profiling strategy + B = [1] + batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") + CO, HH, WW, CI = 256, 28, 28, 128 + filter_HW = 3 + + conv2d_bias = self._build_conv2d_bias( + batch_dim, CO, HH, WW, CI, filter_HW, True + ) + conv2d_bias._attrs["is_output"] = True + conv2d_bias._attrs["name"] = "output_0" + + target = detect_target() + module = compile_model(conv2d_bias, target, "./tmp", "test_conv2d_bias") + + check_tensor = None + for tensor in module.debug_sorted_graph: + if tensor._attrs["name"] == "output_0": + check_tensor = tensor + break + self.assertIsNotNone(check_tensor) + self.assertEqual(len(check_tensor.src_ops()), 1) + src_op = list(check_tensor.src_ops())[0] + self.assertEqual(src_op._attrs["op"], "conv2d_bias") + + for b in B: + X_pt = torch.randn(b, CI, HH, WW).cuda().half() + W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() + B_pt = torch.randn(1, CO, 1, 1).cuda().half() + Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=1) + Y_pt = Y_pt + B_pt + + x = X_pt.permute((0, 2, 3, 1)).contiguous() + w = W_pt.permute((0, 2, 3, 1)).contiguous() + inputs = {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()} + + y = torch.empty([b, HH, WW, CO]).cuda().half() + module.run_with_tensors(inputs, [y]) + y_transpose = y.permute(0, 3, 1, 2) + self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) + + def test_conv2d_bias_add_relu(self): + B = [1] + batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") + CO, HH, WW, CI = 256, 28, 28, 128 + filter_HW = 3 + + conv2d_bias = self._build_conv2d_bias( + batch_dim, CO, HH, WW, CI, filter_HW, False + ) + D = Tensor( + shape=[batch_dim, HH, WW, CO], + dtype="float16", + name="input_3", + is_input=True, + ) + conv2d_bias_add = ops.elementwise(FuncEnum.ADD)(conv2d_bias, D) + conv2d_bias_add_relu = ops.elementwise(FuncEnum.RELU)(conv2d_bias_add) + conv2d_bias_add_relu._attrs["is_output"] = True + conv2d_bias_add_relu._attrs["name"] = "output_0" + + target = detect_target() + module = compile_model( + conv2d_bias_add_relu, target, "./tmp", "test_conv2d_bias_add_relu" + ) + + check_tensor = None + for tensor in module.debug_sorted_graph: + if tensor._attrs["name"] == "output_0": + check_tensor = tensor + break + self.assertIsNotNone(check_tensor) + self.assertEqual(len(check_tensor.src_ops()), 1) + src_op = list(check_tensor.src_ops())[0] + self.assertEqual(src_op._attrs["op"], "conv2d_bias_add_relu") + + for b in B: + X_pt = torch.randn(b, CI, HH, WW).cuda().half() + W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() + B_pt = torch.randn(1, CO, 1, 1).cuda().half() + D_pt = torch.randn(b, CO, HH, WW).cuda().half() + Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=1) + Y_pt = Y_pt + B_pt + D_pt + Y_pt = torch.nn.functional.relu(Y_pt) + + x = X_pt.permute((0, 2, 3, 1)).contiguous() + w = W_pt.permute((0, 2, 3, 1)).contiguous() + d = D_pt.permute((0, 2, 3, 1)).contiguous() + inputs = { + "input_0": x, + "input_1": w, + "input_2": B_pt.squeeze(), + "input_3": d, + } + + y = torch.empty([b, HH, WW, CO]).cuda().half() + module.run_with_tensors(inputs, [y]) + y_transpose = y.permute(0, 3, 1, 2) + self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) + + def test_conv2d_bias_relu(self): + B = [1] + batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") + CO, HH, WW, CI = 256, 28, 28, 128 + filter_HW = 3 + + conv2d_bias = self._build_conv2d_bias( + batch_dim, CO, HH, WW, CI, filter_HW, False + ) + conv2d_bias_relu = ops.elementwise(FuncEnum.RELU)(conv2d_bias) + conv2d_bias_relu._attrs["is_output"] = True + conv2d_bias_relu._attrs["name"] = "output_0" + + target = detect_target() + module = compile_model( + conv2d_bias_relu, target, "./tmp", "test_conv2d_bias_relu" + ) + + check_tensor = None + for tensor in module.debug_sorted_graph: + if tensor._attrs["name"] == "output_0": + check_tensor = tensor + break + self.assertIsNotNone(check_tensor) + self.assertEqual(len(check_tensor.src_ops()), 1) + src_op = list(check_tensor.src_ops())[0] + self.assertEqual(src_op._attrs["op"], "conv2d_bias_relu") + + for b in B: + X_pt = torch.randn(b, CI, HH, WW).cuda().half() + W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() + B_pt = torch.randn(1, CO, 1, 1).cuda().half() + Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=1) + Y_pt = Y_pt + B_pt + Y_pt = torch.nn.functional.relu(Y_pt) + + x = X_pt.permute((0, 2, 3, 1)).contiguous() + w = W_pt.permute((0, 2, 3, 1)).contiguous() + inputs = {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()} + + y = torch.empty([b, HH, WW, CO]).cuda().half() + module.run_with_tensors(inputs, [y]) + y_transpose = y.permute(0, 3, 1, 2) + self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) + + def test_conv2d_bias_sigmoid(self): + B = [1] + batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") + CO, HH, WW, CI = 256, 28, 28, 128 + filter_HW = 3 + + conv2d_bias = self._build_conv2d_bias( + batch_dim, CO, HH, WW, CI, filter_HW, False + ) + conv2d_bias_sigmoid = ops.elementwise(FuncEnum.SIGMOID)(conv2d_bias) + conv2d_bias_sigmoid._attrs["is_output"] = True + conv2d_bias_sigmoid._attrs["name"] = "output_0" + + target = detect_target() + module = compile_model( + conv2d_bias_sigmoid, target, "./tmp", "test_conv2d_bias_sigmoid" + ) + + check_tensor = None + for tensor in module.debug_sorted_graph: + if tensor._attrs["name"] == "output_0": + check_tensor = tensor + break + self.assertIsNotNone(check_tensor) + self.assertEqual(len(check_tensor.src_ops()), 1) + src_op = list(check_tensor.src_ops())[0] + self.assertEqual(src_op._attrs["op"], "conv2d_bias_sigmoid") + + for b in B: + X_pt = torch.randn(b, CI, HH, WW).cuda().half() + W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() + B_pt = torch.randn(1, CO, 1, 1).cuda().half() + Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=1) + Y_pt = Y_pt + B_pt + Y_pt = torch.sigmoid(Y_pt) + + x = X_pt.permute((0, 2, 3, 1)).contiguous() + w = W_pt.permute((0, 2, 3, 1)).contiguous() + inputs = {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()} + + y = torch.empty([b, HH, WW, CO]).cuda().half() + module.run_with_tensors(inputs, [y]) + y_transpose = y.permute(0, 3, 1, 2) + self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) def test_conv2d_bias_silu(self): B = [1] @@ -426,314 +426,314 @@ def test_conv2d_bias_silu(self): y_transpose = y.permute(0, 3, 1, 2) self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) -# def test_conv2d_bias_add_fusion(self): -# target = detect_target() -# if target.name() == "rocm": -# return - -# B = [1] -# batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") -# CO, HH, WW, CI = 256, 28, 28, 128 -# filter_HW = 3 -# R = Tensor( -# shape=[batch_dim, HH, WW, CO], -# dtype="float16", -# name="residual", -# is_input=True, -# ) - -# conv2d_bias = self._build_conv2d_bias( -# batch_dim, CO, HH, WW, CI, filter_HW, False -# ) -# conv2d_bias_add = ops.elementwise(FuncEnum.ADD)(conv2d_bias, R) -# conv2d_bias_add._attrs["is_output"] = True -# conv2d_bias_add._attrs["name"] = "output_0" - -# module = compile_model(conv2d_bias_add, target, "./tmp", "test_conv2d_bias_add") - -# check_tensor = None -# for tensor in module.debug_sorted_graph: -# if tensor._attrs["name"] == "output_0": -# check_tensor = tensor -# break -# self.assertIsNotNone(check_tensor) -# self.assertEqual(len(check_tensor.src_ops()), 1) -# src_op = list(check_tensor.src_ops())[0] -# self.assertEqual(src_op._attrs["op"], "conv2d_bias_add_identity") - -# for b in B: -# X_pt = torch.randn(b, CI, HH, WW).cuda().half() -# W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() -# B_pt = torch.randn(1, CO, 1, 1).cuda().half() -# R_pt = torch.randn(b, CO, HH, WW).cuda().half() -# Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=1) -# Y_pt = Y_pt + B_pt + R_pt - -# x = X_pt.permute((0, 2, 3, 1)).contiguous() -# w = W_pt.permute((0, 2, 3, 1)).contiguous() -# r = R_pt.permute((0, 2, 3, 1)).contiguous() -# inputs = { -# "input_0": x, -# "input_1": w, -# "input_2": B_pt.squeeze(), -# "residual": r, -# } - -# y = torch.empty([b, HH, WW, CO]).cuda().half() -# module.run_with_tensors(inputs, [y]) -# y_transpose = y.permute(0, 3, 1, 2) -# self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) - -# def test_conv2d_bias_add_do_not_fuse(self): -# B = [1] -# batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") -# CO, HH, WW, CI = 256, 28, 28, 128 -# filter_HW = 3 -# R = Tensor( -# shape=[batch_dim, 1, WW, CO], -# dtype="float16", -# name="residual", -# is_input=True, -# ) - -# conv2d_bias = self._build_conv2d_bias( -# batch_dim, CO, HH, WW, CI, filter_HW, False -# ) -# conv2d_bias_add = ops.elementwise(FuncEnum.ADD)(conv2d_bias, R) -# conv2d_bias_add._attrs["is_output"] = True -# conv2d_bias_add._attrs["name"] = "output_0" - -# target = detect_target() -# module = compile_model(conv2d_bias_add, target, "./tmp", "test_conv2d_bias_add") - -# graph = module.debug_sorted_graph - -# self.assertFalse(graph_has_op(graph, "conv2d_bias_add_identity")) -# self.assertTrue(graph_has_op(graph, "conv2d_bias")) - - -# @unittest.skipIf(detect_target().name() == "rocm", "Not supported by ROCM.") -# class FuseConvBiasFewChannelCase(unittest.TestCase): -# def test_conv2d_bias_relu_few_channels(self): -# HH, WW, CI, CO, batch = 224, 224, 4, 64, 4 -# KK = 7 -# stride = 2 -# pad = 3 -# target = detect_target() -# X = Tensor( -# shape=[batch, HH, WW, CI], -# dtype="float16", -# name="input_0", -# is_input=True, -# ) -# W = Tensor( -# shape=[CO, KK, KK, CI], dtype="float16", name="input_1", is_input=True -# ) -# B = Tensor(shape=[CO], dtype="float16", name="input_2", is_input=True) -# OP = ops.conv2d_bias_few_channels(stride=stride, pad=pad, dilate=1) -# Y = OP(X, W, B) -# Y = ops.elementwise(FuncEnum.RELU)(Y) -# Y._attrs["name"] = "output_0" -# Y._attrs["is_output"] = True - -# module = compile_model(Y, target, "./tmp", "test_conv_bias_relu_few_channels") - -# check_tensor = None -# for tensor in module.debug_sorted_graph: -# if tensor._attrs["name"] == "output_0": -# check_tensor = tensor -# break -# self.assertIsNotNone(check_tensor) -# self.assertEqual(len(check_tensor.src_ops()), 1) -# src_op = list(check_tensor.src_ops())[0] -# self.assertEqual(src_op._attrs["op"], "conv2d_bias_relu_few_channels") - -# X_pt = torch.randn(batch, CI, HH, WW).cuda().half() -# W_pt = torch.randn(CO, CI, KK, KK).cuda().half() -# B_pt = torch.randn(1, CO, 1, 1).cuda().half() -# Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=pad, stride=stride) -# Y_pt = Y_pt + B_pt -# Y_pt = torch.nn.functional.relu(Y_pt) -# x = X_pt.permute((0, 2, 3, 1)).contiguous() -# w = W_pt.permute((0, 2, 3, 1)).contiguous() -# inputs = {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()} -# y = torch.empty([batch, HH // stride, WW // stride, CO]).cuda().half() -# module.run_with_tensors(inputs, [y]) -# y_transpose = y.permute((0, 3, 1, 2)) -# self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-2, rtol=1e-2)) - - -# @unittest.skipIf(detect_target().name() == "rocm", "Not supported by ROCM.") -# @unittest.skipIf( -# detect_target().name() == "cuda" and int(detect_target()._arch) < 80, -# "Not supported by CUDA < SM80.", -# ) -# class FuseTransposedConvCase(unittest.TestCase): -# def _build_transposedConv2d_bias_relu_chain( -# self, batch, HH, WW, CI, CO, filter_HW, stride, pad, dilate, depth, decomposed -# ): -# X = Tensor( -# shape=[batch, HH, WW, CI], -# dtype="float16", -# name="input_0", -# is_input=True, -# ) -# W = Tensor( -# shape=[CO, filter_HW, filter_HW, CI], -# dtype="float16", -# name="input_1", -# is_input=True, -# ) -# B = Tensor(shape=[CO], dtype="float16", name="input_2", is_input=True) -# if decomposed: -# transposed_conv2d = ops.transposed_conv2d( -# stride=stride, pad=pad, dilate=dilate -# )(X, W) -# if depth == 0: -# return transposed_conv2d - -# transposed_conv2d_bias = ops.elementwise(FuncEnum.ADD)(transposed_conv2d, B) -# else: -# transposed_conv2d_bias = ops.transposed_conv2d_bias( -# stride=stride, pad=pad, dilate=dilate -# )(X, W, B) -# if depth == 0: -# raise RuntimeError("depth == 0 needs to be decomposed.") -# if depth == 1: -# return transposed_conv2d_bias - -# transposed_conv2d_bias_relu = ops.elementwise(FuncEnum.RELU)( -# transposed_conv2d_bias -# ) -# if depth == 2: -# return transposed_conv2d_bias_relu - -# raise RuntimeError(f"depth should be <= 2, unknown depth {depth}") - -# def _test_transposed_conv2d_bias(self, decomposed): -# batch = 4 -# HH, WW, CI, CO = 14, 14, 256, 256 -# filter_HW = 2 -# stride = 2 -# pad = 0 -# dilate = 1 -# transposed_conv2d_bias = self._build_transposedConv2d_bias_relu_chain( -# batch, -# HH, -# WW, -# CI, -# CO, -# filter_HW, -# stride, -# pad, -# dilate, -# 1, -# decomposed=decomposed, -# ) -# transposed_conv2d_bias._attrs["is_output"] = True -# transposed_conv2d_bias._attrs["name"] = "output_0" - -# target = detect_target() -# module = compile_model( -# transposed_conv2d_bias, -# target, -# "./tmp", -# f"fuse_transpose_conv2d_bias_{decomposed}", -# ) - -# check_tensor = None -# for tensor in module.debug_sorted_graph: -# if tensor._attrs["name"] == "output_0": -# check_tensor = tensor -# break -# self.assertIsNotNone(check_tensor) -# self.assertEqual(len(check_tensor.src_ops()), 1) -# src_op = list(check_tensor.src_ops())[0] -# self.assertEqual(src_op._attrs["op"], "transposed_conv2d_bias") - -# X_pt = torch.randn(batch, CI, HH, WW).cuda().half() -# W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() -# B_pt = torch.randn(1, CO, 1, 1).cuda().half() -# Y_pt = torch.nn.functional.conv_transpose2d( -# X_pt, W_pt, padding=pad, stride=stride -# ) -# Y_pt = Y_pt + B_pt - -# x = X_pt.permute((0, 2, 3, 1)).contiguous() -# w = W_pt.permute((0, 2, 3, 1)).contiguous() -# y = torch.empty([batch, 28, 28, CO]).cuda().half() -# module.run_with_tensors( -# {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()}, [y] -# ) -# y_transpose = y.permute((0, 3, 1, 2)) -# self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) - -# def test_transposed_conv2d_bias(self): -# self._test_transposed_conv2d_bias(True) -# self._test_transposed_conv2d_bias(False) - -# def _test_transposed_conv2d_bias_relu(self, decomposed): -# batch = 4 -# HH, WW, CI, CO = 14, 14, 256, 256 -# filter_HW = 2 -# stride = 2 -# pad = 0 -# dilate = 1 -# transposed_conv2d_bias_relu = self._build_transposedConv2d_bias_relu_chain( -# batch, -# HH, -# WW, -# CI, -# CO, -# filter_HW, -# stride, -# pad, -# dilate, -# 2, -# decomposed=decomposed, -# ) -# transposed_conv2d_bias_relu._attrs["is_output"] = True -# transposed_conv2d_bias_relu._attrs["name"] = "output_0" - -# target = detect_target() -# module = compile_model( -# transposed_conv2d_bias_relu, -# target, -# "./tmp", -# f"fuse_transpose_conv2d_bias_relu_{decomposed}", -# ) - -# check_tensor = None -# for tensor in module.debug_sorted_graph: -# if tensor._attrs["name"] == "output_0": -# check_tensor = tensor -# break -# self.assertIsNotNone(check_tensor) -# self.assertEqual(len(check_tensor.src_ops()), 1) -# src_op = list(check_tensor.src_ops())[0] -# self.assertEqual(src_op._attrs["op"], "transposed_conv2d_bias_relu") - -# X_pt = torch.randn(batch, CI, HH, WW).cuda().half() -# W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() -# B_pt = torch.randn(1, CO, 1, 1).cuda().half() -# Y_pt = torch.nn.functional.conv_transpose2d( -# X_pt, W_pt, padding=pad, stride=stride -# ) -# Y_pt = Y_pt + B_pt -# Y_pt = torch.relu(Y_pt) - -# x = X_pt.permute((0, 2, 3, 1)).contiguous() -# w = W_pt.permute((0, 2, 3, 1)).contiguous() -# y = torch.empty([batch, 28, 28, CO]).cuda().half() -# module.run_with_tensors( -# {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()}, [y] -# ) -# y_transpose = y.permute((0, 3, 1, 2)) -# self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) - -# def test_transposed_conv2d_bias_relu(self): -# self._test_transposed_conv2d_bias_relu(True) -# self._test_transposed_conv2d_bias_relu(False) + def test_conv2d_bias_add_fusion(self): + target = detect_target() + if target.name() == "rocm": + return + + B = [1] + batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") + CO, HH, WW, CI = 256, 28, 28, 128 + filter_HW = 3 + R = Tensor( + shape=[batch_dim, HH, WW, CO], + dtype="float16", + name="residual", + is_input=True, + ) + + conv2d_bias = self._build_conv2d_bias( + batch_dim, CO, HH, WW, CI, filter_HW, False + ) + conv2d_bias_add = ops.elementwise(FuncEnum.ADD)(conv2d_bias, R) + conv2d_bias_add._attrs["is_output"] = True + conv2d_bias_add._attrs["name"] = "output_0" + + module = compile_model(conv2d_bias_add, target, "./tmp", "test_conv2d_bias_add") + + check_tensor = None + for tensor in module.debug_sorted_graph: + if tensor._attrs["name"] == "output_0": + check_tensor = tensor + break + self.assertIsNotNone(check_tensor) + self.assertEqual(len(check_tensor.src_ops()), 1) + src_op = list(check_tensor.src_ops())[0] + self.assertEqual(src_op._attrs["op"], "conv2d_bias_add_identity") + + for b in B: + X_pt = torch.randn(b, CI, HH, WW).cuda().half() + W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() + B_pt = torch.randn(1, CO, 1, 1).cuda().half() + R_pt = torch.randn(b, CO, HH, WW).cuda().half() + Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=1) + Y_pt = Y_pt + B_pt + R_pt + + x = X_pt.permute((0, 2, 3, 1)).contiguous() + w = W_pt.permute((0, 2, 3, 1)).contiguous() + r = R_pt.permute((0, 2, 3, 1)).contiguous() + inputs = { + "input_0": x, + "input_1": w, + "input_2": B_pt.squeeze(), + "residual": r, + } + + y = torch.empty([b, HH, WW, CO]).cuda().half() + module.run_with_tensors(inputs, [y]) + y_transpose = y.permute(0, 3, 1, 2) + self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) + + def test_conv2d_bias_add_do_not_fuse(self): + B = [1] + batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim") + CO, HH, WW, CI = 256, 28, 28, 128 + filter_HW = 3 + R = Tensor( + shape=[batch_dim, 1, WW, CO], + dtype="float16", + name="residual", + is_input=True, + ) + + conv2d_bias = self._build_conv2d_bias( + batch_dim, CO, HH, WW, CI, filter_HW, False + ) + conv2d_bias_add = ops.elementwise(FuncEnum.ADD)(conv2d_bias, R) + conv2d_bias_add._attrs["is_output"] = True + conv2d_bias_add._attrs["name"] = "output_0" + + target = detect_target() + module = compile_model(conv2d_bias_add, target, "./tmp", "test_conv2d_bias_add") + + graph = module.debug_sorted_graph + + self.assertFalse(graph_has_op(graph, "conv2d_bias_add_identity")) + self.assertTrue(graph_has_op(graph, "conv2d_bias")) + + +@unittest.skipIf(detect_target().name() == "rocm", "Not supported by ROCM.") +class FuseConvBiasFewChannelCase(unittest.TestCase): + def test_conv2d_bias_relu_few_channels(self): + HH, WW, CI, CO, batch = 224, 224, 4, 64, 4 + KK = 7 + stride = 2 + pad = 3 + target = detect_target() + X = Tensor( + shape=[batch, HH, WW, CI], + dtype="float16", + name="input_0", + is_input=True, + ) + W = Tensor( + shape=[CO, KK, KK, CI], dtype="float16", name="input_1", is_input=True + ) + B = Tensor(shape=[CO], dtype="float16", name="input_2", is_input=True) + OP = ops.conv2d_bias_few_channels(stride=stride, pad=pad, dilate=1) + Y = OP(X, W, B) + Y = ops.elementwise(FuncEnum.RELU)(Y) + Y._attrs["name"] = "output_0" + Y._attrs["is_output"] = True + + module = compile_model(Y, target, "./tmp", "test_conv_bias_relu_few_channels") + + check_tensor = None + for tensor in module.debug_sorted_graph: + if tensor._attrs["name"] == "output_0": + check_tensor = tensor + break + self.assertIsNotNone(check_tensor) + self.assertEqual(len(check_tensor.src_ops()), 1) + src_op = list(check_tensor.src_ops())[0] + self.assertEqual(src_op._attrs["op"], "conv2d_bias_relu_few_channels") + + X_pt = torch.randn(batch, CI, HH, WW).cuda().half() + W_pt = torch.randn(CO, CI, KK, KK).cuda().half() + B_pt = torch.randn(1, CO, 1, 1).cuda().half() + Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=pad, stride=stride) + Y_pt = Y_pt + B_pt + Y_pt = torch.nn.functional.relu(Y_pt) + x = X_pt.permute((0, 2, 3, 1)).contiguous() + w = W_pt.permute((0, 2, 3, 1)).contiguous() + inputs = {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()} + y = torch.empty([batch, HH // stride, WW // stride, CO]).cuda().half() + module.run_with_tensors(inputs, [y]) + y_transpose = y.permute((0, 3, 1, 2)) + self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-2, rtol=1e-2)) + + +@unittest.skipIf(detect_target().name() == "rocm", "Not supported by ROCM.") +@unittest.skipIf( + detect_target().name() == "cuda" and int(detect_target()._arch) < 80, + "Not supported by CUDA < SM80.", +) +class FuseTransposedConvCase(unittest.TestCase): + def _build_transposedConv2d_bias_relu_chain( + self, batch, HH, WW, CI, CO, filter_HW, stride, pad, dilate, depth, decomposed + ): + X = Tensor( + shape=[batch, HH, WW, CI], + dtype="float16", + name="input_0", + is_input=True, + ) + W = Tensor( + shape=[CO, filter_HW, filter_HW, CI], + dtype="float16", + name="input_1", + is_input=True, + ) + B = Tensor(shape=[CO], dtype="float16", name="input_2", is_input=True) + if decomposed: + transposed_conv2d = ops.transposed_conv2d( + stride=stride, pad=pad, dilate=dilate + )(X, W) + if depth == 0: + return transposed_conv2d + + transposed_conv2d_bias = ops.elementwise(FuncEnum.ADD)(transposed_conv2d, B) + else: + transposed_conv2d_bias = ops.transposed_conv2d_bias( + stride=stride, pad=pad, dilate=dilate + )(X, W, B) + if depth == 0: + raise RuntimeError("depth == 0 needs to be decomposed.") + if depth == 1: + return transposed_conv2d_bias + + transposed_conv2d_bias_relu = ops.elementwise(FuncEnum.RELU)( + transposed_conv2d_bias + ) + if depth == 2: + return transposed_conv2d_bias_relu + + raise RuntimeError(f"depth should be <= 2, unknown depth {depth}") + + def _test_transposed_conv2d_bias(self, decomposed): + batch = 4 + HH, WW, CI, CO = 14, 14, 256, 256 + filter_HW = 2 + stride = 2 + pad = 0 + dilate = 1 + transposed_conv2d_bias = self._build_transposedConv2d_bias_relu_chain( + batch, + HH, + WW, + CI, + CO, + filter_HW, + stride, + pad, + dilate, + 1, + decomposed=decomposed, + ) + transposed_conv2d_bias._attrs["is_output"] = True + transposed_conv2d_bias._attrs["name"] = "output_0" + + target = detect_target() + module = compile_model( + transposed_conv2d_bias, + target, + "./tmp", + f"fuse_transpose_conv2d_bias_{decomposed}", + ) + + check_tensor = None + for tensor in module.debug_sorted_graph: + if tensor._attrs["name"] == "output_0": + check_tensor = tensor + break + self.assertIsNotNone(check_tensor) + self.assertEqual(len(check_tensor.src_ops()), 1) + src_op = list(check_tensor.src_ops())[0] + self.assertEqual(src_op._attrs["op"], "transposed_conv2d_bias") + + X_pt = torch.randn(batch, CI, HH, WW).cuda().half() + W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() + B_pt = torch.randn(1, CO, 1, 1).cuda().half() + Y_pt = torch.nn.functional.conv_transpose2d( + X_pt, W_pt, padding=pad, stride=stride + ) + Y_pt = Y_pt + B_pt + + x = X_pt.permute((0, 2, 3, 1)).contiguous() + w = W_pt.permute((0, 2, 3, 1)).contiguous() + y = torch.empty([batch, 28, 28, CO]).cuda().half() + module.run_with_tensors( + {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()}, [y] + ) + y_transpose = y.permute((0, 3, 1, 2)) + self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) + + def test_transposed_conv2d_bias(self): + self._test_transposed_conv2d_bias(True) + self._test_transposed_conv2d_bias(False) + + def _test_transposed_conv2d_bias_relu(self, decomposed): + batch = 4 + HH, WW, CI, CO = 14, 14, 256, 256 + filter_HW = 2 + stride = 2 + pad = 0 + dilate = 1 + transposed_conv2d_bias_relu = self._build_transposedConv2d_bias_relu_chain( + batch, + HH, + WW, + CI, + CO, + filter_HW, + stride, + pad, + dilate, + 2, + decomposed=decomposed, + ) + transposed_conv2d_bias_relu._attrs["is_output"] = True + transposed_conv2d_bias_relu._attrs["name"] = "output_0" + + target = detect_target() + module = compile_model( + transposed_conv2d_bias_relu, + target, + "./tmp", + f"fuse_transpose_conv2d_bias_relu_{decomposed}", + ) + + check_tensor = None + for tensor in module.debug_sorted_graph: + if tensor._attrs["name"] == "output_0": + check_tensor = tensor + break + self.assertIsNotNone(check_tensor) + self.assertEqual(len(check_tensor.src_ops()), 1) + src_op = list(check_tensor.src_ops())[0] + self.assertEqual(src_op._attrs["op"], "transposed_conv2d_bias_relu") + + X_pt = torch.randn(batch, CI, HH, WW).cuda().half() + W_pt = torch.randn(CO, CI, filter_HW, filter_HW).cuda().half() + B_pt = torch.randn(1, CO, 1, 1).cuda().half() + Y_pt = torch.nn.functional.conv_transpose2d( + X_pt, W_pt, padding=pad, stride=stride + ) + Y_pt = Y_pt + B_pt + Y_pt = torch.relu(Y_pt) + + x = X_pt.permute((0, 2, 3, 1)).contiguous() + w = W_pt.permute((0, 2, 3, 1)).contiguous() + y = torch.empty([batch, 28, 28, CO]).cuda().half() + module.run_with_tensors( + {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()}, [y] + ) + y_transpose = y.permute((0, 3, 1, 2)) + self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) + + def test_transposed_conv2d_bias_relu(self): + self._test_transposed_conv2d_bias_relu(True) + self._test_transposed_conv2d_bias_relu(False) if __name__ == "__main__": From eb6b555eb61065ead45b71dc6ba8afe498b0fe08 Mon Sep 17 00:00:00 2001 From: Yanxing-Shi Date: Sat, 3 Dec 2022 14:40:56 +0000 Subject: [PATCH 3/3] add unnitest skip cuda --- tests/unittest/compiler/test_fuse_conv_elementwise.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unittest/compiler/test_fuse_conv_elementwise.py b/tests/unittest/compiler/test_fuse_conv_elementwise.py index c120e1508..69e7601ed 100644 --- a/tests/unittest/compiler/test_fuse_conv_elementwise.py +++ b/tests/unittest/compiler/test_fuse_conv_elementwise.py @@ -381,6 +381,7 @@ def test_conv2d_bias_sigmoid(self): y_transpose = y.permute(0, 3, 1, 2) self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-1, rtol=1e-1)) + @unittest.skipIf(detect_target().name() == "cuda", "Not supported by CUDA.") def test_conv2d_bias_silu(self): B = [1] batch_dim = shape_utils.gen_int_var_min_max(B, name="batch_dim")