From 082da51cd32a94144e37d15f13b9db22e6e0c5f1 Mon Sep 17 00:00:00 2001 From: Zoran Jovanovic Date: Tue, 22 Oct 2024 21:39:23 +0000 Subject: [PATCH] [ROCm] Fixed linker issues related to fp8 buffer_comparator functions --- xla/service/gpu/buffer_comparator.cu.cc | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/xla/service/gpu/buffer_comparator.cu.cc b/xla/service/gpu/buffer_comparator.cu.cc index b8e5a8e8d1e66..f550deeb9093e 100644 --- a/xla/service/gpu/buffer_comparator.cu.cc +++ b/xla/service/gpu/buffer_comparator.cu.cc @@ -108,6 +108,7 @@ __global__ void xla_fp8_e4m3fnuz_comparison(__hip_fp8_storage_t* buffer_a, float rel_error_threshold, uint64_t buffer_length, int* mismatch_count) { +#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx >= buffer_length) return; __hip_fp8_e4m3_fnuz elem_a_fp8, elem_b_fp8; @@ -123,6 +124,9 @@ __global__ void xla_fp8_e4m3fnuz_comparison(__hip_fp8_storage_t* buffer_a, if (rel_error > rel_error_threshold || isnan(rel_error)) atomicAdd(mismatch_count, 1); +#else + abort(); +#endif // defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) } __global__ void xla_fp8_e5m2fnuz_comparison(__hip_fp8_storage_t* buffer_a, @@ -130,6 +134,7 @@ __global__ void xla_fp8_e5m2fnuz_comparison(__hip_fp8_storage_t* buffer_a, float rel_error_threshold, uint64_t buffer_length, int* mismatch_count) { +#if (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx >= buffer_length) return; __hip_fp8_e5m2_fnuz elem_a_fp8, elem_b_fp8; @@ -145,6 +150,9 @@ __global__ void xla_fp8_e5m2fnuz_comparison(__hip_fp8_storage_t* buffer_a, if (rel_error > rel_error_threshold || isnan(rel_error)) atomicAdd(mismatch_count, 1); +#else + abort(); +#endif // defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) } #endif // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200