Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Add is_vector Method to DataType class and update usages across Codebase #17443

Merged
merged 3 commits into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ class DataType {
bool is_fixed_length_vector() const { return static_cast<int16_t>(data_.lanes) > 1; }
/*! \return Whether the type is a scalable vector. */
bool is_scalable_vector() const { return static_cast<int16_t>(data_.lanes) < -1; }
/*! \return whether type is a vector type. */
bool is_vector() const { return lanes() > 1; }
/*! \return whether type is a bool vector type. */
bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() && bits() == 1; }
/*! \return whether type is a Void type. */
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/topi/elemwise.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ inline Tensor cast(const Tensor& x, DataType type, std::string name = "T_cast",
if (expr.dtype().code() == type.code() && expr.dtype().bits() == type.bits()) {
if (expr.dtype().lanes() == type.lanes()) {
return expr;
} else if (expr.dtype().lanes() == 1 && type.lanes() > 1) {
} else if (expr.dtype().lanes() == 1 && type.is_vector()) {
return tvm::tir::Broadcast(expr, type.lanes());
}
}
Expand Down
2 changes: 1 addition & 1 deletion src/target/llvm/codegen_llvm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1737,7 +1737,7 @@ void CodeGenLLVM::BufferAccessHelper(
if (const RampNode* ramp = last_index.as<RampNode>()) {
PrimExpr offset = ramp->base + (ramp->stride * i);
last_index_value = MakeValue(offset);
} else if (last_index.dtype().lanes() > 1) {
} else if (last_index.dtype().is_vector()) {
if (i == 0) {
cached_vector_index = MakeValue(last_index);
}
Expand Down
8 changes: 4 additions & 4 deletions src/target/llvm/intrin_rule_hexagon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ inline PrimExpr DispatchTVMQHLWrapperFp16(const PrimExpr& e) {

// Enable QHL library for FP16 data type
const PrimExpr& x = call->args[0];
if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) {
return TVMExternCall(call, tvm_wrapper);
}
#endif
Expand Down Expand Up @@ -116,7 +116,7 @@ TVM_REGISTER_OP("tir.tanh")
}

// Enable QHL library for FP16 data type
if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) {
std::string tvm_wrapper("tvm_vect_qhmath_hvx_tanh_ahf");
return TVMExternCall(call, tvm_wrapper);
}
Expand Down Expand Up @@ -152,7 +152,7 @@ TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>(
}

// Enable QHL library for FP16 data type
if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) {
std::string tvm_wrapper("tvm_vect_qhmath_hvx_tan_ahf");
return TVMExternCall(call, tvm_wrapper);
}
Expand Down Expand Up @@ -191,7 +191,7 @@ TVM_REGISTER_OP("tir.sigmoid")
const tir::Call new_call = tir::Call(call->dtype, call->op, new_args);

// Enable QHL library for FP16 data type
if (x->dtype.is_float16() && x->dtype.lanes() > 1 && useqhl) {
if (x->dtype.is_float16() && x->dtype.is_vector() && useqhl) {
std::string tvm_wrapper("tvm_vect_qhmath_hvx_sigmoid_ahf");
return TVMExternCall(new_call.get(), tvm_wrapper);
}
Expand Down
8 changes: 4 additions & 4 deletions src/tir/analysis/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
size_t size = static_cast<size_t>(op->ConstantAllocationSize());
shared_memory_per_block_ += size * op->dtype.bytes() * op->dtype.lanes();
}
if (op->dtype.lanes() > 1) {
if (op->dtype.is_vector()) {
if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
std::stringstream s;
s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes ("
Expand Down Expand Up @@ -202,7 +202,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
}

void VisitExpr_(const CastNode* op) {
if (op->dtype.lanes() > 1) {
if (op->dtype.is_vector()) {
if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
std::stringstream s;
s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes ("
Expand All @@ -215,7 +215,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
}

void VisitExpr_(const BufferLoadNode* op) {
if (op->dtype.lanes() > 1) {
if (op->dtype.is_vector()) {
if (static_cast<size_t>(op->dtype.lanes() * op->dtype.bytes()) > max_vector_bytes_) {
std::stringstream s;
s << "Number of lanes (" << op->dtype.lanes() << ") times number of bytes ("
Expand All @@ -229,7 +229,7 @@ class GPUCodeVerifier : public StmtExprVisitor {
}

void VisitStmt_(const BufferStoreNode* op) {
if (op->value->dtype.lanes() > 1) {
if (op->value->dtype.is_vector()) {
if (static_cast<size_t>(op->value->dtype.lanes() * op->value->dtype.bytes()) >
max_vector_bytes_) {
std::stringstream s;
Expand Down
Loading