From b833e4569b9ade9214a85c19fef315b857d34f10 Mon Sep 17 00:00:00 2001 From: Paul Date: Sun, 29 Sep 2024 10:41:40 -0500 Subject: [PATCH 01/10] Add rewrite_dot --- src/CMakeLists.txt | 1 + src/include/migraphx/rewrite_dot.hpp | 21 ++++++ src/rewrite_dot.cpp | 97 ++++++++++++++++++++++++++++ src/targets/gpu/target.cpp | 2 + 4 files changed, 121 insertions(+) create mode 100644 src/include/migraphx/rewrite_dot.hpp create mode 100644 src/rewrite_dot.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 1ffe2eb9f4e..c1d744e7413 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -97,6 +97,7 @@ add_library(migraphx register_target.cpp replace_allocate.cpp rewrite_reduce.cpp + rewrite_dot.cpp simplify_qdq.cpp split_reduce.cpp sqlite.cpp diff --git a/src/include/migraphx/rewrite_dot.hpp b/src/include/migraphx/rewrite_dot.hpp new file mode 100644 index 00000000000..bf09ac3dcee --- /dev/null +++ b/src/include/migraphx/rewrite_dot.hpp @@ -0,0 +1,21 @@ +#ifndef MIGRAPHX_GUARD_MIGRAPHX_REWRITE_DOT_HPP +#define MIGRAPHX_GUARD_MIGRAPHX_REWRITE_DOT_HPP + +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +struct module; + +struct MIGRAPHX_EXPORT rewrite_dot +{ + std::string name() const { return "rewrite_dot"; } + void apply(module& m) const; +}; + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx +#endif // MIGRAPHX_GUARD_MIGRAPHX_REWRITE_DOT_HPP + diff --git a/src/rewrite_dot.cpp b/src/rewrite_dot.cpp new file mode 100644 index 00000000000..ccff2a7ded1 --- /dev/null +++ b/src/rewrite_dot.cpp @@ -0,0 +1,97 @@ +/* + * The MIT License (MIT) + * + * Copyright (c) 2015-2024 Advanced Micro Devices, Inc. All rights reserved. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + */ +#include +#include +#include +#include +#include + +namespace migraphx { +inline namespace MIGRAPHX_INLINE_NS { + +namespace { + +MIGRAPHX_PRED_MATCHER(conv_1x1, instruction_ref ins) +{ + if(ins->name() != "convolution") + return false; + auto v = ins->get_operator().to_value(); + if (not all_of(v.at("stride"), [](const value& x) { + return x.to() == 1; + })) + return false; + if (not all_of(v.at("padding"), [](const value& x) { + return x.to() == 0; + })) + return false; + if (not all_of(v.at("dilation"), [](const value& x) { + return x.to() == 1; + })) + return false; + auto w = ins->inputs().at(1)->get_shape(); + return std::all_of(w.lens().begin() + 2, w.lens().end(), [](std::size_t i) { return i == 1; }); +} + +struct find_1x1_convolution +{ + auto matcher() const { return conv_1x1(); } + + void apply(module& m, const match::matcher_result& r) const + { + auto ins = r.result; + + auto input = ins->inputs().front(); + auto weights = ins->inputs().back(); + auto m_dim = std::accumulate(input->get_shape().lens().begin()+2, input->get_shape().lens().end(), input->get_shape().lens().front(), std::multiplies<>{}); + auto n_dim = weights->get_shape().lens()[0]; + auto k_dim = weights->get_shape().lens()[1]; + + std::vector aperm(ins->get_shape().ndim()); + std::iota(aperm.begin(), aperm.end(), 0); + std::rotate(aperm.begin()+1, aperm.begin()+2, aperm.end()); + auto transpose = m.insert_instruction(ins, make_op("transpose", {{"permutation", aperm}}), input); + auto a_mat = m.insert_instruction(ins, make_op("reshape", {{"dims", {m_dim, k_dim}}}), transpose); + + auto reshape = m.insert_instruction(ins, make_op("reshape", {{"dims", {n_dim, k_dim}}}), weights); + auto b_mat = m.insert_instruction(ins, make_op("transpose", {{"permutation", {1, 0}}}), reshape); + + auto dot = m.insert_instruction(ins, make_op("dot"), a_mat, b_mat); + auto out_dims = transpose->get_shape().lens(); + out_dims.back() = n_dim; + auto creshape = m.insert_instruction(ins, make_op("reshape", {{"dims", out_dims}}), dot); + m.replace_instruction(ins, make_op("transpose", {{"permutation", invert_permutation(aperm)}}), creshape); + } +}; + + +} // namespace + +void rewrite_dot::apply(module& m) const +{ + match::find_matches(m, find_1x1_convolution{}); +} + +} // namespace MIGRAPHX_INLINE_NS +} // namespace migraphx diff --git a/src/targets/gpu/target.cpp b/src/targets/gpu/target.cpp index b190d11a402..2628c6f8dbd 100644 --- a/src/targets/gpu/target.cpp +++ b/src/targets/gpu/target.cpp @@ -43,6 +43,7 @@ #include #include #include +#include #include #include #include @@ -190,6 +191,7 @@ std::vector target::get_passes(migraphx::context& gctx, const compile_opti dead_code_elimination{}, rewrite_reduce{}, rewrite_low_precision{}, + rewrite_dot{}, dead_code_elimination{}, optimize_module{}, fuse_pointwise_reduce{}, From 3acd4ea60d42f5a6aa104385acaf12f5f58de7b1 Mon Sep 17 00:00:00 2001 From: Paul Date: Sun, 29 Sep 2024 10:44:46 -0500 Subject: [PATCH 02/10] Format --- src/include/migraphx/rewrite_dot.hpp | 1 - src/rewrite_dot.cpp | 56 ++++++++++++++-------------- 2 files changed, 27 insertions(+), 30 deletions(-) diff --git a/src/include/migraphx/rewrite_dot.hpp b/src/include/migraphx/rewrite_dot.hpp index bf09ac3dcee..f756cdf62b0 100644 --- a/src/include/migraphx/rewrite_dot.hpp +++ b/src/include/migraphx/rewrite_dot.hpp @@ -18,4 +18,3 @@ struct MIGRAPHX_EXPORT rewrite_dot } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx #endif // MIGRAPHX_GUARD_MIGRAPHX_REWRITE_DOT_HPP - diff --git a/src/rewrite_dot.cpp b/src/rewrite_dot.cpp index ccff2a7ded1..9f3d4b2b300 100644 --- a/src/rewrite_dot.cpp +++ b/src/rewrite_dot.cpp @@ -38,17 +38,11 @@ MIGRAPHX_PRED_MATCHER(conv_1x1, instruction_ref ins) if(ins->name() != "convolution") return false; auto v = ins->get_operator().to_value(); - if (not all_of(v.at("stride"), [](const value& x) { - return x.to() == 1; - })) + if(not all_of(v.at("stride"), [](const value& x) { return x.to() == 1; })) return false; - if (not all_of(v.at("padding"), [](const value& x) { - return x.to() == 0; - })) + if(not all_of(v.at("padding"), [](const value& x) { return x.to() == 0; })) return false; - if (not all_of(v.at("dilation"), [](const value& x) { - return x.to() == 1; - })) + if(not all_of(v.at("dilation"), [](const value& x) { return x.to() == 1; })) return false; auto w = ins->inputs().at(1)->get_shape(); return std::all_of(w.lens().begin() + 2, w.lens().end(), [](std::size_t i) { return i == 1; }); @@ -60,38 +54,42 @@ struct find_1x1_convolution void apply(module& m, const match::matcher_result& r) const { - auto ins = r.result; - - auto input = ins->inputs().front(); + auto ins = r.result; + + auto input = ins->inputs().front(); auto weights = ins->inputs().back(); - auto m_dim = std::accumulate(input->get_shape().lens().begin()+2, input->get_shape().lens().end(), input->get_shape().lens().front(), std::multiplies<>{}); - auto n_dim = weights->get_shape().lens()[0]; - auto k_dim = weights->get_shape().lens()[1]; + auto m_dim = std::accumulate(input->get_shape().lens().begin() + 2, + input->get_shape().lens().end(), + input->get_shape().lens().front(), + std::multiplies<>{}); + auto n_dim = weights->get_shape().lens()[0]; + auto k_dim = weights->get_shape().lens()[1]; std::vector aperm(ins->get_shape().ndim()); std::iota(aperm.begin(), aperm.end(), 0); - std::rotate(aperm.begin()+1, aperm.begin()+2, aperm.end()); - auto transpose = m.insert_instruction(ins, make_op("transpose", {{"permutation", aperm}}), input); - auto a_mat = m.insert_instruction(ins, make_op("reshape", {{"dims", {m_dim, k_dim}}}), transpose); + std::rotate(aperm.begin() + 1, aperm.begin() + 2, aperm.end()); + auto transpose = + m.insert_instruction(ins, make_op("transpose", {{"permutation", aperm}}), input); + auto a_mat = + m.insert_instruction(ins, make_op("reshape", {{"dims", {m_dim, k_dim}}}), transpose); - auto reshape = m.insert_instruction(ins, make_op("reshape", {{"dims", {n_dim, k_dim}}}), weights); - auto b_mat = m.insert_instruction(ins, make_op("transpose", {{"permutation", {1, 0}}}), reshape); + auto reshape = + m.insert_instruction(ins, make_op("reshape", {{"dims", {n_dim, k_dim}}}), weights); + auto b_mat = + m.insert_instruction(ins, make_op("transpose", {{"permutation", {1, 0}}}), reshape); - auto dot = m.insert_instruction(ins, make_op("dot"), a_mat, b_mat); - auto out_dims = transpose->get_shape().lens(); + auto dot = m.insert_instruction(ins, make_op("dot"), a_mat, b_mat); + auto out_dims = transpose->get_shape().lens(); out_dims.back() = n_dim; - auto creshape = m.insert_instruction(ins, make_op("reshape", {{"dims", out_dims}}), dot); - m.replace_instruction(ins, make_op("transpose", {{"permutation", invert_permutation(aperm)}}), creshape); + auto creshape = m.insert_instruction(ins, make_op("reshape", {{"dims", out_dims}}), dot); + m.replace_instruction( + ins, make_op("transpose", {{"permutation", invert_permutation(aperm)}}), creshape); } }; - } // namespace -void rewrite_dot::apply(module& m) const -{ - match::find_matches(m, find_1x1_convolution{}); -} +void rewrite_dot::apply(module& m) const { match::find_matches(m, find_1x1_convolution{}); } } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx From baaa8e10f0b53804f440331e97069e5d9c2a497a Mon Sep 17 00:00:00 2001 From: Paul Date: Sun, 29 Sep 2024 11:22:22 -0500 Subject: [PATCH 03/10] Add env var to enable --- src/rewrite_dot.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/rewrite_dot.cpp b/src/rewrite_dot.cpp index 9f3d4b2b300..b4a6cac40c9 100644 --- a/src/rewrite_dot.cpp +++ b/src/rewrite_dot.cpp @@ -31,6 +31,8 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_REWRITE_DOT); + namespace { MIGRAPHX_PRED_MATCHER(conv_1x1, instruction_ref ins) @@ -89,7 +91,12 @@ struct find_1x1_convolution } // namespace -void rewrite_dot::apply(module& m) const { match::find_matches(m, find_1x1_convolution{}); } +void rewrite_dot::apply(module& m) const +{ + if(not enabled(MIGRAPHX_ENABLE_REWRITE_DOT{})) + return; + match::find_matches(m, find_1x1_convolution{}); +} } // namespace MIGRAPHX_INLINE_NS } // namespace migraphx From 47048e88bfed77ea32b8a5b4479ee6ef295394b6 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 8 Oct 2024 10:25:57 -0500 Subject: [PATCH 04/10] Use batch gemm --- src/rewrite_dot.cpp | 29 ++++++++++++++++------------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/src/rewrite_dot.cpp b/src/rewrite_dot.cpp index b4a6cac40c9..f5958c427fd 100644 --- a/src/rewrite_dot.cpp +++ b/src/rewrite_dot.cpp @@ -52,7 +52,9 @@ MIGRAPHX_PRED_MATCHER(conv_1x1, instruction_ref ins) struct find_1x1_convolution { - auto matcher() const { return conv_1x1(); } + bool channels_last = true; + + auto matcher() const { return conv_1x1(match::arg(1)(match::is_constant())); } void apply(module& m, const match::matcher_result& r) const { @@ -70,22 +72,22 @@ struct find_1x1_convolution std::vector aperm(ins->get_shape().ndim()); std::iota(aperm.begin(), aperm.end(), 0); std::rotate(aperm.begin() + 1, aperm.begin() + 2, aperm.end()); - auto transpose = - m.insert_instruction(ins, make_op("transpose", {{"permutation", aperm}}), input); auto a_mat = - m.insert_instruction(ins, make_op("reshape", {{"dims", {m_dim, k_dim}}}), transpose); + m.insert_instruction(ins, make_op("transpose", {{"permutation", aperm}}), input); - auto reshape = - m.insert_instruction(ins, make_op("reshape", {{"dims", {n_dim, k_dim}}}), weights); - auto b_mat = - m.insert_instruction(ins, make_op("transpose", {{"permutation", {1, 0}}}), reshape); + std::vector sq_axes(ins->get_shape().ndim() - 2); + std::iota(sq_axes.begin(), sq_axes.end(), 2); + auto squeeze = + m.insert_instruction(ins, make_op("squeeze", {{"axes", sq_axes}}), weights); + auto transpose = + m.insert_instruction(ins, make_op("transpose", {{"permutation", {1, 0}}}), squeeze); + auto b_lens = a_mat->get_shape().lens(); + copy(transpose->get_shape().lens(), b_lens.end() - 2); + auto b_mat = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", b_lens}}), transpose); auto dot = m.insert_instruction(ins, make_op("dot"), a_mat, b_mat); - auto out_dims = transpose->get_shape().lens(); - out_dims.back() = n_dim; - auto creshape = m.insert_instruction(ins, make_op("reshape", {{"dims", out_dims}}), dot); m.replace_instruction( - ins, make_op("transpose", {{"permutation", invert_permutation(aperm)}}), creshape); + ins, make_op("transpose", {{"permutation", invert_permutation(aperm)}}), dot); } }; @@ -95,7 +97,8 @@ void rewrite_dot::apply(module& m) const { if(not enabled(MIGRAPHX_ENABLE_REWRITE_DOT{})) return; - match::find_matches(m, find_1x1_convolution{}); + match::find_matches(m, find_1x1_convolution{}); + // m.debug_print(); } } // namespace MIGRAPHX_INLINE_NS From d54ae5da187c2908ea68b80e20a32a5d51139a73 Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 8 Oct 2024 10:55:12 -0500 Subject: [PATCH 05/10] Support nchw layout as well --- src/rewrite_dot.cpp | 60 ++++++++++++++++++++++++++------------------- 1 file changed, 35 insertions(+), 25 deletions(-) diff --git a/src/rewrite_dot.cpp b/src/rewrite_dot.cpp index f5958c427fd..1ef87ef5070 100644 --- a/src/rewrite_dot.cpp +++ b/src/rewrite_dot.cpp @@ -52,8 +52,6 @@ MIGRAPHX_PRED_MATCHER(conv_1x1, instruction_ref ins) struct find_1x1_convolution { - bool channels_last = true; - auto matcher() const { return conv_1x1(match::arg(1)(match::is_constant())); } void apply(module& m, const match::matcher_result& r) const @@ -62,32 +60,45 @@ struct find_1x1_convolution auto input = ins->inputs().front(); auto weights = ins->inputs().back(); - auto m_dim = std::accumulate(input->get_shape().lens().begin() + 2, - input->get_shape().lens().end(), - input->get_shape().lens().front(), - std::multiplies<>{}); - auto n_dim = weights->get_shape().lens()[0]; - auto k_dim = weights->get_shape().lens()[1]; - - std::vector aperm(ins->get_shape().ndim()); - std::iota(aperm.begin(), aperm.end(), 0); - std::rotate(aperm.begin() + 1, aperm.begin() + 2, aperm.end()); - auto a_mat = - m.insert_instruction(ins, make_op("transpose", {{"permutation", aperm}}), input); - + std::vector sq_axes(ins->get_shape().ndim() - 2); std::iota(sq_axes.begin(), sq_axes.end(), 2); - auto squeeze = + auto sq_weights = m.insert_instruction(ins, make_op("squeeze", {{"axes", sq_axes}}), weights); - auto transpose = - m.insert_instruction(ins, make_op("transpose", {{"permutation", {1, 0}}}), squeeze); - auto b_lens = a_mat->get_shape().lens(); - copy(transpose->get_shape().lens(), b_lens.end() - 2); - auto b_mat = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", b_lens}}), transpose); - auto dot = m.insert_instruction(ins, make_op("dot"), a_mat, b_mat); - m.replace_instruction( - ins, make_op("transpose", {{"permutation", invert_permutation(aperm)}}), dot); + if(ins->get_shape().transposed()) + { + std::vector aperm(ins->get_shape().ndim()); + std::iota(aperm.begin(), aperm.end(), 0); + std::rotate(aperm.begin() + 1, aperm.begin() + 2, aperm.end()); + auto a_mat = + m.insert_instruction(ins, make_op("transpose", {{"permutation", aperm}}), input); + + auto transpose = + m.insert_instruction(ins, make_op("transpose", {{"permutation", {1, 0}}}), sq_weights); + auto b_lens = a_mat->get_shape().lens(); + copy(transpose->get_shape().lens(), b_lens.end() - 2); + auto b_mat = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", b_lens}}), transpose); + + auto dot = m.insert_instruction(ins, make_op("dot"), a_mat, b_mat); + m.replace_instruction( + ins, make_op("transpose", {{"permutation", invert_permutation(aperm)}}), dot); + } + else + { + auto batch_dim = ins->get_shape().lens().front(); + auto m_dim = std::accumulate(input->get_shape().lens().begin() + 2, + input->get_shape().lens().end(), + 1, + std::multiplies<>{}); + auto n_dim = weights->get_shape().lens()[0]; + auto k_dim = weights->get_shape().lens()[1]; + auto a_mat = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", {batch_dim, n_dim, k_dim}}}), sq_weights); + auto b_mat = m.insert_instruction(ins, make_op("reshape", {{"dims", {batch_dim, k_dim, m_dim}}}), input); + auto dot = m.insert_instruction(ins, make_op("dot"), a_mat, b_mat); + m.replace_instruction(ins, make_op("reshape", {{"dims", ins->get_shape().lens()}}), dot); + + } } }; @@ -98,7 +109,6 @@ void rewrite_dot::apply(module& m) const if(not enabled(MIGRAPHX_ENABLE_REWRITE_DOT{})) return; match::find_matches(m, find_1x1_convolution{}); - // m.debug_print(); } } // namespace MIGRAPHX_INLINE_NS From 2f6e0525cfed4327bc3234b060002f6ae3146c3a Mon Sep 17 00:00:00 2001 From: Paul Date: Tue, 8 Oct 2024 10:55:31 -0500 Subject: [PATCH 06/10] Format --- src/rewrite_dot.cpp | 35 ++++++++++++++++++++--------------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/src/rewrite_dot.cpp b/src/rewrite_dot.cpp index 1ef87ef5070..1e4d76e378d 100644 --- a/src/rewrite_dot.cpp +++ b/src/rewrite_dot.cpp @@ -60,7 +60,7 @@ struct find_1x1_convolution auto input = ins->inputs().front(); auto weights = ins->inputs().back(); - + std::vector sq_axes(ins->get_shape().ndim() - 2); std::iota(sq_axes.begin(), sq_axes.end(), 2); auto sq_weights = @@ -74,37 +74,42 @@ struct find_1x1_convolution auto a_mat = m.insert_instruction(ins, make_op("transpose", {{"permutation", aperm}}), input); - auto transpose = - m.insert_instruction(ins, make_op("transpose", {{"permutation", {1, 0}}}), sq_weights); - auto b_lens = a_mat->get_shape().lens(); + auto transpose = m.insert_instruction( + ins, make_op("transpose", {{"permutation", {1, 0}}}), sq_weights); + auto b_lens = a_mat->get_shape().lens(); copy(transpose->get_shape().lens(), b_lens.end() - 2); - auto b_mat = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", b_lens}}), transpose); + auto b_mat = m.insert_instruction( + ins, make_op("multibroadcast", {{"out_lens", b_lens}}), transpose); - auto dot = m.insert_instruction(ins, make_op("dot"), a_mat, b_mat); + auto dot = m.insert_instruction(ins, make_op("dot"), a_mat, b_mat); m.replace_instruction( ins, make_op("transpose", {{"permutation", invert_permutation(aperm)}}), dot); } else { auto batch_dim = ins->get_shape().lens().front(); - auto m_dim = std::accumulate(input->get_shape().lens().begin() + 2, + auto m_dim = std::accumulate(input->get_shape().lens().begin() + 2, input->get_shape().lens().end(), 1, std::multiplies<>{}); - auto n_dim = weights->get_shape().lens()[0]; - auto k_dim = weights->get_shape().lens()[1]; - auto a_mat = m.insert_instruction(ins, make_op("multibroadcast", {{"out_lens", {batch_dim, n_dim, k_dim}}}), sq_weights); - auto b_mat = m.insert_instruction(ins, make_op("reshape", {{"dims", {batch_dim, k_dim, m_dim}}}), input); - auto dot = m.insert_instruction(ins, make_op("dot"), a_mat, b_mat); - m.replace_instruction(ins, make_op("reshape", {{"dims", ins->get_shape().lens()}}), dot); - + auto n_dim = weights->get_shape().lens()[0]; + auto k_dim = weights->get_shape().lens()[1]; + auto a_mat = m.insert_instruction( + ins, + make_op("multibroadcast", {{"out_lens", {batch_dim, n_dim, k_dim}}}), + sq_weights); + auto b_mat = m.insert_instruction( + ins, make_op("reshape", {{"dims", {batch_dim, k_dim, m_dim}}}), input); + auto dot = m.insert_instruction(ins, make_op("dot"), a_mat, b_mat); + m.replace_instruction( + ins, make_op("reshape", {{"dims", ins->get_shape().lens()}}), dot); } } }; } // namespace -void rewrite_dot::apply(module& m) const +void rewrite_dot::apply(module& m) const { if(not enabled(MIGRAPHX_ENABLE_REWRITE_DOT{})) return; From 81cb444f8d3ad0ef0bd5a0afd3d99bae9be3da00 Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 24 Oct 2024 11:22:36 -0500 Subject: [PATCH 07/10] Use different heuristic --- src/targets/gpu/fuse_mlir.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 1b759a89c7c..a3c7512e982 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -275,13 +275,18 @@ auto is_mlir_dot(mlir_mode mode) return true; auto a = ins->inputs().front()->get_shape(); auto b = ins->inputs().back()->get_shape(); - // auto m = a.lens()[a.lens().size() - 2]; - // auto n = b.lens().back(); + auto g = std::accumulate(a.lens().begin(), a.lens().end() - 2, 1, std::multiplies<>{}); + auto m = a.lens()[a.lens().size() - 2]; + auto n = b.lens().back(); auto k = a.lens().back(); // Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy // to avoid poor-performing GEMM kernels from MLIR // To-do: Investigate a more precise strategy - return k <= 1024; + if (k > 2048) + return false; + if (k < 1024) + return true; + return (g*m*n) < (384*384); }); } From a1540d02b6e6641162713e7eba1184df56789ff9 Mon Sep 17 00:00:00 2001 From: Paul Date: Thu, 24 Oct 2024 17:40:21 -0500 Subject: [PATCH 08/10] Adjust threshold --- src/targets/gpu/fuse_mlir.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index a3c7512e982..5f4e26742f0 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -282,7 +282,7 @@ auto is_mlir_dot(mlir_mode mode) // Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy // to avoid poor-performing GEMM kernels from MLIR // To-do: Investigate a more precise strategy - if (k > 2048) + if (k > 1535) return false; if (k < 1024) return true; From 7b8978ff7faba3129cd57ac14d177012def8637c Mon Sep 17 00:00:00 2001 From: Paul Date: Sat, 26 Oct 2024 13:11:38 -0700 Subject: [PATCH 09/10] Add unit tests --- src/rewrite_dot.cpp | 4 +-- test/rewrite_dot.cpp | 68 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 2 deletions(-) create mode 100644 test/rewrite_dot.cpp diff --git a/src/rewrite_dot.cpp b/src/rewrite_dot.cpp index 1e4d76e378d..91ee8404758 100644 --- a/src/rewrite_dot.cpp +++ b/src/rewrite_dot.cpp @@ -31,7 +31,7 @@ namespace migraphx { inline namespace MIGRAPHX_INLINE_NS { -MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_ENABLE_REWRITE_DOT); +MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DISABLE_REWRITE_DOT); namespace { @@ -111,7 +111,7 @@ struct find_1x1_convolution void rewrite_dot::apply(module& m) const { - if(not enabled(MIGRAPHX_ENABLE_REWRITE_DOT{})) + if(enabled(MIGRAPHX_DISABLE_REWRITE_DOT{})) return; match::find_matches(m, find_1x1_convolution{}); } diff --git a/test/rewrite_dot.cpp b/test/rewrite_dot.cpp new file mode 100644 index 00000000000..f5b90ef97c5 --- /dev/null +++ b/test/rewrite_dot.cpp @@ -0,0 +1,68 @@ + +#include +#include +#include +#include +#include +#include +#include + +void run_pass(migraphx::module& m) +{ + migraphx::run_passes(m, {migraphx::rewrite_dot{}, migraphx::dead_code_elimination{}}); +} + +TEST_CASE(nchw_conv_1x1) +{ + migraphx::shape s1{migraphx::shape::float_type, {64, 128, 28, 28}}; + migraphx::shape s2{migraphx::shape::float_type, {512, 128, 1, 1}}; + migraphx::module m1; + { + auto x = m1.add_parameter("x", s1); + auto w = m1.add_literal(migraphx::generate_literal(s2)); + auto conv = m1.add_instruction(migraphx::make_op("convolution"), x, w); + m1.add_return({conv}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", s1); + auto w = m2.add_literal(migraphx::generate_literal(s2)); + auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {2, 3}}}), w); + auto broadcast = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64, 512, 128}}}), squeeze); + auto reshape1 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {64, 128, 784}}}), x); + auto dot = m2.add_instruction(migraphx::make_op("dot"), broadcast, reshape1); + auto reshape2 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {64, 512, 28, 28}}}), dot); + m2.add_return({reshape2}); + } + EXPECT(m1.sort() == m2.sort()); +} + +TEST_CASE(nhwc_conv_1x1) +{ + auto s1 = migraphx::shape::from_permutation(migraphx::shape::float_type, {64, 128, 28, 28}, {0, 2, 3, 1}); + auto s2 = migraphx::shape::from_permutation(migraphx::shape::float_type, {512, 128, 1, 1}, {0, 2, 3, 1}); + migraphx::module m1; + { + auto x = m1.add_parameter("x", s1); + auto w = m1.add_literal(migraphx::generate_literal(s2)); + auto conv = m1.add_instruction(migraphx::make_op("convolution"), x, w); + m1.add_return({conv}); + } + run_pass(m1); + migraphx::module m2; + { + auto x = m2.add_parameter("x", s1); + auto w = m2.add_literal(migraphx::generate_literal(s2)); + auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {2, 3}}}), w); + auto transpose1 = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), squeeze); + auto broadcast = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64, 28, 128, 512}}}), transpose1); + auto transpose2 = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x); + auto dot = m2.add_instruction(migraphx::make_op("dot"), transpose2, broadcast); + auto transpose3 = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), dot); + m2.add_return({transpose3}); + } + EXPECT(m1.sort() == m2.sort()); +} + +int main(int argc, const char* argv[]) { test::run(argc, argv); } From 84a941cf31f0645298e0afc5faa4976addb59031 Mon Sep 17 00:00:00 2001 From: Paul Date: Sat, 26 Oct 2024 13:11:46 -0700 Subject: [PATCH 10/10] Format --- test/rewrite_dot.cpp | 47 ++++++++++++++++++++++++++------------------ 1 file changed, 28 insertions(+), 19 deletions(-) diff --git a/test/rewrite_dot.cpp b/test/rewrite_dot.cpp index f5b90ef97c5..54af2030f77 100644 --- a/test/rewrite_dot.cpp +++ b/test/rewrite_dot.cpp @@ -18,21 +18,24 @@ TEST_CASE(nchw_conv_1x1) migraphx::shape s2{migraphx::shape::float_type, {512, 128, 1, 1}}; migraphx::module m1; { - auto x = m1.add_parameter("x", s1); - auto w = m1.add_literal(migraphx::generate_literal(s2)); + auto x = m1.add_parameter("x", s1); + auto w = m1.add_literal(migraphx::generate_literal(s2)); auto conv = m1.add_instruction(migraphx::make_op("convolution"), x, w); m1.add_return({conv}); } run_pass(m1); migraphx::module m2; { - auto x = m2.add_parameter("x", s1); - auto w = m2.add_literal(migraphx::generate_literal(s2)); - auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {2, 3}}}), w); - auto broadcast = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64, 512, 128}}}), squeeze); - auto reshape1 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {64, 128, 784}}}), x); + auto x = m2.add_parameter("x", s1); + auto w = m2.add_literal(migraphx::generate_literal(s2)); + auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {2, 3}}}), w); + auto broadcast = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {64, 512, 128}}}), squeeze); + auto reshape1 = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {64, 128, 784}}}), x); auto dot = m2.add_instruction(migraphx::make_op("dot"), broadcast, reshape1); - auto reshape2 = m2.add_instruction(migraphx::make_op("reshape", {{"dims", {64, 512, 28, 28}}}), dot); + auto reshape2 = + m2.add_instruction(migraphx::make_op("reshape", {{"dims", {64, 512, 28, 28}}}), dot); m2.add_return({reshape2}); } EXPECT(m1.sort() == m2.sort()); @@ -40,26 +43,32 @@ TEST_CASE(nchw_conv_1x1) TEST_CASE(nhwc_conv_1x1) { - auto s1 = migraphx::shape::from_permutation(migraphx::shape::float_type, {64, 128, 28, 28}, {0, 2, 3, 1}); - auto s2 = migraphx::shape::from_permutation(migraphx::shape::float_type, {512, 128, 1, 1}, {0, 2, 3, 1}); + auto s1 = migraphx::shape::from_permutation( + migraphx::shape::float_type, {64, 128, 28, 28}, {0, 2, 3, 1}); + auto s2 = migraphx::shape::from_permutation( + migraphx::shape::float_type, {512, 128, 1, 1}, {0, 2, 3, 1}); migraphx::module m1; { - auto x = m1.add_parameter("x", s1); - auto w = m1.add_literal(migraphx::generate_literal(s2)); + auto x = m1.add_parameter("x", s1); + auto w = m1.add_literal(migraphx::generate_literal(s2)); auto conv = m1.add_instruction(migraphx::make_op("convolution"), x, w); m1.add_return({conv}); } run_pass(m1); migraphx::module m2; { - auto x = m2.add_parameter("x", s1); - auto w = m2.add_literal(migraphx::generate_literal(s2)); + auto x = m2.add_parameter("x", s1); + auto w = m2.add_literal(migraphx::generate_literal(s2)); auto squeeze = m2.add_instruction(migraphx::make_op("squeeze", {{"axes", {2, 3}}}), w); - auto transpose1 = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), squeeze); - auto broadcast = m2.add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {64, 28, 128, 512}}}), transpose1); - auto transpose2 = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x); - auto dot = m2.add_instruction(migraphx::make_op("dot"), transpose2, broadcast); - auto transpose3 = m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), dot); + auto transpose1 = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {1, 0}}}), squeeze); + auto broadcast = m2.add_instruction( + migraphx::make_op("multibroadcast", {{"out_lens", {64, 28, 128, 512}}}), transpose1); + auto transpose2 = + m2.add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 2, 3, 1}}}), x); + auto dot = m2.add_instruction(migraphx::make_op("dot"), transpose2, broadcast); + auto transpose3 = m2.add_instruction( + migraphx::make_op("transpose", {{"permutation", {0, 3, 1, 2}}}), dot); m2.add_return({transpose3}); } EXPECT(m1.sort() == m2.sort());