From 9130d919aac9e975e839a1f1992ad50dcc796e04 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 5 Oct 2024 16:45:03 +0000 Subject: [PATCH 1/3] Refactor data_type.h and c_runtime_api.h This commit refactors the `data_type.h` and `c_runtime_api.h` files. It introduces a new function `is_vector()` in the `DataType` class to check if a type is a vector type. Additionally, it adds a new constant `kTVMGridConstant` in the `TVMTypeCode` enum in `c_runtime_api.h`. These changes improve the code organization and provide better support for vector types. --- include/tvm/runtime/c_runtime_api.h | 1 + include/tvm/runtime/data_type.h | 2 ++ include/tvm/topi/elemwise.h | 2 +- src/target/llvm/codegen_llvm.cc | 2 +- src/target/llvm/intrin_rule_hexagon.cc | 8 ++++---- src/tir/analysis/verify_gpu_code.cc | 8 ++++---- 6 files changed, 13 insertions(+), 10 deletions(-) diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index d26c95e4f53c..f2faae01d305 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -195,6 +195,7 @@ typedef enum { kTVMExtBegin = 16U, kTVMNNVMFirst = 16U, kTVMNNVMLast = 20U, + kTVMGridConstant = 30U, // The following section of code is used for non-reserved types. kTVMExtReserveEnd = 64U, kTVMExtEnd = 128U, diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h index a330ccbbdf65..c49fde1746bc 100644 --- a/include/tvm/runtime/data_type.h +++ b/include/tvm/runtime/data_type.h @@ -148,6 +148,8 @@ class DataType { bool is_fixed_length_vector() const { return static_cast(data_.lanes) > 1; } /*! \return Whether the type is a scalable vector. */ bool is_scalable_vector() const { return static_cast(data_.lanes) < -1; } + /*! \return whether type is a vector type. */ + bool is_vector() const { return lanes() > 1; } /*! \return whether type is a bool vector type. */ bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && bits() == 1; } /*! \return whether type is a Void type. */ diff --git a/include/tvm/topi/elemwise.h b/include/tvm/topi/elemwise.h index 132992c57dc7..806ddcb662f9 100644 --- a/include/tvm/topi/elemwise.h +++ b/include/tvm/topi/elemwise.h @@ -287,7 +287,7 @@ inline Tensor cast(const Tensor& x, DataType type, std::string name = "T_cast", if (expr.dtype().code() == type.code() && expr.dtype().bits() == type.bits()) { if (expr.dtype().lanes() == type.lanes()) { return expr; - } else if (expr.dtype().lanes() == 1 && type.lanes() > 1) { + } else if (expr.dtype().lanes() == 1 && type.is_vector()) { return tvm::tir::Broadcast(expr, type.lanes()); } } diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index e21436e556ee..3d6d3a9461d3 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1737,7 +1737,7 @@ void CodeGenLLVM::BufferAccessHelper( if (const RampNode* ramp = last_index.as()) { PrimExpr offset = ramp->base + (ramp->stride * i); last_index_value = MakeValue(offset); - } else if (last_index.dtype().lanes() > 1) { + } else if (last_index.dtype().is_vector()) { if (i == 0) { cached_vector_index = MakeValue(last_index); } diff --git a/src/target/llvm/intrin_rule_hexagon.cc b/src/target/llvm/intrin_rule_hexagon.cc index 7c4b38c1d702..a5a081cf32d9 100644 --- a/src/target/llvm/intrin_rule_hexagon.cc +++ b/src/target/llvm/intrin_rule_hexagon.cc @@ -66,7 +66,7 @@ inline PrimExpr DispatchTVMQHLWrapperFp16(const PrimExpr& e) { // Enable QHL library for FP16 data type const PrimExpr& x = call->args[0]; - if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) { + if (x->dtype.is_float16() && x->dtypeis_vector() && useqhl) { return TVMExternCall(call, tvm_wrapper); } #endif @@ -116,7 +116,7 @@ TVM_REGISTER_OP("tir.tanh") } // Enable QHL library for FP16 data type - if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) { + if (x->dtype.is_float16() && x->dtypeis_vector() && useqhl) { std::string tvm_wrapper("tvm_vect_qhmath_hvx_tanh_ahf"); return TVMExternCall(call, tvm_wrapper); } @@ -152,7 +152,7 @@ TVM_REGISTER_OP("tir.tan").set_attr( } // Enable QHL library for FP16 data type - if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) { + if (x->dtype.is_float16() && x->dtypeis_vector() && useqhl) { std::string tvm_wrapper("tvm_vect_qhmath_hvx_tan_ahf"); return TVMExternCall(call, tvm_wrapper); } @@ -191,7 +191,7 @@ TVM_REGISTER_OP("tir.sigmoid") const tir::Call new_call = tir::Call(call->dtype, call->op, new_args); // Enable QHL library for FP16 data type - if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) { + if (x->dtype.is_float16() && x->dtypeis_vector() && useqhl) { std::string tvm_wrapper("tvm_vect_qhmath_hvx_sigmoid_ahf"); return TVMExternCall(new_call.get(), tvm_wrapper); } diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index f012f8a1b35e..16ebd0ad34b2 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -71,7 +71,7 @@ class GPUCodeVerifier : public StmtExprVisitor { size_t size = static_cast(op->ConstantAllocationSize()); shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); } - if (op->dtype.lanes() > 1) { + if (op->dtypeis_vector()) { if (static_cast(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { std::stringstream s; s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes (" @@ -202,7 +202,7 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitExpr_(const CastNode* op) { - if (op->dtype.lanes() > 1) { + if (op->dtypeis_vector()) { if (static_cast(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { std::stringstream s; s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes (" @@ -215,7 +215,7 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitExpr_(const BufferLoadNode* op) { - if (op->dtype.lanes() > 1) { + if (op->dtypeis_vector()) { if (static_cast(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { std::stringstream s; s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes (" @@ -229,7 +229,7 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitStmt_(const BufferStoreNode* op) { - if (op->value->dtype.lanes() > 1) { + if (op->value->dtypeis_vector()) { if (static_cast(op->value->dtype.lanes() * op->value->dtype.bytes()) > max_vector_bytes_) { std::stringstream s; From d9670f90a3cc8c4e292afba99fd81b452f9afdb9 Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 5 Oct 2024 17:02:20 +0000 Subject: [PATCH 2/3] revert kTVMGridConstant --- include/tvm/runtime/c_runtime_api.h | 1 - 1 file changed, 1 deletion(-) diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index f2faae01d305..d26c95e4f53c 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -195,7 +195,6 @@ typedef enum { kTVMExtBegin = 16U, kTVMNNVMFirst = 16U, kTVMNNVMLast = 20U, - kTVMGridConstant = 30U, // The following section of code is used for non-reserved types. kTVMExtReserveEnd = 64U, kTVMExtEnd = 128U, From 6523e81de0c3165daca2c4a382393ef620ea3b8d Mon Sep 17 00:00:00 2001 From: LeiWang1999 Date: Sat, 5 Oct 2024 17:04:37 +0000 Subject: [PATCH 3/3] lint fix --- src/target/llvm/intrin_rule_hexagon.cc | 8 ++++---- src/tir/analysis/verify_gpu_code.cc | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/target/llvm/intrin_rule_hexagon.cc b/src/target/llvm/intrin_rule_hexagon.cc index a5a081cf32d9..2661f2fa6591 100644 --- a/src/target/llvm/intrin_rule_hexagon.cc +++ b/src/target/llvm/intrin_rule_hexagon.cc @@ -66,7 +66,7 @@ inline PrimExpr DispatchTVMQHLWrapperFp16(const PrimExpr& e) { // Enable QHL library for FP16 data type const PrimExpr& x = call->args[0]; - if (x->dtype.is_float16() && x->dtypeis_vector() && useqhl) { + if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) { return TVMExternCall(call, tvm_wrapper); } #endif @@ -116,7 +116,7 @@ TVM_REGISTER_OP("tir.tanh") } // Enable QHL library for FP16 data type - if (x->dtype.is_float16() && x->dtypeis_vector() && useqhl) { + if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) { std::string tvm_wrapper("tvm_vect_qhmath_hvx_tanh_ahf"); return TVMExternCall(call, tvm_wrapper); } @@ -152,7 +152,7 @@ TVM_REGISTER_OP("tir.tan").set_attr( } // Enable QHL library for FP16 data type - if (x->dtype.is_float16() && x->dtypeis_vector() && useqhl) { + if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) { std::string tvm_wrapper("tvm_vect_qhmath_hvx_tan_ahf"); return TVMExternCall(call, tvm_wrapper); } @@ -191,7 +191,7 @@ TVM_REGISTER_OP("tir.sigmoid") const tir::Call new_call = tir::Call(call->dtype, call->op, new_args); // Enable QHL library for FP16 data type - if (x->dtype.is_float16() && x->dtypeis_vector() && useqhl) { + if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) { std::string tvm_wrapper("tvm_vect_qhmath_hvx_sigmoid_ahf"); return TVMExternCall(new_call.get(), tvm_wrapper); } diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index 16ebd0ad34b2..8eda537579e7 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -71,7 +71,7 @@ class GPUCodeVerifier : public StmtExprVisitor { size_t size = static_cast(op->ConstantAllocationSize()); shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes(); } - if (op->dtypeis_vector()) { + if (op->dtype.is_vector()) { if (static_cast(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { std::stringstream s; s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes (" @@ -202,7 +202,7 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitExpr_(const CastNode* op) { - if (op->dtypeis_vector()) { + if (op->dtype.is_vector()) { if (static_cast(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { std::stringstream s; s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes (" @@ -215,7 +215,7 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitExpr_(const BufferLoadNode* op) { - if (op->dtypeis_vector()) { + if (op->dtype.is_vector()) { if (static_cast(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) { std::stringstream s; s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes (" @@ -229,7 +229,7 @@ class GPUCodeVerifier : public StmtExprVisitor { } void VisitStmt_(const BufferStoreNode* op) { - if (op->value->dtypeis_vector()) { + if (op->value->dtype.is_vector()) { if (static_cast(op->value->dtype.lanes() * op->value->dtype.bytes()) > max_vector_bytes_) { std::stringstream s;