Skip to content

Commit

Permalink
Enable emission of float16/32 casts on x86 (#7837)
Browse files Browse the repository at this point in the history
* Enable emission of float16/32 casts on x86

Fixes #7836
Fixes #4166

* Fix comment

* Don't catch bfloat casts

* Fix missing word in comment
  • Loading branch information
abadams authored Sep 6, 2023
1 parent 02865e2 commit 836879e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 27 deletions.
43 changes: 21 additions & 22 deletions src/CodeGen_X86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,6 @@ class CodeGen_X86 : public CodeGen_Posix {

int vector_lanes_for_slice(const Type &t) const;

llvm::Type *llvm_type_of(const Type &t) const override;

using CodeGen_Posix::visit;

void init_module() override;
Expand Down Expand Up @@ -488,9 +486,25 @@ void CodeGen_X86::visit(const Select *op) {
}

void CodeGen_X86::visit(const Cast *op) {
Type src = op->value.type();
Type dst = op->type;

if (target.has_feature(Target::F16C) &&
dst.code() == Type::Float &&
src.code() == Type::Float &&
(dst.bits() == 16 || src.bits() == 16)) {
// Node we use code() == Type::Float instead of is_float(), because we
// don't want to catch bfloat casts.

// This target doesn't support full float16 arithmetic, but it *does*
// support float16 casts, so we emit a vanilla LLVM cast node.
value = codegen(op->value);
value = builder->CreateFPCast(value, llvm_type_of(dst));
return;
}

if (!op->type.is_vector()) {
// We only have peephole optimizations for vectors in here.
if (!dst.is_vector()) {
// We only have peephole optimizations for vectors after this point.
CodeGen_Posix::visit(op);
return;
}
Expand All @@ -513,20 +527,20 @@ void CodeGen_X86::visit(const Cast *op) {
vector<Expr> matches;
for (const Pattern &p : patterns) {
if (expr_match(p.pattern, op, matches)) {
value = call_overloaded_intrin(op->type, p.intrin, matches);
value = call_overloaded_intrin(dst, p.intrin, matches);
if (value) {
return;
}
}
}

if (const Call *mul = Call::as_intrinsic(op->value, {Call::widening_mul})) {
if (op->value.type().bits() < op->type.bits() && op->type.bits() <= 32) {
if (src.bits() < dst.bits() && dst.bits() <= 32) {
// LLVM/x86 really doesn't like 8 -> 16 bit multiplication. If we're
// widening to 32-bits after a widening multiply, LLVM prefers to see a
// widening multiply directly to 32-bits. This may result in extra
// casts, so simplify to remove them.
value = codegen(simplify(Mul::make(Cast::make(op->type, mul->args[0]), Cast::make(op->type, mul->args[1]))));
value = codegen(simplify(Mul::make(Cast::make(dst, mul->args[0]), Cast::make(dst, mul->args[1]))));
return;
}
}
Expand Down Expand Up @@ -997,21 +1011,6 @@ int CodeGen_X86::vector_lanes_for_slice(const Type &t) const {
return slice_bits / t.bits();
}

llvm::Type *CodeGen_X86::llvm_type_of(const Type &t) const {
if (t.is_float() && t.bits() < 32) {
// LLVM as of August 2019 has all sorts of issues in the x86
// backend for half types. It injects expensive calls to
// convert between float and half for seemingly no reason
// (e.g. to do a select), and bitcasting to int16 doesn't
// help, because it simplifies away the bitcast for you.
// See: https://bugs.llvm.org/show_bug.cgi?id=43065
// and: https://github.com/halide/Halide/issues/4166
return llvm_type_of(t.with_code(halide_type_uint));
} else {
return CodeGen_Posix::llvm_type_of(t);
}
}

} // namespace

std::unique_ptr<CodeGen_Posix> new_CodeGen_X86(const Target &target) {
Expand Down
19 changes: 14 additions & 5 deletions test/correctness/simd_op_check_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class SimdOpCheckX86 : public SimdOpCheckTest {
void check_sse_and_avx() {
Expr f64_1 = in_f64(x), f64_2 = in_f64(x + 16), f64_3 = in_f64(x + 32);
Expr f32_1 = in_f32(x), f32_2 = in_f32(x + 16), f32_3 = in_f32(x + 32);
Expr f16_1 = in_f16(x), f16_2 = in_f16(x + 16), f16_3 = in_f16(x + 32);
Expr i8_1 = in_i8(x), i8_2 = in_i8(x + 16), i8_3 = in_i8(x + 32);
Expr u8_1 = in_u8(x), u8_2 = in_u8(x + 16), u8_3 = in_u8(x + 32);
Expr i16_1 = in_i16(x), i16_2 = in_i16(x + 16), i16_3 = in_i16(x + 32);
Expand Down Expand Up @@ -496,6 +497,11 @@ class SimdOpCheckX86 : public SimdOpCheckTest {
check_x86_fixed_point("zmm", 2);
}

if (target.has_feature(Target::F16C)) {
check("vcvtps2ph", 8, cast(Float(16), f32_1));
check("vcvtph2ps", 8, cast(Float(32), f16_1));
}

check(use_avx512 ? "vpaddq*zmm" : "vpaddq*ymm", 8, i64_1 + i64_2);
check(use_avx512 ? "vpsubq*zmm" : "vpsubq*ymm", 8, i64_1 - i64_2);
check(use_avx512 ? "vpmullq" : "vpmuludq*ymm", 8, u64_1 * u64_2);
Expand Down Expand Up @@ -638,14 +644,17 @@ int main(int argc, char **argv) {
{
Target("x86-32-linux"),
Target("x86-32-linux-sse41"),
Target("x86-64-linux-sse41-avx"),
Target("x86-64-linux-sse41-avx-avx2"),
// Always turn on f16c when using avx. Sandy Bridge had avx without
// f16c, but f16c is orthogonal to everything else, so there's no
// real reason to test avx without it.
Target("x86-64-linux-sse41-avx-f16c"),
Target("x86-64-linux-sse41-avx-f16c-avx2"),
// See above: don't test avx512 without extra features, the test
// isn't yet set up to test it properly.
// Target("x86-64-linux-sse41-avx-avx2-avx512"),
// Target("x86-64-linux-sse41-avx-avx2-avx512-avx512_knl"),
Target("x86-64-linux-sse41-avx-avx2-avx512-avx512_skylake"),
Target("x86-64-linux-sse41-avx-avx2-avx512-avx512_skylake-avx512_cannonlake"),
Target("x86-64-linux-sse41-avx-avx2-avx512-avx512_skylake-avx512_cannonlake-avx512_sapphirerapids"),
Target("x86-64-linux-sse41-avx-f16c-avx2-avx512-avx512_skylake"),
Target("x86-64-linux-sse41-avx-f16c-avx2-avx512-avx512_skylake-avx512_cannonlake"),
Target("x86-64-linux-sse41-avx-f16c-avx2-avx512-avx512_skylake-avx512_cannonlake-avx512_sapphirerapids"),
});
}

0 comments on commit 836879e

Please sign in to comment.