Skip to content

Commit

Permalink
Fix build break due to
Browse files Browse the repository at this point in the history
1. Update OUTPUT_PATH in third_party/xla/third_party/py/python_repo.bzl
Upstreamed these changes openxla/xla#15557
2. Use jit=True for kernel generation in tensorflow/core/kernels/mlir_generated/build_defs.bzl
3. Correct triton API name in  third_party/xla/xla/service/gpu/fusions/triton/compilation_pipeline_rocm.cc
Upstreamed using openxla/xla#15477
  • Loading branch information
hsharsha committed Jul 31, 2024
1 parent d88d57c commit 1c607bd
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 3 deletions.
4 changes: 2 additions & 2 deletions tensorflow/core/kernels/mlir_generated/build_defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,8 @@ def _gen_kernel_bin_impl(ctx):
"--input=%s" % ctx.file.mlir_op.path,
"--output=%s" % gpu_bin.path,
"--enable_ftz=%s" % (ctx.attr.data_type == "f32"),
"--jit_i64_indexed_for_large_tensors=%s" % ctx.attr.jit_i64_indexed_for_large_tensors,
"--jit=%s" % ctx.attr.jit,
"--jit_i64_indexed_for_large_tensors=%s" % "False",
"--jit=%s" % "True",
],
use_default_shell_env = True,
mnemonic = "compile",
Expand Down
4 changes: 4 additions & 0 deletions third_party/xla/third_party/py/python_repo.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def _python_repository_impl(ctx):
ctx.file("BUILD", "")
wheel_name = ctx.os.environ.get("WHEEL_NAME", "tensorflow")
wheel_collab = ctx.os.environ.get("WHEEL_COLLAB", False)
output_path = ctx.os.environ.get("OUTPUT_PATH", None)

requirements = None
for i in range(0, len(ctx.attr.requirements_locks)):
Expand Down Expand Up @@ -62,12 +63,14 @@ TF_PYTHON_VERSION = "{version}"
HERMETIC_PYTHON_VERSION = "{version}"
WHEEL_NAME = "{wheel_name}"
WHEEL_COLLAB = "{wheel_collab}"
OUTPUT_PATH = "{output_path}"
REQUIREMENTS = "{requirements}"
REQUIREMENTS_WITH_LOCAL_WHEELS = "{requirements_with_local_wheels}"
""".format(
version = version,
wheel_name = wheel_name,
wheel_collab = wheel_collab,
output_path = output_path,
requirements = str(requirements),
requirements_with_local_wheels = requirements_with_local_wheels,
),
Expand Down Expand Up @@ -200,6 +203,7 @@ python_repository = repository_rule(
"HERMETIC_PYTHON_VERSION",
"WHEEL_NAME",
"WHEEL_COLLAB",
"OUTPUT_PATH",
],
local = True,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ absl::Status CreateTritonPipeline(
pm.addPass(mlir::triton::AMD::createDecomposeUnsupportedConversionsPass(
ccRocm.gfx_version()));
const int custom_lds_size = 0;
pm.addPass(mlir::triton::AMD::createOptimizeLdsUsagePass(ccRocm.gfx_version(),
pm.addPass(mlir::triton::AMD::createOptimizeLDSUsagePass(ccRocm.gfx_version(),
custom_lds_size));
pm.addPass(mlir::createConvertSCFToCFPass());
pm.addPass(mlir::createConvertIndexToLLVMPass());
Expand Down

0 comments on commit 1c607bd

Please sign in to comment.