Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add conv2d+bias+silu fusion API for rocm backend #25

Open
wants to merge 4 commits into
base: amd-develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/aitemplate/backend/rocm/conv2d/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
conv2d_bias_add_relu,
conv2d_bias_relu,
conv2d_bias_sigmoid,
conv2d_bias_silu,
transposed_conv2d,
transposed_conv2d_bias_relu,
)
Expand All @@ -33,6 +34,7 @@
"conv2d_bias_add_relu",
"conv2d_bias_relu",
"conv2d_bias_sigmoid",
"conv2d_bias_silu",
"transposed_conv2d",
"transposed_conv2d_bias_relu",
]
6 changes: 4 additions & 2 deletions python/aitemplate/backend/rocm/conv2d/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

{% if conv2d_flag == "" %}
{{indent}} {},
{% elif conv2d_flag in ["bias", "bias_relu", "bias_sigmoid"] %}
{% elif conv2d_flag in ["bias", "bias_relu", "bias_sigmoid", "bias_silu"] %}
{{indent}} std::array<const void*, 1>{static_cast<ck::half_t *>(bias_ptr)},
{% elif conv2d_flag in ["bias_add_relu", "bias_add_identity"] %}
{{indent}} std::array<const void*, 2>{static_cast<ck::half_t *>(bias_ptr), static_cast<ck::half_t *>(res_ptr)},
Expand All @@ -52,7 +52,7 @@
{{indent}} b_g_k_c_xs_strides,
{% if conv2d_flag == "" %}
{{indent}} {}, {},
{% elif conv2d_flag in ["bias", "bias_relu", "bias_sigmoid"] %}
{% elif conv2d_flag in ["bias", "bias_relu", "bias_sigmoid", "bias_silu"] %}
{{indent}} std::array<std::array<ck::index_t, NDimSpatial + 3>, 1>{ {d_g_n_k_wos_lengths} },
{{indent}} std::array<std::array<ck::index_t, NDimSpatial + 3>, 1>{ {d_g_n_k_wos_strides} },
{% elif conv2d_flag in ["bias_add_relu", "bias_add_identity"] %}
Expand All @@ -75,6 +75,8 @@
{{indent}} ck::tensor_operation::element_wise::AddRelu{}
{% elif conv2d_flag == "bias_sigmoid" %}
{{indent}} ck::tensor_operation::element_wise::AddSigmoid{}
{% elif conv2d_flag == "bias_silu" %}
{{indent}} ck::tensor_operation::element_wise::AddSiLU{}
{% elif conv2d_flag == "bias_add_identity" %}
{{indent}} ck::tensor_operation::element_wise::AddAdd{}
{% elif conv2d_flag == "bias_add_relu" %}
Expand Down
26 changes: 24 additions & 2 deletions python/aitemplate/backend/rocm/conv2d/conv2d_bias_sigmoid.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,30 @@
struct AddSigmoid
{
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const{
ck::tensor_operation::element_wise::Sigmoid{}(y, x0 + x1);
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;

template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const float& x1) const
{
const float a = x0 + x1;
y = 1.0f / (1.0f + exp(-a));
};

template <>
__host__ __device__ constexpr void
operator()<double>(double& y, const double& x0, const double& x1) const
{
const double a = x0 + x1;
y = 1.0 / (1.0 + exp(-a));
};

template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
{
const half_t a = x0 + x1;
y = type_convert<half_t>(1.0) / (type_convert<half_t>(1.0) + type_convert<half_t>(exp(ck::type_convert<float>(-a))));
};
};
} // namespace
Expand Down
212 changes: 212 additions & 0 deletions python/aitemplate/backend/rocm/conv2d/conv2d_bias_silu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
ROCM codegen functions for conv2d_bias_silu.
"""
import jinja2

from ... import registry
from . import common

# pylint: disable=C0103,C0415,W0613

EXTRA_CODE = jinja2.Template(
"""
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d_xdl_cshuffle.hpp"

#include "ck/utility/data_type.hpp"

namespace ck {
namespace tensor_operation {
namespace element_wise {
namespace {
struct AddSiLU
{
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;

template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const float& x1) const
{
const float a = x0 + x1;
y = 1.0f / (1.0f + exp(-a)) * a;
};

template <>
__host__ __device__ constexpr void
operator()<double>(double& y, const double& x0, const double& x1) const
{
const double a = x0 + x1;
y = 1.0 / (1.0 + exp(-a)) * a;
};

template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
{
const half_t a = x0 + x1;
y = type_convert<half_t>(1.0) / (type_convert<half_t>(1.0) + type_convert<half_t>(exp(ck::type_convert<float>(-a)))) * a;
};
};
} // namespace
} // namespace element_wise
} // namespace tensor_operation
} // namespace ck
"""
)


@registry.reg("rocm.conv2d_bias_silu.config")
def conv2d_config(func_attrs):
"""Extracts (operation name, operation instance) pair from
all operation candidates.


Parameters
----------
func_attrs : Dict
Operation attributes.

Returns
-------
Dict
Extracted (operation name, operation instance) pair
from all operation candidates.
"""
import ck_lib

op_kind = ck_lib.library.Conv2dKind.GroupConv2dBiasRelu
extra_kind = ck_lib.library.TensorOperation.AddSiLU
func_attrs["op_instance"] = common.extract_config(op_kind, extra_kind)


@registry.reg("rocm.conv2d_bias_silu.gen_profiler")
def conv2d_gen_profiler(func_attrs, workdir, shape_template):
"""Generates standalone executables for profiler.

Parameters
----------
func_attrs : Dict
Operation attributes.
workdir : str
Directory to store the generated outputs.
shape_template : jinja2.Template
Generates shape calculation.
The template is passed from compiler/ops/pool.
"""
return common.gen_profiler(
func_attrs=func_attrs,
workdir=workdir,
shape_template=shape_template,
conv2d_flag="bias_silu",
extra_code=EXTRA_CODE.render(),
)


@registry.reg("rocm.conv2d_bias_silu.gen_function")
def conv2d_gen_function(
func_attrs,
exec_cond_remplate,
shape_eval_template,
shape_save_template,
):
"""Generates function body.

Parameters
----------
func_attrs : Dict
Operation attributes.
exec_cond_remplate : jinja2.Template
Generates if statement to execute kernel.
shape_eval_template : jinja2.Template
Generates shape calculation.
The template is passed from compiler/ops/pool.
shape_save_template : jinja2.Template
Generates output dimensions.
The template is passed from compiler/ops/pool.

Returns
-------
str
The rendered template of generated function body.
"""
return common.gen_function(
func_attrs,
exec_cond_remplate,
shape_eval_template,
shape_save_template,
"bias_silu",
extra_code=EXTRA_CODE.render(),
)


@registry.reg("rocm.conv2d_bias_silu.func_decl")
def conv2d_gen_function_decl(func_attrs):
"""Generates function declarations.

Parameters
----------
func_attrs : Dict
Operation attributes.

Returns
-------
str
The rentered template of function declaration.
"""
func_name = func_attrs["name"]
return common.gen_function_decl(func_name=func_name, conv2d_flag="bias_sigmoid")


@registry.reg("rocm.conv2d_bias_silu.func_call")
def conv2d_gen_function_call(func_attrs, indent=" "):
"""Generates function call.

Parameters
----------
func_attrs : Dict
Stores the operation attributes.
indent : str, optional
Indent for codegen, target dependent e.g. C++, python, etc., by default " ".

Returns
-------
str
The rendered template of generated function call.
"""
return common.gen_function_call(func_attrs, indent, conv2d_flag="bias_silu")


@registry.reg("rocm.conv2d_bias_silu.filter")
def conv2d_function_filter(cfg, func_attrs, x_shape):
"""Generates function filter.

Parameters
----------
cfg: str
The filename generated for profiler.
func_attrs : Dict
Stores the operation attributes.
x_shape:
Input shapes.

Returns
-------
bool
If input cfg should be filtered.
"""
# return common.function_filter(cfg, func_attrs, 3)
return True
1 change: 1 addition & 0 deletions python/aitemplate/compiler/ops/conv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from .conv2d_bias_relu import conv2d_bias_relu
from .conv2d_bias_relu_few_channels import conv2d_bias_relu_few_channels
from .conv2d_bias_sigmoid import conv2d_bias_sigmoid
from .conv2d_bias_silu import conv2d_bias_silu
from .conv2d_depthwise import conv2d_depthwise
from .conv2d_depthwise_bias import conv2d_depthwise_bias
from .conv3d import conv3d
Expand Down
77 changes: 77 additions & 0 deletions python/aitemplate/compiler/ops/conv/conv2d_bias_silu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
Fused conv2d_bias_silu op.
"""
from .common_conv2d_bias_activation import conv2d_bias_activation


# pylint: disable=C0103
class conv2d_bias_silu(conv2d_bias_activation):
r"""Conv2d with bias + silu.

Applies a 2D convolution on input in shape (N, H, W, C_in), adds a bias in shape (C_out), performs sigmoid and produces output in shape (N, H_out, W_out, C_out). N is batch size, H, W are the height and width of the input images in pixels, and C is the number of channels.

Args:
input: input tensor of shape :math:`(N , H , W, \text{in\_channels})`

weight: filters of shape :math:`(\text{out\_channels} , K_h, K_w, \frac{\text{in\_channels}}{\text{groups}})`

bias: optional bias tensor of shape :math:`(\text{out\_channels})`

This operator uses "channels_last" data format. Below is an example and its equivalence in PyTorch:

.. highlight:: python
.. code-block:: python

X = Tensor(shape=[N, H, W, C_in], dtype="float16", name="images", is_input=True)
W = Tensor(shape=[C_out, K_h, K_w, C_in], dtype="float16", name="weight", is_input=True)
B = Tensor(shape=[C_out], dtype="float16", name="weight", is_input=True)
OP = aitemplate.compiler.ops.conv2d_bias_sigmoid(stride=1, pad=1, dilate=1)
Result_ait = OP(X, W, B)

.. highlight:: python
.. code-block:: python

X_pt = NHWC2NCHW(X_ait)
W_pt = NHWC2NCHW(W_ait)
B_pt = NHWC2NCHW(B_ait)

Y = torch.nn.functional.conv2d(X_pt, W_pt, bias=B_pt)
Result_pt =torch.nn.functional.silu(Y)
Result_ait = NCHW2NHWC(Result_pt)
"""

def __init__(self, stride, pad, dilate=1, group=1) -> None:
"""Conv2d_bias_silu constructor.

Parameters
----------
stride : int
Stride of the convolution
pad : int
Size of padding to add to the input
dilate : int, optional
Size of spacing between kernel elements, by default 1
group : int, optional
Number of input channels to process to compute one output channel, by default 1
"""
super().__init__("silu", stride, pad, dilate=dilate, group=group)

def _get_op_attributes(self):
attr = super()._get_op_attributes()
del attr["activation"]

return attr
8 changes: 8 additions & 0 deletions python/aitemplate/compiler/transform/fuse_conv_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
conv2d_bias_relu,
conv2d_bias_relu_few_channels,
conv2d_bias_sigmoid,
conv2d_bias_silu,
transposed_conv2d,
transposed_conv2d_bias,
transposed_conv2d_bias_relu,
Expand Down Expand Up @@ -67,6 +68,13 @@ def get_conv2d_bias_elementwise_patterns():
),
conv2d_bias_sigmoid,
),
(
(
conv2d_bias(stride=1, pad=0),
elementwise(FuncEnum.SILU),
),
conv2d_bias_silu,
),
(
(
conv2d_bias(stride=1, pad=0),
Expand Down
Loading