From 836879eee270f20c6096f9171fcda6cc771bd26b Mon Sep 17 00:00:00 2001 From: Andrew Adams Date: Wed, 6 Sep 2023 14:29:27 -0700 Subject: [PATCH] Enable emission of float16/32 casts on x86 (#7837) * Enable emission of float16/32 casts on x86 Fixes #7836 Fixes #4166 * Fix comment * Don't catch bfloat casts * Fix missing word in comment --- src/CodeGen_X86.cpp | 43 +++++++++++++------------- test/correctness/simd_op_check_x86.cpp | 19 +++++++++--- 2 files changed, 35 insertions(+), 27 deletions(-) diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 45d4e224e277..c08e8064bded 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -67,8 +67,6 @@ class CodeGen_X86 : public CodeGen_Posix { int vector_lanes_for_slice(const Type &t) const; - llvm::Type *llvm_type_of(const Type &t) const override; - using CodeGen_Posix::visit; void init_module() override; @@ -488,9 +486,25 @@ void CodeGen_X86::visit(const Select *op) { } void CodeGen_X86::visit(const Cast *op) { + Type src = op->value.type(); + Type dst = op->type; + + if (target.has_feature(Target::F16C) && + dst.code() == Type::Float && + src.code() == Type::Float && + (dst.bits() == 16 || src.bits() == 16)) { + // Node we use code() == Type::Float instead of is_float(), because we + // don't want to catch bfloat casts. + + // This target doesn't support full float16 arithmetic, but it *does* + // support float16 casts, so we emit a vanilla LLVM cast node. + value = codegen(op->value); + value = builder->CreateFPCast(value, llvm_type_of(dst)); + return; + } - if (!op->type.is_vector()) { - // We only have peephole optimizations for vectors in here. + if (!dst.is_vector()) { + // We only have peephole optimizations for vectors after this point. CodeGen_Posix::visit(op); return; } @@ -513,7 +527,7 @@ void CodeGen_X86::visit(const Cast *op) { vector matches; for (const Pattern &p : patterns) { if (expr_match(p.pattern, op, matches)) { - value = call_overloaded_intrin(op->type, p.intrin, matches); + value = call_overloaded_intrin(dst, p.intrin, matches); if (value) { return; } @@ -521,12 +535,12 @@ void CodeGen_X86::visit(const Cast *op) { } if (const Call *mul = Call::as_intrinsic(op->value, {Call::widening_mul})) { - if (op->value.type().bits() < op->type.bits() && op->type.bits() <= 32) { + if (src.bits() < dst.bits() && dst.bits() <= 32) { // LLVM/x86 really doesn't like 8 -> 16 bit multiplication. If we're // widening to 32-bits after a widening multiply, LLVM prefers to see a // widening multiply directly to 32-bits. This may result in extra // casts, so simplify to remove them. - value = codegen(simplify(Mul::make(Cast::make(op->type, mul->args[0]), Cast::make(op->type, mul->args[1])))); + value = codegen(simplify(Mul::make(Cast::make(dst, mul->args[0]), Cast::make(dst, mul->args[1])))); return; } } @@ -997,21 +1011,6 @@ int CodeGen_X86::vector_lanes_for_slice(const Type &t) const { return slice_bits / t.bits(); } -llvm::Type *CodeGen_X86::llvm_type_of(const Type &t) const { - if (t.is_float() && t.bits() < 32) { - // LLVM as of August 2019 has all sorts of issues in the x86 - // backend for half types. It injects expensive calls to - // convert between float and half for seemingly no reason - // (e.g. to do a select), and bitcasting to int16 doesn't - // help, because it simplifies away the bitcast for you. - // See: https://bugs.llvm.org/show_bug.cgi?id=43065 - // and: https://github.com/halide/Halide/issues/4166 - return llvm_type_of(t.with_code(halide_type_uint)); - } else { - return CodeGen_Posix::llvm_type_of(t); - } -} - } // namespace std::unique_ptr new_CodeGen_X86(const Target &target) { diff --git a/test/correctness/simd_op_check_x86.cpp b/test/correctness/simd_op_check_x86.cpp index f86134d37630..aa18c9685de7 100644 --- a/test/correctness/simd_op_check_x86.cpp +++ b/test/correctness/simd_op_check_x86.cpp @@ -45,6 +45,7 @@ class SimdOpCheckX86 : public SimdOpCheckTest { void check_sse_and_avx() { Expr f64_1 = in_f64(x), f64_2 = in_f64(x + 16), f64_3 = in_f64(x + 32); Expr f32_1 = in_f32(x), f32_2 = in_f32(x + 16), f32_3 = in_f32(x + 32); + Expr f16_1 = in_f16(x), f16_2 = in_f16(x + 16), f16_3 = in_f16(x + 32); Expr i8_1 = in_i8(x), i8_2 = in_i8(x + 16), i8_3 = in_i8(x + 32); Expr u8_1 = in_u8(x), u8_2 = in_u8(x + 16), u8_3 = in_u8(x + 32); Expr i16_1 = in_i16(x), i16_2 = in_i16(x + 16), i16_3 = in_i16(x + 32); @@ -496,6 +497,11 @@ class SimdOpCheckX86 : public SimdOpCheckTest { check_x86_fixed_point("zmm", 2); } + if (target.has_feature(Target::F16C)) { + check("vcvtps2ph", 8, cast(Float(16), f32_1)); + check("vcvtph2ps", 8, cast(Float(32), f16_1)); + } + check(use_avx512 ? "vpaddq*zmm" : "vpaddq*ymm", 8, i64_1 + i64_2); check(use_avx512 ? "vpsubq*zmm" : "vpsubq*ymm", 8, i64_1 - i64_2); check(use_avx512 ? "vpmullq" : "vpmuludq*ymm", 8, u64_1 * u64_2); @@ -638,14 +644,17 @@ int main(int argc, char **argv) { { Target("x86-32-linux"), Target("x86-32-linux-sse41"), - Target("x86-64-linux-sse41-avx"), - Target("x86-64-linux-sse41-avx-avx2"), + // Always turn on f16c when using avx. Sandy Bridge had avx without + // f16c, but f16c is orthogonal to everything else, so there's no + // real reason to test avx without it. + Target("x86-64-linux-sse41-avx-f16c"), + Target("x86-64-linux-sse41-avx-f16c-avx2"), // See above: don't test avx512 without extra features, the test // isn't yet set up to test it properly. // Target("x86-64-linux-sse41-avx-avx2-avx512"), // Target("x86-64-linux-sse41-avx-avx2-avx512-avx512_knl"), - Target("x86-64-linux-sse41-avx-avx2-avx512-avx512_skylake"), - Target("x86-64-linux-sse41-avx-avx2-avx512-avx512_skylake-avx512_cannonlake"), - Target("x86-64-linux-sse41-avx-avx2-avx512-avx512_skylake-avx512_cannonlake-avx512_sapphirerapids"), + Target("x86-64-linux-sse41-avx-f16c-avx2-avx512-avx512_skylake"), + Target("x86-64-linux-sse41-avx-f16c-avx2-avx512-avx512_skylake-avx512_cannonlake"), + Target("x86-64-linux-sse41-avx-f16c-avx2-avx512-avx512_skylake-avx512_cannonlake-avx512_sapphirerapids"), }); }