From 1c607bd7934f977ce661e6904ad8c74895ef8ac6 Mon Sep 17 00:00:00 2001 From: Harsha HS Date: Wed, 31 Jul 2024 06:06:31 -0700 Subject: [PATCH] Fix build break due to 1. Update OUTPUT_PATH in third_party/xla/third_party/py/python_repo.bzl Upstreamed these changes https://github.com/openxla/xla/pull/15557 2. Use jit=True for kernel generation in tensorflow/core/kernels/mlir_generated/build_defs.bzl 3. Correct triton API name in third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc Upstreamed using https://github.com/openxla/xla/pull/15477 --- tensorflow/core/kernels/mlir_generated/build_defs.bzl | 4 ++-- third_party/xla/third_party/py/python_repo.bzl | 4 ++++ .../service/gpu/fusions/triton/compilation_pipeline_rocm.cc | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/kernels/mlir_generated/build_defs.bzl b/tensorflow/core/kernels/mlir_generated/build_defs.bzl index f574a8da8fd99d..09ea0cffa7687b 100644 --- a/tensorflow/core/kernels/mlir_generated/build_defs.bzl +++ b/tensorflow/core/kernels/mlir_generated/build_defs.bzl @@ -162,8 +162,8 @@ def _gen_kernel_bin_impl(ctx): "--input=%s" % ctx.file.mlir_op.path, "--output=%s" % gpu_bin.path, "--enable_ftz=%s" % (ctx.attr.data_type == "f32"), - "--jit_i64_indexed_for_large_tensors=%s" % ctx.attr.jit_i64_indexed_for_large_tensors, - "--jit=%s" % ctx.attr.jit, + "--jit_i64_indexed_for_large_tensors=%s" % "False", + "--jit=%s" % "True", ], use_default_shell_env = True, mnemonic = "compile", diff --git a/third_party/xla/third_party/py/python_repo.bzl b/third_party/xla/third_party/py/python_repo.bzl index f8fdd1033b5e2f..85dbda9c62f11e 100644 --- a/third_party/xla/third_party/py/python_repo.bzl +++ b/third_party/xla/third_party/py/python_repo.bzl @@ -14,6 +14,7 @@ def _python_repository_impl(ctx): ctx.file("BUILD", "") wheel_name = ctx.os.environ.get("WHEEL_NAME", "tensorflow") wheel_collab = ctx.os.environ.get("WHEEL_COLLAB", False) + output_path = ctx.os.environ.get("OUTPUT_PATH", None) requirements = None for i in range(0, len(ctx.attr.requirements_locks)): @@ -62,12 +63,14 @@ TF_PYTHON_VERSION = "{version}" HERMETIC_PYTHON_VERSION = "{version}" WHEEL_NAME = "{wheel_name}" WHEEL_COLLAB = "{wheel_collab}" +OUTPUT_PATH = "{output_path}" REQUIREMENTS = "{requirements}" REQUIREMENTS_WITH_LOCAL_WHEELS = "{requirements_with_local_wheels}" """.format( version = version, wheel_name = wheel_name, wheel_collab = wheel_collab, + output_path = output_path, requirements = str(requirements), requirements_with_local_wheels = requirements_with_local_wheels, ), @@ -200,6 +203,7 @@ python_repository = repository_rule( "HERMETIC_PYTHON_VERSION", "WHEEL_NAME", "WHEEL_COLLAB", + "OUTPUT_PATH", ], local = True, ) diff --git a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc index 2c12aafb9ac536..e31e29b0e0d29c 100644 --- a/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc +++ b/third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc @@ -107,7 +107,7 @@ absl::Status CreateTritonPipeline( pm.addPass(mlir::triton::AMD::createDecomposeUnsupportedConversionsPass( ccRocm.gfx_version())); const int custom_lds_size = 0; - pm.addPass(mlir::triton::AMD::createOptimizeLdsUsagePass(ccRocm.gfx_version(), + pm.addPass(mlir::triton::AMD::createOptimizeLDSUsagePass(ccRocm.gfx_version(), custom_lds_size)); pm.addPass(mlir::createConvertSCFToCFPass()); pm.addPass(mlir::createConvertIndexToLLVMPass());