diff --git a/xla/service/gpu/transforms/BUILD b/xla/service/gpu/transforms/BUILD index 52eb0df8a91c0..11939858745d4 100644 --- a/xla/service/gpu/transforms/BUILD +++ b/xla/service/gpu/transforms/BUILD @@ -1408,6 +1408,37 @@ xla_test( ], ) +cc_library( + name = "dot_normalizer", + srcs = ["dot_normalizer.cc"], + hdrs = ["dot_normalizer.h"], + deps = [ + "//xla:shape_util", + "//xla:util", + "//xla/hlo/ir:hlo", + "//xla/hlo/transforms:op_expander_pass", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@tsl//tsl/platform:errors", + ], +) + +xla_cc_test( + name = "dot_normalizer_test", + srcs = ["dot_normalizer_test.cc"], + deps = [ + ":dot_normalizer", + "//xla/hlo/ir:hlo", + "//xla/service:pattern_matcher", + "//xla/service:pattern_matcher_gmock", + "//xla/tests:hlo_test_base", + "@tsl//tsl/platform:status_matchers", + "@tsl//tsl/platform:statusor", + "@tsl//tsl/platform:test", + "@tsl//tsl/platform:test_main", + ], +) + cc_library( name = "dot_operand_converter", srcs = ["dot_operand_converter.cc"], diff --git a/xla/service/gpu/transforms/dot_normalizer.cc b/xla/service/gpu/transforms/dot_normalizer.cc new file mode 100644 index 0000000000000..77768f32711fe --- /dev/null +++ b/xla/service/gpu/transforms/dot_normalizer.cc @@ -0,0 +1,59 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/dot_normalizer.h" + +#include "absl/status/statusor.h" +#include "xla/hlo/ir/hlo_casting_utils.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_instructions.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "tsl/platform/errors.h" + +namespace xla::gpu { + +bool DotNormalizer::InstructionMatchesPattern(HloInstruction* instruction) { + if (instruction->opcode() != HloOpcode::kDot) { + return false; + } + return instruction->dot_dimension_numbers() + .lhs_contracting_dimensions() + .empty(); +} + +absl::StatusOr DotNormalizer::ExpandInstruction( + HloInstruction* instruction) { + HloDotInstruction* dot = Cast(instruction); + HloInstruction* lhs = dot->mutable_operand(0); + Shape new_lhs_shape = lhs->shape(); + ShapeUtil::AppendMinorDimension(1, &new_lhs_shape); + HloInstruction* normalized_lhs = + dot->AddInstruction(HloInstruction::CreateBitcast(new_lhs_shape, lhs)); + TF_RETURN_IF_ERROR(dot->ReplaceOperandWithDifferentShape(0, normalized_lhs)); + HloInstruction* rhs = dot->mutable_operand(1); + Shape new_rhs_shape = rhs->shape(); + ShapeUtil::AppendMinorDimension(1, &new_rhs_shape); + HloInstruction* normalized_rhs = + dot->AddInstruction(HloInstruction::CreateBitcast(new_rhs_shape, rhs)); + TF_RETURN_IF_ERROR(dot->ReplaceOperandWithDifferentShape(1, normalized_rhs)); + DotDimensionNumbers* dnums = dot->mutable_dot_dimension_numbers(); + dnums->add_lhs_contracting_dimensions(new_lhs_shape.rank() - 1); + dnums->add_rhs_contracting_dimensions(new_rhs_shape.rank() - 1); + return nullptr; +} + +} // namespace xla::gpu diff --git a/xla/service/gpu/transforms/dot_normalizer.h b/xla/service/gpu/transforms/dot_normalizer.h new file mode 100644 index 0000000000000..97e85229195f1 --- /dev/null +++ b/xla/service/gpu/transforms/dot_normalizer.h @@ -0,0 +1,48 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_TRANSFORMS_DOT_NORMALIZER_H_ +#define XLA_SERVICE_GPU_TRANSFORMS_DOT_NORMALIZER_H_ + +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/transforms/expanders/op_expander_pass.h" +#include "xla/util.h" + +namespace xla::gpu { + +// Ensures that a dot has at least 1 contracting dimension. If there are no +// contracting dimensions, a trivial 1-sized contracting dimension is added. +// This pass is expected to be run after layout assignment. +class DotNormalizer : public OpExpanderPass { + public: + explicit DotNormalizer(HloPredicate extra_filter = nullptr) + : OpExpanderPass(std::move(extra_filter)) {} + + absl::string_view name() const override { return "dot_normalizer"; } + + protected: + bool InstructionMatchesPattern(HloInstruction* instruction) override; + + absl::StatusOr ExpandInstruction( + HloInstruction* instruction) override; +}; + +} // namespace xla::gpu + +#endif // XLA_SERVICE_GPU_TRANSFORMS_DOT_NORMALIZER_H_ diff --git a/xla/service/gpu/transforms/dot_normalizer_test.cc b/xla/service/gpu/transforms/dot_normalizer_test.cc new file mode 100644 index 0000000000000..9242c1f73bf7a --- /dev/null +++ b/xla/service/gpu/transforms/dot_normalizer_test.cc @@ -0,0 +1,72 @@ +/* Copyright 2024 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/transforms/dot_normalizer.h" + +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/service/pattern_matcher.h" +#include "xla/service/pattern_matcher_gmock.h" +#include "xla/tests/hlo_test_base.h" +#include "tsl/platform/status_matchers.h" +#include "tsl/platform/statusor.h" +#include "tsl/platform/test.h" + +namespace xla::gpu { +namespace { + +namespace m = ::xla::match; + +using DotNormalizerTest = HloTestBase; +using ::tsl::testing::IsOkAndHolds; + +TEST_F(DotNormalizerTest, DotWithoutContractingDims) { + constexpr char kHlo[] = R"( + HloModule test + + ENTRY main { + p0 = f16[5,15]{1,0} parameter(0) + p1 = f16[5,16,17]{2,1,0} parameter(1) + ROOT r = f16[5,15,16,17]{3,2,1,0} dot(p0, p1), + lhs_batch_dims={0}, rhs_batch_dims={0} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kHlo)); + EXPECT_THAT(DotNormalizer().Run(m.get()), IsOkAndHolds(true)); + EXPECT_THAT( + m->entry_computation()->root_instruction(), + GmockMatch( + m::Dot(m::Bitcast().WithShape(F16, {5, 15, 1}, {2, 1, 0}), + m::Bitcast().WithShape(F16, {5, 16, 17, 1}, {3, 2, 1, 0})) + .WithContractingDims({2}, {3}))); +} + +TEST_F(DotNormalizerTest, DotWithContractingDims) { + constexpr char kHlo[] = R"( + HloModule test + + ENTRY main { + p0 = f16[5,15,3]{2,1,0} parameter(0) + p1 = f16[5,17,3]{2,1,0} parameter(1) + ROOT r = f16[5,15,17]{2,1,0} dot(p0, p1), + lhs_batch_dims={0}, lhs_contracting_dims={2}, + rhs_batch_dims={0}, rhs_contracting_dims={2} + } + )"; + TF_ASSERT_OK_AND_ASSIGN(auto m, ParseAndReturnVerifiedModule(kHlo)); + EXPECT_THAT(DotNormalizer().Run(m.get()), IsOkAndHolds(false)); +} + +} // namespace +} // namespace xla::gpu