From 4f16089cfce6a583b1f9fdfad26fc7b5982537da Mon Sep 17 00:00:00 2001 From: scxfjiang Date: Mon, 21 Oct 2024 07:36:49 -0700 Subject: [PATCH] add bf16 support to fused_batchnorm, conv, pooling, and topk --- .../xla/stream_executor/rocm/rocm_dnn.cc | 44 ++ .../xla/stream_executor/rocm/rocm_dnn.h | 30 ++ tensorflow/core/kernels/batch_norm_op_test.cc | 91 ++-- .../core/kernels/conv_grad_filter_ops_3d.cc | 459 +++++++++--------- .../kernels/conv_grad_filter_ops_launcher.cc | 6 +- .../core/kernels/conv_grad_input_ops.cc | 9 +- .../core/kernels/conv_grad_input_ops_3d.cc | 9 +- tensorflow/core/kernels/conv_ops_3d.cc | 10 +- tensorflow/core/kernels/conv_ops_bfloat16.cc | 6 +- tensorflow/core/kernels/cudnn_pooling_gpu.cc | 10 +- tensorflow/core/kernels/depthwise_conv_op.cc | 6 +- .../core/kernels/fused_batch_norm_op.cc | 11 +- tensorflow/core/kernels/gpu_utils.cc | 15 + tensorflow/core/kernels/gpu_utils.h | 4 + tensorflow/core/kernels/pooling_ops_common.cc | 6 +- tensorflow/core/kernels/topk_op_gpu.h | 29 +- tensorflow/python/BUILD | 1 - .../kernel_tests/nn_ops/conv_ops_test.py | 43 +- .../nn_ops/depthwise_conv_op_base.py | 83 ++-- .../python/ops/nn_fused_batchnorm_test.py | 4 +- 20 files changed, 441 insertions(+), 435 deletions(-) diff --git a/tensorflow/compiler/xla/stream_executor/rocm/rocm_dnn.cc b/tensorflow/compiler/xla/stream_executor/rocm/rocm_dnn.cc index f8c41660b6502c..8b56ab0d26a6d2 100644 --- a/tensorflow/compiler/xla/stream_executor/rocm/rocm_dnn.cc +++ b/tensorflow/compiler/xla/stream_executor/rocm/rocm_dnn.cc @@ -46,6 +46,8 @@ limitations under the License. #include "tensorflow/tsl/util/determinism.h" #include "tensorflow/tsl/util/env_var.h" #include "rocm/rocm_config.h" +#include +#include namespace { @@ -3806,6 +3808,28 @@ bool MIOpenSupport::GetRnnAlgorithms( return true; } +bool MIOpenSupport::DoBatchNormalizationForward( + Stream* stream, const DeviceMemory& x, + const DeviceMemory& scale, const DeviceMemory& offset, + const DeviceMemory& estimated_mean, + const DeviceMemory& estimated_variance, + const DeviceMemory& side_input, + const dnn::BatchDescriptor& x_desc, + const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, + const double exponential_average_factor, + dnn::ActivationMode activation_mode, DeviceMemory* y, + DeviceMemory* batch_mean, DeviceMemory* batch_var, + DeviceMemory* saved_mean, DeviceMemory* saved_inv_var, + bool is_training, ScratchAllocator* reserve_space_allocator, + ScratchAllocator* workspace_allocator) { + + return DoBatchNormalizationForwardImpl( + stream, dnn::DataType::kBF16, dnn::DataType::kFloat, x, scale, offset, + estimated_mean, estimated_variance, side_input, x_desc, scale_offset_desc, + epsilon, exponential_average_factor, activation_mode, y, batch_mean, + batch_var, saved_mean, saved_inv_var, is_training); +} + bool MIOpenSupport::DoBatchNormalizationForward( Stream* stream, const DeviceMemory& x, const DeviceMemory& scale, const DeviceMemory& offset, @@ -3896,6 +3920,26 @@ bool MIOpenSupport::DoBatchNormalizationForwardImpl( return true; } +bool MIOpenSupport::DoBatchNormalizationBackward( + Stream* stream, const DeviceMemory& y_backprop, + const DeviceMemory& x, const DeviceMemory& scale, + const DeviceMemory& offset, const DeviceMemory& mean, + const DeviceMemory& inv_var, const DeviceMemory& y, + const dnn::BatchDescriptor& x_desc, + const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, + dnn::ActivationMode activation_mode, + DeviceMemory* x_backprop, + DeviceMemory* scale_backprop, DeviceMemory* offset_backprop, + DeviceMemory* side_input_backprop, + DeviceMemory* reserve_space_data, + ScratchAllocator* workspace_allocator) { + +return DoBatchNormalizationBackwardImpl( + stream, miopenBFloat16, miopenFloat, y_backprop, x, scale, mean, inv_var, + x_desc, scale_offset_desc, epsilon, x_backprop, scale_backprop, + offset_backprop); +} + bool MIOpenSupport::DoBatchNormalizationBackward( Stream* stream, const DeviceMemory& y_backprop, const DeviceMemory& x, const DeviceMemory& scale, diff --git a/tensorflow/compiler/xla/stream_executor/rocm/rocm_dnn.h b/tensorflow/compiler/xla/stream_executor/rocm/rocm_dnn.h index 5b5ea70f456a5e..42a35637d06039 100644 --- a/tensorflow/compiler/xla/stream_executor/rocm/rocm_dnn.h +++ b/tensorflow/compiler/xla/stream_executor/rocm/rocm_dnn.h @@ -298,6 +298,21 @@ class MIOpenSupport : public dnn::DnnSupport { bool is_training, ScratchAllocator* reserve_space_allocator, ScratchAllocator* workspace_allocator) override; + bool DoBatchNormalizationForward( + Stream* stream, const DeviceMemory& x, + const DeviceMemory& scale, const DeviceMemory& offset, + const DeviceMemory& estimated_mean, + const DeviceMemory& estimated_variance, + const DeviceMemory& side_input, + const dnn::BatchDescriptor& x_desc, + const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, + const double exponential_average_factor, + dnn::ActivationMode activation_mode, DeviceMemory* y, + DeviceMemory* batch_mean, DeviceMemory* batch_var, + DeviceMemory* saved_mean, DeviceMemory* saved_inv_var, + bool is_training, ScratchAllocator* reserve_space_allocator, + ScratchAllocator* workspace_allocator) override; + bool DoBatchNormalizationBackward( Stream* stream, const DeviceMemory& y_backprop, const DeviceMemory& x, const DeviceMemory& scale, @@ -325,6 +340,21 @@ class MIOpenSupport : public dnn::DnnSupport { DeviceMemory* reserve_space_data, ScratchAllocator* workspace_allocator) override; + bool DoBatchNormalizationBackward( + Stream* stream, const DeviceMemory& y_backprop, + const DeviceMemory& x, const DeviceMemory& scale, + const DeviceMemory& offset, const DeviceMemory& mean, + const DeviceMemory& inv_var, + const DeviceMemory& y, + const dnn::BatchDescriptor& x_desc, + const dnn::BatchDescriptor& scale_offset_desc, const double epsilon, + dnn::ActivationMode activation_mode, + DeviceMemory* x_backprop, + DeviceMemory* scale_backprop, DeviceMemory* offset_backprop, + DeviceMemory* side_input_backprop, + DeviceMemory* reserve_space_data, + ScratchAllocator* workspace_allocator) override; + tsl::Status DoConvolve( dnn::ConvolutionKind kind, dnn::DataType element_type, dnn::DataType output_type, Stream* stream, diff --git a/tensorflow/core/kernels/batch_norm_op_test.cc b/tensorflow/core/kernels/batch_norm_op_test.cc index 45ddc853295557..00db3310122abe 100644 --- a/tensorflow/core/kernels/batch_norm_op_test.cc +++ b/tensorflow/core/kernels/batch_norm_op_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include + #include "tensorflow/core/framework/allocator.h" #include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def_builder.h" @@ -29,60 +30,42 @@ limitations under the License. namespace tensorflow { -class BatchNormOpTest : public OpsTestBase {}; - -TEST_F(BatchNormOpTest, Simple) { - TF_EXPECT_OK( - NodeDefBuilder("batch_norm_op", "BatchNormWithGlobalNormalization") - .Input(FakeInput(DT_FLOAT)) - .Input(FakeInput(DT_FLOAT)) - .Input(FakeInput(DT_FLOAT)) - .Input(FakeInput(DT_FLOAT)) - .Input(FakeInput(DT_FLOAT)) - .Attr("scale_after_normalization", false) - .Attr("variance_epsilon", 0.001) - .Finalize(node_def())); - TF_EXPECT_OK(InitOpWithGraphVersion(8)); - AddInputFromArray(TensorShape({1, 1, 6, 2}), - {1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6}); - AddInputFromArray(TensorShape({2}), {10, 20}); - AddInputFromArray(TensorShape({2}), {0.25f, 0.5f}); - AddInputFromArray(TensorShape({2}), {0.1f, 0.6f}); - AddInputFromArray(TensorShape({2}), {0.0f, 0.0f}); - TF_ASSERT_OK(RunOpKernel()); - - Tensor expected(allocator(), DT_FLOAT, TensorShape({1, 1, 6, 2})); - test::FillValues( - &expected, {-17.86f, -22.00f, -15.87f, -20.59f, -13.87f, -19.18f, -21.86f, - -33.31f, -23.85f, -34.72f, -25.85f, -36.13f}); - test::ExpectTensorNear(expected, *GetOutput(0), 0.01); -} - -TEST_F(BatchNormOpTest, Fp16) { - TF_EXPECT_OK( - NodeDefBuilder("batch_norm_op", "BatchNormWithGlobalNormalization") - .Input(FakeInput(DT_HALF)) - .Input(FakeInput(DT_HALF)) - .Input(FakeInput(DT_HALF)) - .Input(FakeInput(DT_HALF)) - .Input(FakeInput(DT_HALF)) - .Attr("scale_after_normalization", false) - .Attr("variance_epsilon", 0.001) - .Finalize(node_def())); - TF_EXPECT_OK(InitOpWithGraphVersion(8)); - AddInputFromList(TensorShape({1, 1, 6, 2}), - {1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6}); - AddInputFromList(TensorShape({2}), {10, 20}); - AddInputFromList(TensorShape({2}), {0.25, 0.5}); - AddInputFromList(TensorShape({2}), {0.1, 0.6}); - AddInputFromList(TensorShape({2}), {0.0, 0.0}); - TF_ASSERT_OK(RunOpKernel()); +template +struct BatchNormOpTest : public OpsTestBase { + static constexpr auto TValueType = DataTypeToEnum::value; + void run_me() { + TF_EXPECT_OK( + NodeDefBuilder("batch_norm_op", "BatchNormWithGlobalNormalization") + .Input(FakeInput(TValueType)) + .Input(FakeInput(TValueType)) + .Input(FakeInput(TValueType)) + .Input(FakeInput(TValueType)) + .Input(FakeInput(TValueType)) + .Attr("scale_after_normalization", false) + .Attr("variance_epsilon", 0.001) + .Finalize(node_def())); + TF_EXPECT_OK(InitOpWithGraphVersion(8)); + AddInputFromList(TensorShape({1, 1, 6, 2}), + {1, 4, 2, 5, 3, 6, -1, -4, -2, -5, -3, -6}); + AddInputFromList(TensorShape({2}), {10, 20}); + AddInputFromList(TensorShape({2}), {0.25, 0.5}); + AddInputFromList(TensorShape({2}), {0.1, 0.6}); + AddInputFromList(TensorShape({2}), {0.0, 0.0}); + TF_ASSERT_OK(RunOpKernel()); + double atol = TValueType == DT_FLOAT ? 0.01 : 0.1; + Tensor expected(allocator(), TValueType, TensorShape({1, 1, 6, 2})); + test::FillValues(&expected, + {-17.86f, -22.00f, -15.87f, -20.59f, -13.87f, -19.18f, + -21.86f, -33.31f, -23.85f, -34.72f, -25.85f, -36.13f}); + test::ExpectTensorNear(expected, *GetOutput(0), atol); + } +}; - Tensor expected(allocator(), DT_HALF, TensorShape({1, 1, 6, 2})); - test::FillValues( - &expected, {-17.86, -22.00, -15.87, -20.59, -13.87, -19.18, -21.86, - -33.31, -23.85, -34.72, -25.85, -36.13}); - test::ExpectTensorNear(expected, *GetOutput(0), 0.1); -} +TYPED_TEST_SUITE_P(BatchNormOpTest); +TYPED_TEST_P(BatchNormOpTest, Simple) { this->run_me(); } +REGISTER_TYPED_TEST_SUITE_P(BatchNormOpTest, Simple); +// TODO(ezhulenev): Add support for more data types. +using DataTypes = ::testing::Types; //, Eigen::bfloat16>; +INSTANTIATE_TYPED_TEST_SUITE_P(Test, BatchNormOpTest, DataTypes); } // namespace tensorflow diff --git a/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc b/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc index 6de7a7f48e8fcd..7a03ae3816ec68 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops_3d.cc @@ -52,10 +52,10 @@ using stream_executor::dnn::DimIndex; #include "tensorflow/core/util/proto/proto_utils.h" #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM #if GOOGLE_CUDA -#include "third_party/gpus/cudnn/cudnn.h" #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_asm_opts.h" #include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h" #include "tensorflow/compiler/xla/stream_executor/tf_allocator_adapter.h" +#include "third_party/gpus/cudnn/cudnn.h" #endif // GOOGLE_CUDA namespace { @@ -662,7 +662,6 @@ DECLARE_GPU_SPEC(double); #undef DECLARE_GPU_SPEC } // namespace functor - // A dummy type to group backward filter autotune results together. struct Conv3dBackwardFilterAutotuneGroup { static string name() { return "Conv3dBwdFilter"; } @@ -693,8 +692,7 @@ void LaunchConvBackpropFilterOpImpl( OP_REQUIRES(context, stream, errors::Internal("No GPU stream available.")); if (DataTypeToEnum::value == DT_BFLOAT16 && - !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE)) { + IsBF16NotSupportedInOps(stream)) { context->SetStatus(errors::Unimplemented( "Conv3DBackpropFilter for GPU with bfloat16 is only supported " "with cuDNN on Ampere GPUs or later.")); @@ -796,86 +794,86 @@ void LaunchConvBackpropFilterOpImpl( << padding_planes << ")"; #if GOOGLE_CUDA - const bool compute_in_nhwc = ComputeInNhwcEnabled( - DataTypeToEnum::value, stream, /*use_4d_tensor=*/false); + const bool compute_in_nhwc = ComputeInNhwcEnabled( + DataTypeToEnum::value, stream, /*use_4d_tensor=*/false); #else - // fast NDHWC implementation is a CUDA only feature - const bool compute_in_nhwc = false; + // fast NDHWC implementation is a CUDA only feature + const bool compute_in_nhwc = false; #endif - const TensorFormat compute_data_format = - (compute_in_nhwc && data_format == FORMAT_NHWC) ? FORMAT_NHWC - : FORMAT_NCHW; - - VLOG(3) << "Compute Conv3DBackpropFilter with cuDNN:" - << " data_format=" << ToString(data_format) - << " compute_data_format=" << ToString(compute_data_format); - - constexpr auto kComputeInNHWC = - std::make_tuple(se::dnn::DataLayout::kBatchYXDepth, - se::dnn::FilterLayout::kOutputYXInput); - constexpr auto kComputeInNCHW = - std::make_tuple(se::dnn::DataLayout::kBatchDepthYX, - se::dnn::FilterLayout::kOutputInputYX); - - se::dnn::DataLayout compute_data_layout; - se::dnn::FilterLayout filter_layout; - - std::tie(compute_data_layout, filter_layout) = - compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW; - - se::dnn::BatchDescriptor input_desc(3); - input_desc.set_count(dims.batch_size) - .set_spatial_dim(DimIndex::X, - GetTensorDim(compatible_input, data_format, '2')) - .set_spatial_dim(DimIndex::Y, - GetTensorDim(compatible_input, data_format, '1')) - .set_spatial_dim(DimIndex::Z, - GetTensorDim(compatible_input, data_format, '0')) - .set_feature_map_count(dims.in_depth) - .set_layout(compute_data_layout); - se::dnn::BatchDescriptor output_desc(3); - output_desc.set_count(dims.batch_size) - .set_spatial_dim(DimIndex::X, dims.output_size(2)) - .set_spatial_dim(DimIndex::Y, dims.output_size(1)) - .set_spatial_dim(DimIndex::Z, dims.output_size(0)) - .set_feature_map_count(dims.out_depth) - .set_layout(compute_data_layout); - se::dnn::FilterDescriptor filter_desc(3); - filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2)) - .set_spatial_dim(DimIndex::Y, dims.filter_size(1)) - .set_spatial_dim(DimIndex::Z, dims.filter_size(0)) - .set_input_feature_map_count(filter_shape.dim_size(3)) - .set_output_feature_map_count(filter_shape.dim_size(4)) - .set_layout(filter_layout); - se::dnn::ConvolutionDescriptor conv_desc(3); - conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2)) - .set_dilation_rate(DimIndex::Y, dims.dilation(1)) - .set_dilation_rate(DimIndex::Z, dims.dilation(0)) - .set_filter_stride(DimIndex::X, dims.stride(2)) - .set_filter_stride(DimIndex::Y, dims.stride(1)) - .set_filter_stride(DimIndex::Z, dims.stride(0)) - .set_zero_padding(DimIndex::X, padding_cols / 2) - .set_zero_padding(DimIndex::Y, padding_rows / 2) - .set_zero_padding(DimIndex::Z, padding_planes / 2) - .set_group_count(dims.in_depth / filter_shape.dim_size(3)); - - Tensor pre_transformed_filter_backprop; - auto dst_format = - compute_data_format == FORMAT_NCHW ? FORMAT_OIHW : FORMAT_OHWI; - TensorShape dst_shape = - dst_format == FORMAT_OIHW - ? TensorShape({filter_shape.dim_size(4), filter_shape.dim_size(3), - dims.filter_size(0), dims.filter_size(1), - dims.filter_size(2)}) - : TensorShape({filter_shape.dim_size(4), dims.filter_size(0), - dims.filter_size(1), dims.filter_size(2), - filter_shape.dim_size(3)}); - OP_REQUIRES_OK(context, - context->allocate_temp(DataTypeToEnum::value, dst_shape, - &pre_transformed_filter_backprop)); - - Tensor transformed_out_backprop; - if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) { + const TensorFormat compute_data_format = + (compute_in_nhwc && data_format == FORMAT_NHWC) ? FORMAT_NHWC + : FORMAT_NCHW; + + VLOG(3) << "Compute Conv3DBackpropFilter with cuDNN:" + << " data_format=" << ToString(data_format) + << " compute_data_format=" << ToString(compute_data_format); + + constexpr auto kComputeInNHWC = + std::make_tuple(se::dnn::DataLayout::kBatchYXDepth, + se::dnn::FilterLayout::kOutputYXInput); + constexpr auto kComputeInNCHW = + std::make_tuple(se::dnn::DataLayout::kBatchDepthYX, + se::dnn::FilterLayout::kOutputInputYX); + + se::dnn::DataLayout compute_data_layout; + se::dnn::FilterLayout filter_layout; + + std::tie(compute_data_layout, filter_layout) = + compute_data_format == FORMAT_NHWC ? kComputeInNHWC : kComputeInNCHW; + + se::dnn::BatchDescriptor input_desc(3); + input_desc.set_count(dims.batch_size) + .set_spatial_dim(DimIndex::X, + GetTensorDim(compatible_input, data_format, '2')) + .set_spatial_dim(DimIndex::Y, + GetTensorDim(compatible_input, data_format, '1')) + .set_spatial_dim(DimIndex::Z, + GetTensorDim(compatible_input, data_format, '0')) + .set_feature_map_count(dims.in_depth) + .set_layout(compute_data_layout); + se::dnn::BatchDescriptor output_desc(3); + output_desc.set_count(dims.batch_size) + .set_spatial_dim(DimIndex::X, dims.output_size(2)) + .set_spatial_dim(DimIndex::Y, dims.output_size(1)) + .set_spatial_dim(DimIndex::Z, dims.output_size(0)) + .set_feature_map_count(dims.out_depth) + .set_layout(compute_data_layout); + se::dnn::FilterDescriptor filter_desc(3); + filter_desc.set_spatial_dim(DimIndex::X, dims.filter_size(2)) + .set_spatial_dim(DimIndex::Y, dims.filter_size(1)) + .set_spatial_dim(DimIndex::Z, dims.filter_size(0)) + .set_input_feature_map_count(filter_shape.dim_size(3)) + .set_output_feature_map_count(filter_shape.dim_size(4)) + .set_layout(filter_layout); + se::dnn::ConvolutionDescriptor conv_desc(3); + conv_desc.set_dilation_rate(DimIndex::X, dims.dilation(2)) + .set_dilation_rate(DimIndex::Y, dims.dilation(1)) + .set_dilation_rate(DimIndex::Z, dims.dilation(0)) + .set_filter_stride(DimIndex::X, dims.stride(2)) + .set_filter_stride(DimIndex::Y, dims.stride(1)) + .set_filter_stride(DimIndex::Z, dims.stride(0)) + .set_zero_padding(DimIndex::X, padding_cols / 2) + .set_zero_padding(DimIndex::Y, padding_rows / 2) + .set_zero_padding(DimIndex::Z, padding_planes / 2) + .set_group_count(dims.in_depth / filter_shape.dim_size(3)); + + Tensor pre_transformed_filter_backprop; + auto dst_format = + compute_data_format == FORMAT_NCHW ? FORMAT_OIHW : FORMAT_OHWI; + TensorShape dst_shape = + dst_format == FORMAT_OIHW + ? TensorShape({filter_shape.dim_size(4), filter_shape.dim_size(3), + dims.filter_size(0), dims.filter_size(1), + dims.filter_size(2)}) + : TensorShape({filter_shape.dim_size(4), dims.filter_size(0), + dims.filter_size(1), dims.filter_size(2), + filter_shape.dim_size(3)}); + OP_REQUIRES_OK(context, + context->allocate_temp(DataTypeToEnum::value, dst_shape, + &pre_transformed_filter_backprop)); + + Tensor transformed_out_backprop; + if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) { VLOG(4) << "Convert the `out_backprop` tensor from NDHWC to NCDHW."; TensorShape nchw_shape = {dims.batch_size, dims.out_depth, dims.output_size(0), dims.output_size(1), @@ -890,11 +888,11 @@ void LaunchConvBackpropFilterOpImpl( } else { CHECK(transformed_out_backprop.CopyFrom(out_backprop, nchw_shape)); } - } else { + } else { transformed_out_backprop = out_backprop; - } - Tensor transformed_input; - if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) { + } + Tensor transformed_input; + if (data_format == FORMAT_NHWC && compute_data_format == FORMAT_NCHW) { VLOG(4) << "Convert the `input` tensor from NDHWC to NCDHW."; TensorShape nchw_shape = { dims.batch_size, dims.in_depth, compatible_input.dim_size(1), @@ -910,96 +908,91 @@ void LaunchConvBackpropFilterOpImpl( } else { CHECK(transformed_input.CopyFrom(compatible_input, nchw_shape)); } - } else { + } else { transformed_input = compatible_input; - } + } - auto out_backprop_ptr = - AsDeviceMemory(transformed_out_backprop.template flat().data(), - transformed_out_backprop.template flat().size()); - auto filter_backprop_ptr = AsDeviceMemory( - pre_transformed_filter_backprop.template flat().data(), - pre_transformed_filter_backprop.template flat().size()); - auto input_ptr = - AsDeviceMemory(transformed_input.template flat().data(), - transformed_input.template flat().size()); - - static int64_t ConvolveBackwardFilterScratchSize = - GetDnnWorkspaceLimitOrDefault(); - - const ConvParameters conv_parameters = { - stream->parent(), - dims.batch_size, - dims.in_depth, - {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}}, - compute_data_format, - dims.out_depth, - {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}}, - {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}}, - {{dims.stride(0), dims.stride(1), dims.stride(2)}}, - {{padding_planes, padding_rows, padding_cols}}, - input.dtype(), - conv_desc.group_count(), - }; - - using se::dnn::AlgorithmConfig; - using se::dnn::AlgorithmDesc; - using se::dnn::ProfileResult; - - auto entry_or = AutotuneUnfusedConv( - cudnn_use_autotune, AutotuneConv3dBwdFilter::GetInstance(), - conv_parameters, context, se::dnn::ConvolutionKind::BACKWARD_FILTER, - input_desc, input_ptr, filter_desc, filter_backprop_ptr, conv_desc, - output_desc, out_backprop_ptr, ConvolveBackwardFilterScratchSize); - OP_REQUIRES_OK(context, entry_or.status()); - auto autotune_entry = std::move(entry_or).value(); - - DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize, - context); - Status cudnn_launch_status = LaunchAutotunedConv( - autotune_entry, &scratch_allocator, - se::dnn::ConvolutionKind::BACKWARD_FILTER, stream, input_desc, - input_ptr, filter_desc, filter_backprop_ptr, conv_desc, output_desc, - out_backprop_ptr); - if (!cudnn_launch_status.ok()) { - context->SetStatus(cudnn_launch_status); - return; - } + auto out_backprop_ptr = + AsDeviceMemory(transformed_out_backprop.template flat().data(), + transformed_out_backprop.template flat().size()); + auto filter_backprop_ptr = + AsDeviceMemory(pre_transformed_filter_backprop.template flat().data(), + pre_transformed_filter_backprop.template flat().size()); + auto input_ptr = AsDeviceMemory(transformed_input.template flat().data(), + transformed_input.template flat().size()); + + static int64_t ConvolveBackwardFilterScratchSize = + GetDnnWorkspaceLimitOrDefault(); + + const ConvParameters conv_parameters = { + stream->parent(), + dims.batch_size, + dims.in_depth, + {{dims.input_size(0), dims.input_size(1), dims.input_size(2)}}, + compute_data_format, + dims.out_depth, + {{dims.filter_size(0), dims.filter_size(1), dims.filter_size(2)}}, + {{dims.dilation(0), dims.dilation(1), dims.dilation(2)}}, + {{dims.stride(0), dims.stride(1), dims.stride(2)}}, + {{padding_planes, padding_rows, padding_cols}}, + input.dtype(), + conv_desc.group_count(), + }; + + using se::dnn::AlgorithmConfig; + using se::dnn::AlgorithmDesc; + using se::dnn::ProfileResult; + + auto entry_or = AutotuneUnfusedConv( + cudnn_use_autotune, AutotuneConv3dBwdFilter::GetInstance(), + conv_parameters, context, se::dnn::ConvolutionKind::BACKWARD_FILTER, + input_desc, input_ptr, filter_desc, filter_backprop_ptr, conv_desc, + output_desc, out_backprop_ptr, ConvolveBackwardFilterScratchSize); + OP_REQUIRES_OK(context, entry_or.status()); + auto autotune_entry = std::move(entry_or).value(); + + DnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize, + context); + Status cudnn_launch_status = LaunchAutotunedConv( + autotune_entry, &scratch_allocator, + se::dnn::ConvolutionKind::BACKWARD_FILTER, stream, input_desc, input_ptr, + filter_desc, filter_backprop_ptr, conv_desc, output_desc, + out_backprop_ptr); + if (!cudnn_launch_status.ok()) { + context->SetStatus(cudnn_launch_status); + return; + } - auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; - functor::ReverseTransformFilter()( - context->eigen_device(), /*src_filter_format=*/dst_format, - toConstTensor(pre_transformed_filter_backprop).template tensor(), - filter_backprop->tensor()); + auto toConstTensor = [](const Tensor& x) -> const Tensor { return x; }; + functor::ReverseTransformFilter()( + context->eigen_device(), /*src_filter_format=*/dst_format, + toConstTensor(pre_transformed_filter_backprop).template tensor(), + filter_backprop->tensor()); } template struct LaunchConvBackpropFilterOp { - static void launch(OpKernelContext* context, bool cudnn_use_autotune, - const Tensor& input, const Tensor& out_backprop, - const std::vector& dilation, - const std::vector& stride, const Padding& padding, - Tensor* filter_backprop, TensorFormat data_format) { - LaunchConvBackpropFilterOpImpl(context, cudnn_use_autotune, input, - out_backprop, dilation, stride, padding, - filter_backprop, data_format); - } + static void launch(OpKernelContext* context, bool cudnn_use_autotune, + const Tensor& input, const Tensor& out_backprop, + const std::vector& dilation, + const std::vector& stride, const Padding& padding, + Tensor* filter_backprop, TensorFormat data_format) { + LaunchConvBackpropFilterOpImpl(context, cudnn_use_autotune, input, + out_backprop, dilation, stride, padding, + filter_backprop, data_format); + } }; template <> struct LaunchConvBackpropFilterOp { - static void launch(OpKernelContext* ctx, bool cudnn_use_autotune, - const Tensor& input, const Tensor& out_backprop, - const std::vector& dilation, - const std::vector& stride, const Padding& padding, - Tensor* filter_backprop, TensorFormat data_format) { - // Performant bfloat16 operations are supported for Ampere+ GPUs. For - // pre-Ampere GPUs, we cast inputs to float and outputs back to bfloat16. - auto* stream = ctx->op_device_context()->stream(); - const bool cast_to_float = !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); - - if (cast_to_float) { + static void launch(OpKernelContext* ctx, bool cudnn_use_autotune, + const Tensor& input, const Tensor& out_backprop, + const std::vector& dilation, + const std::vector& stride, const Padding& padding, + Tensor* filter_backprop, TensorFormat data_format) { + auto* stream = ctx->op_device_context()->stream(); + const bool cast_to_float = IsBF16NotSupportedInOps(stream); + if (cast_to_float) { Tensor casted_input = input; Tensor casted_out_backprop = out_backprop; Tensor casted_filter_backprop = *filter_backprop; @@ -1028,96 +1021,96 @@ struct LaunchConvBackpropFilterOp { cast_back(device, filter_backprop->template flat(), casted_filter_backprop_const.template flat()); return; - } - - LaunchConvBackpropFilterOpImpl( - ctx, cudnn_use_autotune, input, out_backprop, dilation, stride, - padding, filter_backprop, data_format); } + + LaunchConvBackpropFilterOpImpl( + ctx, cudnn_use_autotune, input, out_backprop, dilation, stride, padding, + filter_backprop, data_format); + } }; template class Conv3DBackpropFilterOp : public OpKernel { - public: - explicit Conv3DBackpropFilterOp(OpKernelConstruction* context) - : OpKernel(context), - data_format_(FORMAT_NHWC), - takes_shape_(type_string().find("V2") != std::string::npos) { - // data_format is only available in V2. - if (takes_shape_) { + public: + explicit Conv3DBackpropFilterOp(OpKernelConstruction* context) + : OpKernel(context), + data_format_(FORMAT_NHWC), + takes_shape_(type_string().find("V2") != std::string::npos) { + // data_format is only available in V2. + if (takes_shape_) { string data_format; OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format)); OP_REQUIRES(context, FormatFromString(data_format, &data_format_), errors::InvalidArgument("Invalid data format")); - } - OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_)); - OP_REQUIRES(context, dilation_.size() == 5, - errors::InvalidArgument("Dilation rates field must " - "specify 5 dimensions")); - OP_REQUIRES(context, - (GetTensorDim(dilation_, data_format_, 'C') == 1 && - GetTensorDim(dilation_, data_format_, 'N') == 1), - errors::InvalidArgument( - "Current implementation does not yet support " - "dilation rates in the batch and depth dimensions.")); - OP_REQUIRES( - context, - (GetTensorDim(dilation_, data_format_, '0') > 0 && - GetTensorDim(dilation_, data_format_, '1') > 0 && - GetTensorDim(dilation_, data_format_, '2') > 0), - errors::InvalidArgument("Dilated rates should be larger than 0.")); - OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); - OP_REQUIRES(context, stride_.size() == 5, - errors::InvalidArgument("Sliding window strides field must " - "specify 5 dimensions")); - OP_REQUIRES(context, - (GetTensorDim(stride_, data_format_, 'C') == 1 && - GetTensorDim(stride_, data_format_, 'N') == 1), - errors::InvalidArgument( - "Current implementation does not yet support " - "strides in the batch and depth dimensions.")); - OP_REQUIRES( - context, - (GetTensorDim(stride_, data_format_, '0') > 0 && - GetTensorDim(stride_, data_format_, '1') > 0 && - GetTensorDim(stride_, data_format_, '2') > 0), - errors::InvalidArgument("Spatial strides should be larger than 0.")); - OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); - cudnn_use_autotune_ = CudnnUseAutotune(); } + OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilation_)); + OP_REQUIRES(context, dilation_.size() == 5, + errors::InvalidArgument("Dilation rates field must " + "specify 5 dimensions")); + OP_REQUIRES(context, + (GetTensorDim(dilation_, data_format_, 'C') == 1 && + GetTensorDim(dilation_, data_format_, 'N') == 1), + errors::InvalidArgument( + "Current implementation does not yet support " + "dilation rates in the batch and depth dimensions.")); + OP_REQUIRES( + context, + (GetTensorDim(dilation_, data_format_, '0') > 0 && + GetTensorDim(dilation_, data_format_, '1') > 0 && + GetTensorDim(dilation_, data_format_, '2') > 0), + errors::InvalidArgument("Dilated rates should be larger than 0.")); + OP_REQUIRES_OK(context, context->GetAttr("strides", &stride_)); + OP_REQUIRES(context, stride_.size() == 5, + errors::InvalidArgument("Sliding window strides field must " + "specify 5 dimensions")); + OP_REQUIRES( + context, + (GetTensorDim(stride_, data_format_, 'C') == 1 && + GetTensorDim(stride_, data_format_, 'N') == 1), + errors::InvalidArgument("Current implementation does not yet support " + "strides in the batch and depth dimensions.")); + OP_REQUIRES( + context, + (GetTensorDim(stride_, data_format_, '0') > 0 && + GetTensorDim(stride_, data_format_, '1') > 0 && + GetTensorDim(stride_, data_format_, '2') > 0), + errors::InvalidArgument("Spatial strides should be larger than 0.")); + OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_)); + cudnn_use_autotune_ = CudnnUseAutotune(); + } - void Compute(OpKernelContext* context) override { - const Tensor& input = context->input(0); - const Tensor& out_backprop = context->input(2); + void Compute(OpKernelContext* context) override { + const Tensor& input = context->input(0); + const Tensor& out_backprop = context->input(2); - TensorShape filter_shape; - if (takes_shape_) { + TensorShape filter_shape; + if (takes_shape_) { const Tensor& filter_sizes = context->input(1); OP_REQUIRES(context, TensorShapeUtils::IsVector(filter_sizes.shape()), errors::InvalidArgument( "filter_sizes shape must be rank 1 but is rank ", filter_sizes.shape().dims())); OP_REQUIRES_OK(context, tensor::MakeShape(filter_sizes, &filter_shape)); - } else { + } else { filter_shape = context->input(1).shape(); - } + } - Tensor* filter_backprop; - OP_REQUIRES_OK( - context, context->allocate_output(0, filter_shape, &filter_backprop)); + Tensor* filter_backprop; + OP_REQUIRES_OK(context, + context->allocate_output(0, filter_shape, &filter_backprop)); - LaunchConvBackpropFilterOp::launch( - context, cudnn_use_autotune_, input, out_backprop, dilation_, stride_, - padding_, filter_backprop, data_format_); - } + LaunchConvBackpropFilterOp::launch( + context, cudnn_use_autotune_, input, out_backprop, dilation_, stride_, + padding_, filter_backprop, data_format_); + } - private: - std::vector dilation_; - std::vector stride_; - Padding padding_; - TensorFormat data_format_; - bool takes_shape_; - bool cudnn_use_autotune_; + private: + std::vector dilation_; + std::vector stride_; + Padding padding_; + TensorFormat data_format_; + bool takes_shape_; + bool cudnn_use_autotune_; }; #define REGISTER_GPU_KERNEL(T) \ diff --git a/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc b/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc index 8428835625ec36..2d5f9e313269a9 100644 --- a/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc +++ b/tensorflow/core/kernels/conv_grad_filter_ops_launcher.cc @@ -540,12 +540,8 @@ operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune, const Padding& padding, const std::vector& explicit_paddings, Tensor* filter_backprop, TensorFormat data_format) { - // Performant bfloat16 operations are supported for Ampere+ GPUs. For - // pre-Ampere GPUs, we cast inputs to float and outputs back to bfloat16. auto* stream = ctx->op_device_context()->stream(); - const bool cast_to_float = !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); - + const bool cast_to_float = IsBF16NotSupportedInOps(stream); if (cast_to_float) { Tensor casted_input = input; Tensor casted_out_backprop = out_backprop; diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc index 909f1c901898a9..897bd14d97ae1d 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops.cc @@ -465,15 +465,8 @@ void LaunchConv2DBackpropInputOp::operator()( int col_dilation, int row_stride, int col_stride, const Padding& padding, const std::vector& explicit_paddings, Tensor* in_backprop, TensorFormat data_format) { - // Performant bfloat16 operations are supported for Ampere+ GPUs. For - // pre-Ampere GPUs, we cast inputs to float and outputs back to bfloat16. auto* stream = ctx->op_device_context()->stream(); -#if GOOGLE_CUDA - const bool cast_to_float = !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); -#else - const bool cast_to_float = false; -#endif + const bool cast_to_float = IsBF16NotSupportedInOps(stream); Tensor casted_out_backprop = out_backprop; Tensor casted_filter = filter; Tensor casted_in_backprop = *in_backprop; diff --git a/tensorflow/core/kernels/conv_grad_input_ops_3d.cc b/tensorflow/core/kernels/conv_grad_input_ops_3d.cc index a4704157a7b9f1..d1f273e06c11eb 100644 --- a/tensorflow/core/kernels/conv_grad_input_ops_3d.cc +++ b/tensorflow/core/kernels/conv_grad_input_ops_3d.cc @@ -1018,15 +1018,8 @@ struct LaunchConvBackpropInputOp { const std::vector& dilation, const std::vector& strides, const Padding& padding, Tensor* in_backprop, TensorFormat data_format) { - // Performant bfloat16 operations are supported for Ampere+ GPUs. For - // pre-Ampere GPUs, we cast inputs to float and outputs back to bfloat16. auto* stream = ctx->op_device_context()->stream(); -#if GOOGLE_CUDA - const bool cast_to_float = !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); -#else - const bool cast_to_float = false; -#endif + const bool cast_to_float = IsBF16NotSupportedInOps(stream); Tensor casted_out_backprop = out_backprop; Tensor casted_filter = filter; Tensor casted_in_backprop = *in_backprop; diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc index b1fb1e744af523..6272147cc38f1c 100644 --- a/tensorflow/core/kernels/conv_ops_3d.cc +++ b/tensorflow/core/kernels/conv_ops_3d.cc @@ -557,20 +557,12 @@ struct LaunchConvOp { const std::array& dilations, const std::array& strides, const Padding padding, TensorFormat data_format, Tensor* output) { - // Performant bfloat16 operations are supported for Ampere+ GPUs. For - // pre-Ampere GPUs, we cast inputs to float and outputs back to bfloat16. auto* stream = ctx->op_device_context()->stream(); -#if GOOGLE_CUDA - const bool cast_to_float = !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); -#else - const bool cast_to_float = false; -#endif + const bool cast_to_float = IsBF16NotSupportedInOps(stream); Tensor casted_input = input_param; Tensor casted_filter = filter; Tensor casted_out = *output; - if (cast_to_float) { Tensor casted_input = input_param; Tensor casted_filter = filter; diff --git a/tensorflow/core/kernels/conv_ops_bfloat16.cc b/tensorflow/core/kernels/conv_ops_bfloat16.cc index 73eb3158153d42..d00e4449298183 100644 --- a/tensorflow/core/kernels/conv_ops_bfloat16.cc +++ b/tensorflow/core/kernels/conv_ops_bfloat16.cc @@ -91,12 +91,8 @@ void LaunchConv2DOp::operator()( int col_dilation, int row_stride, int col_stride, const Padding& padding, const std::vector& explicit_paddings, Tensor* output, TensorFormat data_format) { - // Performant bfloat16 operations are supported for Ampere+ GPUs. For - // pre-Ampere GPUs, we cast inputs to float and outputs back to bfloat16. auto* stream = ctx->op_device_context()->stream(); - const bool cast_to_float = !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); - + const bool cast_to_float = IsBF16NotSupportedInOps(stream); if (cast_to_float) { Tensor casted_input = input_param; Tensor casted_filter = filter; diff --git a/tensorflow/core/kernels/cudnn_pooling_gpu.cc b/tensorflow/core/kernels/cudnn_pooling_gpu.cc index dce0e995be7581..83032fc0e47440 100644 --- a/tensorflow/core/kernels/cudnn_pooling_gpu.cc +++ b/tensorflow/core/kernels/cudnn_pooling_gpu.cc @@ -149,11 +149,8 @@ void DnnPooling3dOp::Compute( const std::array& window, const std::array& stride, const std::array& padding, TensorFormat data_format, const Tensor& tensor_in, Tensor* output) { - // Performant bfloat16 operations are supported for Ampere+ GPUs. For - // pre-Ampere GPUs, we cast inputs to float and outputs back to bfloat16. auto* stream = context->op_device_context()->stream(); - const bool cast_to_float = !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); + const bool cast_to_float = IsBF16NotSupportedInOps(stream); if (cast_to_float) { Tensor casted_in; Tensor casted_output; @@ -348,11 +345,8 @@ void DnnPooling3dGradOp::Compute( const std::array& output_size, TensorFormat data_format, const Tensor& out_backprop, const TensorShape& tensor_in_shape, const Tensor* tensor_in, const Tensor* tensor_out, Tensor* input_backprop) { - // Performant bfloat16 operations are supported for Ampere+ GPUs. For - // pre-Ampere GPUs, we cast inputs to float and outputs back to bfloat16. auto* stream = context->op_device_context()->stream(); - const bool cast_to_float = !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); + const bool cast_to_float = IsBF16NotSupportedInOps(stream); if (cast_to_float) { Tensor casted_out_backprop; Tensor casted_tensor_in; diff --git a/tensorflow/core/kernels/depthwise_conv_op.cc b/tensorflow/core/kernels/depthwise_conv_op.cc index b282855666b4ae..864f7ba603c6b5 100644 --- a/tensorflow/core/kernels/depthwise_conv_op.cc +++ b/tensorflow/core/kernels/depthwise_conv_op.cc @@ -58,16 +58,12 @@ typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; bool UseCudnnWith16BitFloat(OpKernelContext* ctx, DataType dtype) { -#if GOOGLE_CUDA if (dtype == DT_HALF) { return true; } else if (dtype == DT_BFLOAT16) { auto* stream = ctx->op_device_context()->stream(); - if (!stream) return false; - return stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); + return !IsBF16NotSupportedInOps(stream); } -#endif return false; } diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc index 61c859855022e6..fd4d591d7ed0aa 100644 --- a/tensorflow/core/kernels/fused_batch_norm_op.cc +++ b/tensorflow/core/kernels/fused_batch_norm_op.cc @@ -37,6 +37,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/kernels/gpu_utils.h" #include "tensorflow/core/kernels/cast_op.h" #include "tensorflow/core/kernels/fill_functor.h" #include "tensorflow/core/kernels/fused_batch_norm_op.h" @@ -1059,11 +1060,8 @@ struct FusedBatchNorm { Tensor* batch_mean, Tensor* batch_var, Tensor* saved_mean, Tensor* saved_inv_var, TensorFormat tensor_format, bool use_reserved_space) { - // Performant bfloat16 operations are supported for Ampere+ GPUs. For - // pre-Ampere GPUs, we cast inputs to float and outputs back to bfloat16. auto* stream = context->op_device_context()->stream(); - const bool cast_to_float = !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); + const bool cast_to_float = IsBF16NotSupportedInOps(stream); if (cast_to_float) { Tensor casted_x = x; Tensor casted_side_input; @@ -1327,11 +1325,8 @@ struct FusedBatchNormGrad { Tensor* x_backprop, Tensor* scale_backprop, Tensor* offset_backprop, Tensor* side_input_backprop, bool use_reserved_space, TensorFormat tensor_format) { - // Performant bfloat16 operations are supported for Ampere+ GPUs. For - // pre-Ampere GPUs, we cast inputs to float and outputs back to bfloat16. auto* stream = context->op_device_context()->stream(); - const bool cast_to_float = !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); + const bool cast_to_float = IsBF16NotSupportedInOps(stream); if (cast_to_float) { Tensor casted_y_backprop = y_backprop; Tensor casted_x = x; diff --git a/tensorflow/core/kernels/gpu_utils.cc b/tensorflow/core/kernels/gpu_utils.cc index a48991fbfef9d7..1b45aa8828aad6 100644 --- a/tensorflow/core/kernels/gpu_utils.cc +++ b/tensorflow/core/kernels/gpu_utils.cc @@ -33,6 +33,21 @@ limitations under the License. namespace tensorflow { +bool IsBF16NotSupportedInOps(se::Stream* stream) { + if (!stream) { + return true; // no stream: don't know whether it's supported + } +#if GOOGLE_CUDA + // Performant bfloat16 operations are supported for Ampere+ GPUs. For + // pre-Ampere GPUs, we cast inputs to float and outputs back to bfloat16. + return !stream->GetCudaComputeCapability().IsAtLeast( + se::CudaComputeCapability::AMPERE); +#elif TENSORFLOW_USE_ROCM + return true; // so far, we return true meaning that the conversion to float + // is needed +#endif +} + bool RedzoneCheckDisabled() { const char* disable_rz_str = std::getenv("TF_DISABLE_RZ_CHECK"); return disable_rz_str != nullptr && std::strcmp(disable_rz_str, "1") == 0; diff --git a/tensorflow/core/kernels/gpu_utils.h b/tensorflow/core/kernels/gpu_utils.h index cba7aab1878406..dcb2c5bac1a43b 100644 --- a/tensorflow/core/kernels/gpu_utils.h +++ b/tensorflow/core/kernels/gpu_utils.h @@ -38,6 +38,10 @@ class RedzoneAllocator; namespace tensorflow { +// returns true if bfloat16 is not directly supported in Ops and inputs shall be +// casted to floats to perform the computations and then back +bool IsBF16NotSupportedInOps(se::Stream *stream); + class NodeDef; class AutotuneResult; diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc index b48287ae1442a4..d4a25b24276445 100644 --- a/tensorflow/core/kernels/pooling_ops_common.cc +++ b/tensorflow/core/kernels/pooling_ops_common.cc @@ -458,8 +458,7 @@ void DnnPoolingOp::Compute( context->allocate_output(0, tensor_out_shape, &tensor_out)); auto* stream = context->op_device_context()->stream(); - const bool cast_to_float = !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); + const bool cast_to_float = IsBF16NotSupportedInOps(stream); if (cast_to_float) { Tensor casted_tensor_in; Tensor casted_tensor_out; @@ -872,8 +871,7 @@ void DnnPoolingGradOp::Compute( OP_REQUIRES_OK(context, context->allocate_output(0, tensor_in_shape, &input_backprop)); auto* stream = context->op_device_context()->stream(); - const bool cast_to_float = !stream->GetCudaComputeCapability().IsAtLeast( - se::CudaComputeCapability::AMPERE); + const bool cast_to_float = IsBF16NotSupportedInOps(stream); if (cast_to_float) { Tensor casted_tensor_in; Tensor casted_tensor_out; diff --git a/tensorflow/core/kernels/topk_op_gpu.h b/tensorflow/core/kernels/topk_op_gpu.h index 278ebeb172f79c..219798a4b072e5 100644 --- a/tensorflow/core/kernels/topk_op_gpu.h +++ b/tensorflow/core/kernels/topk_op_gpu.h @@ -483,25 +483,16 @@ Status LaunchSortKernel(OpKernelContext* ctx, const T* input, int num_rows, bool ran_nonsegmented_version = false; if (num_rows == 1) { -#if GOOGLE_CUDA - constexpr bool is_supported = true; -#else - // GpuRadixSortDescending is not supported on ROCm for fp16/bf16. - constexpr bool is_supported = !std::is_same::value && - !std::is_same::value; -#endif - if constexpr (is_supported) { - // Note: DeviceSegmentedRadixSort is very slow when num_segments=1 because - // it only uses 1 SM per segment. Calling the un-segmented version is much - // faster in this case. - TF_RETURN_IF_ERROR( - GpuRadixSortDescending(ctx, num_cols, /*keys_in=*/input, - /*keys_out=*/sorted_values_ptr, - /*indices_in=*/input_indices_t.data(), - /*indices_out=*/sorted_indices_ptr, - /*num_bits=*/sizeof(T) * 8)); - ran_nonsegmented_version = true; - } + // Note: DeviceSegmentedRadixSort is very slow when num_segments=1 because + // it only uses 1 SM per segment. Calling the un-segmented version is much + // faster in this case. + TF_RETURN_IF_ERROR( + GpuRadixSortDescending(ctx, num_cols, /*keys_in=*/input, + /*keys_out=*/sorted_values_ptr, + /*indices_in=*/input_indices_t.data(), + /*indices_out=*/sorted_indices_ptr, + /*num_bits=*/sizeof(T) * 8)); + ran_nonsegmented_version = true; } if (!ran_nonsegmented_version) { auto err = gpuprim::DeviceSegmentedRadixSort::SortPairsDescending( diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index c4e3dc7a4a492c..056a1d8d161326 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -3759,7 +3759,6 @@ cuda_py_test( main = "ops/nn_fused_batchnorm_test.py", python_version = "PY3", shard_count = 24, - tags = ["no_rocm"], deps = [ ":array_ops", ":client_testlib", diff --git a/tensorflow/python/kernel_tests/nn_ops/conv_ops_test.py b/tensorflow/python/kernel_tests/nn_ops/conv_ops_test.py index b62a21c469e12f..fff3284a588f0a 100644 --- a/tensorflow/python/kernel_tests/nn_ops/conv_ops_test.py +++ b/tensorflow/python/kernel_tests/nn_ops/conv_ops_test.py @@ -332,7 +332,7 @@ def _VerifyValues(self, gpu_only=False, test_grappler_layout_optimizer=False, tol=1e-5): - if gpu_only and not test.is_gpu_available(cuda_only=True): + if gpu_only and not test.is_gpu_available(): return tensors = [] dilations = list(dilations) @@ -840,8 +840,11 @@ def MakeConv2d(inputs, filters): results[0], results[1], atol=tol_to_use, rtol=tol_to_use) @test_util.run_in_graph_and_eager_modes + @test.disable_with_predicate( + pred=test.is_built_with_rocm, + skip_message='MIOpen does not support group conv yet!') def testConv2DGroupConvFwd(self): - if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): + if test.is_gpu_available() or test_util.IsMklEnabled(): data_formats = ["NHWC", "NCHW"] else: data_formats = ["NHWC"] @@ -857,7 +860,10 @@ def testConv2DGroupConvFwd(self): dtype=dtypes.float32) @test_util.deprecated_graph_mode_only - @test_util.run_cuda_only + @test_util.run_gpu_only + @test.disable_with_predicate( + pred=test.is_built_with_rocm, + skip_message='MIOpen does not support group conv yet!') def testInputGradientGroupConv(self): for data_format in ["NCHW", "NHWC"]: for test_input in [True, False]: @@ -879,7 +885,10 @@ def testInputGradientGroupConv(self): max_err=0.005) @test_util.deprecated_graph_mode_only - @test_util.run_cuda_only + @test_util.run_gpu_only + @test.disable_with_predicate( + pred=test.is_built_with_rocm, + skip_message='MIOpen does not support group conv yet!') def testFilterGradientGroupConv(self): for data_format in ["NCHW", "NHWC"]: for test_input in [True, False]: @@ -917,7 +926,7 @@ def _RunAndVerifyBackpropInput(self, use_gpu, err, dilations=(1, 1)): - if use_gpu and not test.is_gpu_available(cuda_only=True): + if use_gpu and not test.is_gpu_available(): return x1 = self._CreateNumpyTensor(filter_sizes) x2 = self._CreateNumpyTensor(output_sizes) @@ -1403,7 +1412,7 @@ def _RunAndVerifyBackpropFilterDilation(self, input_sizes, filter_sizes, @test_util.deprecated_graph_mode_only def testConv2D2x2Depth3ValidBackpropFilterStride1x1Dilation2x1(self): - if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): + if test.is_gpu_available() or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropFilterDilation( input_sizes=[1, 3, 6, 1], @@ -1418,7 +1427,7 @@ def testConv2D2x2Depth3ValidBackpropFilterStride1x1Dilation2x1(self): @test_util.deprecated_graph_mode_only def testConv2D2x2Depth1ValidBackpropFilterDilation1x2(self): - if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): + if test.is_gpu_available() or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropFilterDilation( input_sizes=[1, 2, 3, 1], @@ -1433,7 +1442,7 @@ def testConv2D2x2Depth1ValidBackpropFilterDilation1x2(self): @test_util.deprecated_graph_mode_only def testConv2DEmptyBackpropFilterDilation1x2(self): - if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): + if test.is_gpu_available() or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropFilterDilation( input_sizes=[1, 2, 3, 1], @@ -1448,7 +1457,7 @@ def testConv2DEmptyBackpropFilterDilation1x2(self): @test_util.deprecated_graph_mode_only def testConv2D2x2Depth3ValidBackpropFilterDilation2x2(self): - if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): + if test.is_gpu_available() or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropFilterDilation( input_sizes=[1, 3, 4, 3], @@ -1463,7 +1472,7 @@ def testConv2D2x2Depth3ValidBackpropFilterDilation2x2(self): @test_util.deprecated_graph_mode_only def testConv2DKernelSizeMatchesInputSizeBackpropFilterDilation2x2(self): - if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): + if test.is_gpu_available() or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropFilterDilation( input_sizes=[1, 3, 3, 1], @@ -1478,7 +1487,7 @@ def testConv2DKernelSizeMatchesInputSizeBackpropFilterDilation2x2(self): @test_util.deprecated_graph_mode_only def testConv2D2x2Depth3ValidBackpropInputStride1x1Dilation2x1(self): - if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): + if test.is_gpu_available() or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropInputDilation( input_sizes=[1, 3, 6, 1], @@ -1493,7 +1502,7 @@ def testConv2D2x2Depth3ValidBackpropInputStride1x1Dilation2x1(self): @test_util.deprecated_graph_mode_only def testConv2D2x2Depth1ValidBackpropInputDilation1x2(self): - if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): + if test.is_gpu_available() or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropInputDilation( input_sizes=[1, 2, 3, 1], @@ -1508,7 +1517,7 @@ def testConv2D2x2Depth1ValidBackpropInputDilation1x2(self): @test_util.deprecated_graph_mode_only def testConv2DEmptyBackpropInputDilation1x2(self): - if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): + if test.is_gpu_available() or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropInputDilation( input_sizes=[0, 2, 3, 1], @@ -1523,7 +1532,7 @@ def testConv2DEmptyBackpropInputDilation1x2(self): @test_util.deprecated_graph_mode_only def testConv2D2x2Depth3ValidBackpropInputDilation2x1(self): - if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): + if test.is_gpu_available() or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): # The GPU version of this test is not very stable. So adjusting the # error threshold to 1e-4. @@ -1540,7 +1549,7 @@ def testConv2D2x2Depth3ValidBackpropInputDilation2x1(self): @test_util.deprecated_graph_mode_only def testConv2DKernelSizeMatchesInputSizeBackpropInputDilation2x2(self): - if test.is_gpu_available(cuda_only=True) or test_util.IsMklEnabled(): + if test.is_gpu_available() or test_util.IsMklEnabled(): for (data_format, use_gpu) in GetTestConfigs(): self._RunAndVerifyBackpropInputDilation( input_sizes=[1, 3, 3, 1], @@ -1563,7 +1572,7 @@ def _RunAndVerifyBackpropInputExplicitPadding(self, use_gpu, dilations=(1, 1), err=2e-5): - if use_gpu and not test.is_gpu_available(cuda_only=True): + if use_gpu and not test.is_gpu_available(): return if not use_gpu and dilations != (1, 1): return # Non-default dilations is currently not supported on the CPU. @@ -1725,7 +1734,7 @@ def _RunAndVerifyBackpropFilterExplicitPadding(self, use_gpu, dilations=(1, 1), err=1e-5): - if use_gpu and not test.is_gpu_available(cuda_only=True): + if use_gpu and not test.is_gpu_available(): return if not use_gpu and dilations != (1, 1): return # Non-default dilations is currently not supported on the CPU. diff --git a/tensorflow/python/kernel_tests/nn_ops/depthwise_conv_op_base.py b/tensorflow/python/kernel_tests/nn_ops/depthwise_conv_op_base.py index a9f63ad6ce9a94..19be6a7d74a423 100644 --- a/tensorflow/python/kernel_tests/nn_ops/depthwise_conv_op_base.py +++ b/tensorflow/python/kernel_tests/nn_ops/depthwise_conv_op_base.py @@ -407,7 +407,7 @@ def _VerifyValues(self, interface_result, np_result, atol=tolerance, rtol=tolerance) @test_util.run_v1_only("b/120545219") - @test_util.run_cuda_only + @test_util.run_gpu_only def testDepthwiseConv2DCudnn(self): for index, (input_size, filter_size, _, stride, padding, dilations) in enumerate(ConfigsToTest()): @@ -510,8 +510,8 @@ def testDepthwiseConv2DExplicit(self): "Testing DepthwiseConv2D, %dth config: %r * %r, stride: %d, padding: " "%s", index, input_size, filter_size, stride, padding) # double datatype is currently not supported for convolution ops - # on the ROCm platform and its support for bfloat16 is unknown. - data_types = [dtypes.float16, dtypes.float32] + # on the ROCm platform + data_types = [dtypes.float16, dtypes.float32, dtypes.bfloat16] if not test.is_built_with_rocm(): data_types.extend([dtypes.float64, dtypes.bfloat16]) data_formats = ["NHWC", "NCHW"] if test.is_gpu_available() else ["NHWC"] @@ -736,7 +736,7 @@ def _ConstructAndTestGradient(self, self.assertLess(err, tolerance) @test_util.run_v1_only("b/120545219") - @test_util.run_cuda_only + @test_util.run_gpu_only def testDepthwiseConv2DInputGradCudnn(self): for index, (input_size, filter_size, output_size, stride, padding, dilations) in enumerate(CheckGradConfigsToTest()): @@ -832,8 +832,8 @@ def testDepthwiseConv2DInputGradExplicit(self): "stride: %d, padding: %s", index, input_size, filter_size, stride, padding) # double datatype is currently not supported for convolution ops - # on the ROCm platform and its support for bfloat16 is unknown. - data_types = [dtypes.float16, dtypes.float32] + # on the ROCm platform + data_types = [dtypes.float16, dtypes.float32, dtypes.bfloat16] if not test.is_built_with_rocm(): data_types.extend([dtypes.float64, dtypes.bfloat16]) data_formats = ["NHWC", "NCHW"] if test.is_gpu_available() else ["NHWC"] @@ -852,7 +852,7 @@ def testDepthwiseConv2DInputGradExplicit(self): dilations=dilations) @test_util.run_v1_only("b/120545219") - @test_util.run_cuda_only + @test_util.run_gpu_only def testDepthwiseConv2DFilterGradCudnn(self): for index, (input_size, filter_size, output_size, stride, padding, dilations) in enumerate(CheckGradConfigsToTest()): @@ -945,8 +945,8 @@ def testDepthwiseConv2DFilterGradExplicit(self): "stride: %d, padding: %s", index, input_size, filter_size, stride, padding) # double datatype is currently not supported for convolution ops - # on the ROCm platform and its support for bfloat16 is unknown. - data_types = [dtypes.float16, dtypes.float32] + # on the ROCm platform + data_types = [dtypes.float16, dtypes.float32, dtypes.bfloat16] if not test.is_built_with_rocm(): data_types.extend([dtypes.float64, dtypes.bfloat16]) data_formats = ["NHWC", "NCHW"] if test.is_gpu_available() else ["NHWC"] @@ -999,14 +999,13 @@ def testDepthwiseConv2DInputGradCompare(self): padding) self._CompareBackpropInput(input_size, filter_size, output_size, stride, padding, "float32") - # Convolutions on the ROCm platform don't support double dtype. And its - # support for bf16 is unknown. So, we skip these tests. - if test.is_built_with_rocm(): - continue - self._CompareBackpropInput(input_size, filter_size, output_size, stride, - padding, "float64") self._CompareBackpropInput(input_size, filter_size, output_size, stride, padding, "bfloat16") + # Convolutions on the ROCm platform don't support double dtype. + # So, we skip these tests. + if not test.is_built_with_rocm(): + self._CompareBackpropInput(input_size, filter_size, output_size, stride, + padding, "float64") @test_util.run_gpu_only def testDepthwiseConv2DInputGradExplicitCompare(self): @@ -1020,14 +1019,12 @@ def testDepthwiseConv2DInputGradExplicitCompare(self): padding) self._CompareBackpropInput(input_size, filter_size, output_size, stride, padding, "float32") - # Convolutions on the ROCm platform don't support double dtype. And its - # support for bf16 is unknown. So, we skip these tests. - if test.is_built_with_rocm(): - continue - self._CompareBackpropInput(input_size, filter_size, output_size, stride, - padding, "float64") self._CompareBackpropInput(input_size, filter_size, output_size, stride, padding, "bfloat16") + # Convolutions on the ROCm platform don't support double dtype. + if not test.is_built_with_rocm(): + self._CompareBackpropInput(input_size, filter_size, output_size, stride, + padding, "float64") def _CompareBackpropFilter(self, input_sizes, filter_sizes, output_sizes, stride, padding, dtype): @@ -1080,15 +1077,12 @@ def testDepthwiseConv2DFilterGradCompare(self): padding) self._CompareBackpropFilter(input_size, filter_size, output_size, stride, padding, "float32") - # Convolutions on the ROCm platform don't support double dtype. And its - # support for bf16 is unknown. So, we skip these tests. - if test.is_built_with_rocm(): - continue - self._CompareBackpropFilter(input_size, filter_size, output_size, stride, - padding, "float64") - self._CompareBackpropFilter(input_size, filter_size, output_size, stride, padding, "bfloat16") + # Convolutions on the ROCm platform don't support double dtype. + if not test.is_built_with_rocm(): + self._CompareBackpropFilter(input_size, filter_size, output_size, stride, + padding, "float64") @test_util.run_gpu_only def testDepthwiseConv2DFilterGradExplicitCompare(self): @@ -1102,15 +1096,12 @@ def testDepthwiseConv2DFilterGradExplicitCompare(self): padding) self._CompareBackpropFilter(input_size, filter_size, output_size, stride, padding, "float32") - # Convolutions on the ROCm platform don't support double dtype. And its - # support for bf16 is unknown. So, we skip these tests. - if test.is_built_with_rocm(): - continue - self._CompareBackpropFilter(input_size, filter_size, output_size, stride, - padding, "float64") - self._CompareBackpropFilter(input_size, filter_size, output_size, stride, padding, "bfloat16") + # Convolutions on the ROCm platform don't support double dtype. + if not test.is_built_with_rocm(): + self._CompareBackpropFilter(input_size, filter_size, output_size, stride, + padding, "float64") def _CompareForward(self, input_sizes, filter_sizes, output_sizes, stride, padding, dtype): @@ -1146,15 +1137,12 @@ def testDepthwiseConv2DForwardCompare(self): padding) self._CompareForward(input_size, filter_size, output_size, stride, padding, "float32") - # Convolutions on the ROCm platform don't support double dtype. And its - # support for bf16 is unknown. So, we skip these tests. - if test.is_built_with_rocm(): - continue - self._CompareForward(input_size, filter_size, output_size, stride, - padding, "float64") - self._CompareForward(input_size, filter_size, output_size, stride, padding, "bfloat16") + # Convolutions on the ROCm platform don't support double dtype. + if not test.is_built_with_rocm(): + self._CompareForward(input_size, filter_size, output_size, stride, + padding, "float64") @test_util.run_gpu_only def testDepthwiseConv2DForwardExplicitCompare(self): @@ -1166,14 +1154,11 @@ def testDepthwiseConv2DForwardExplicitCompare(self): "Testing DepthwiseConv2DForwardCompare, %dth config: %r * %r, " "stride: %d, padding: %s", index, input_size, filter_size, stride, padding) - # Convolutions on the ROCm platform don't support double dtype. And its - # support for bf16 is unknown. So, we skip these tests. - if test.is_built_with_rocm(): - continue - self._CompareForward(input_size, filter_size, output_size, stride, - padding, "float64") self._CompareForward(input_size, filter_size, output_size, stride, padding, "float32") - self._CompareForward(input_size, filter_size, output_size, stride, padding, "bfloat16") + # Convolutions on the ROCm platform don't support double dtype. + if not test.is_built_with_rocm(): + self._CompareForward(input_size, filter_size, output_size, stride, + padding, "float64") diff --git a/tensorflow/python/ops/nn_fused_batchnorm_test.py b/tensorflow/python/ops/nn_fused_batchnorm_test.py index c54969c144225c..94a0e8e50b33c8 100644 --- a/tensorflow/python/ops/nn_fused_batchnorm_test.py +++ b/tensorflow/python/ops/nn_fused_batchnorm_test.py @@ -406,7 +406,7 @@ def _runtests(self, x_shape, is_training, gradient_test=False, else: data_format_list = ['NCDHW', 'NDHWC'] use_gpu_vals = [False] - if test.is_gpu_available(cuda_only=True) and not cpu_only: + if test.is_gpu_available() and not cpu_only: use_gpu_vals += [True] factors = [1.0, 0.6] for dtype in [np.float16, np.float32, dtypes.bfloat16.as_numpy_dtype]: @@ -596,7 +596,7 @@ def _testBatchNormGradGrad(self, config): data_format_nhwc, features_nhwc = 'NDHWC', shape[4] data_format_nchw, features_nchw = 'NCDHW', shape[1] for is_training in [True, False]: - if test.is_gpu_available(cuda_only=True): + if test.is_gpu_available(): self._test_grad_grad( shape, dtype, [features_nhwc],