Skip to content

Commit

Permalink
[TIR][Schedule] Add annotate_buffer_access primitive
Browse files Browse the repository at this point in the history
  • Loading branch information
qsqqsqqsq-intellif committed Oct 12, 2024
1 parent ab64835 commit be963f6
Show file tree
Hide file tree
Showing 12 changed files with 718 additions and 6 deletions.
11 changes: 11 additions & 0 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,17 @@ class ScheduleNode : public runtime::Object {
*/
virtual void RollingBuffer(const BlockRV& block_rv, int write_buffer_index) = 0;

/*!
* \brief Annotate the buffer access of a block
* \param block_rv The block to be annotated
* \param buffer_index The index of the buffer in block's read or write region
* \param buffer_index_type The type of the buffer index, kRead or kWrite.
* \param index_map The index map that defines the new read or write region
*/
virtual void AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type,
const IndexMap& index_map) = 0;

/******** Schedule: Misc ********/
/*! \brief A no-op that marks the start of postprocessing phase of scheduling */
virtual void EnterPostproc() = 0;
Expand Down
10 changes: 10 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1664,6 +1664,16 @@ constexpr const char* warp_execution = "warp_execution";
/*! \brief Mark that a block is disallowed in auto inline. */
constexpr const char* meta_schedule_inline_rule = "meta_schedule.inline_rule";

/*! \brief Mark that a block has an explicitly specified read region.
* This is used to override the default read region inference in TIR.
*/
constexpr const char* explicit_read_region = "explicit_read_region";

/*! \brief Mark that a block has an explicitly specified write region.
* This is used to override the default write region inference in TIR.
*/
constexpr const char* explicit_write_region = "explicit_write_region";

/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
Expand Down
136 changes: 136 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3907,3 +3907,139 @@ def unsafe_hide_buffer_access(
buf_type,
buf_index_array,
)

@type_checked
def annotate_buffer_access(
self, block: BlockRV, buffer_index: int, buf_type: str, gen_new_ranges: Callable
) -> None:
"""Annotate the read or write region of a block
Parameters
----------
block : BlockRV
The block to be annotated
buffer_index : int
The index of the buffer in block's read or write region
buf_type : str
The buffer type: "read" or "write"
gen_new_ranges : Callable
A function that takes the block's iter_vars and returns a
Tuple[Union[PrimExpr, Tuple[PrimExpr, PrimExpr]], ...]
which defines the new read or write region for the buffer.
Each element in the tuple can be:
- A single PrimExpr representing the iter_var itself
- A tuple of two PrimExprs representing the range (begin, end)
Examples
--------
Annotate a 2D read region for a buffer.
Before annotate_buffer_access, in TensorIR, the IR is:
.. code-block:: python
@T.prim_func
def before_annotate_buffer_access(
A: T.Buffer((128, 128), "float32"),
C: T.Buffer((128, 128), "float32")
) -> None:
B = T.alloc_buffer((128, 128), "float32")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0
Create the schedule and do annotate_buffer_access:
.. code-block:: python
sch = tir.Schedule(before_annotate_buffer_access)
block = sch.get_block("B")
sch.annotate_buffer_access(block, 0, "read",
lambda vi, vj: ((vi - 1, vi + 1), (vj - 1, vj + 1)))
print(sch.mod["main"].script())
After applying annotate_buffer_access, the IR becomes:
.. code-block:: python
@T.prim_func
def after_annotate_buffer_access(
A: T.Buffer((128, 128), "float32"),
C: T.Buffer((128, 128), "float32")
) -> None:
B = T.alloc_buffer((128, 128), "float32")
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.reads(A[vi - 1:vi + 1, vj - 1:vj + 1])
T.writes(B[vi, vj])
T.block_attr({"explicit_read_region": 0})
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(128, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi, vj] + 1.0
This annotates the read region for buffer A (index 0) in block "B" to be
[vi-1:vi+1, vj-1:vj+1] for each (vi, vj) in the block's iteration domain.
Note
----
This function allows manual specification of read or write regions, which
can be useful in cases where the compiler cannot accurately infer the
access pattern, such as complex data-dependent accesses.
It overrides the automatically inferred region for the specified buffer.
The function adds an annotation to the block, indicating that an explicit
region has been provided for the buffer at the given index. This annotation
is used in the CompactBufferAllocation pass to respect the manually specified
region instead of relying on automatic inference.
Caution should be exercised when using this function, as incorrect annotations
may lead to incorrect code generation or runtime errors. It's crucial to
ensure that the specified region covers all actual reads or writes performed
by the block for the given buffer.
"""
block_obj = self.get(block)
iter_vars = [x.var for x in block_obj.iter_vars]
new_ranges_spec = gen_new_ranges(*iter_vars)
if len(iter_vars) != len(new_ranges_spec):
raise ValueError(
f"Number of iter_vars ({len(iter_vars)}) must match "
f"number of new_ranges_spec ({len(new_ranges_spec)})"
)

result = []
for rng in new_ranges_spec:
if isinstance(rng, (tuple, list)):
if len(rng) != 2:
raise ValueError(
f"Tuple must have exactly 2 elements to represent (begin, end)."
)
result.extend(rng)
elif isinstance(rng, PrimExpr):
result.extend([rng, rng + 1]) # Single point represented as (rng, rng + 1)
else:
raise TypeError(f"Expected PrimExpr or tuple of PrimExpr, got {type(rng)}")

# Create index_map using IndexMap constructor
index_map = IndexMap(
initial_indices=iter_vars,
final_indices=result,
inverse_index_map=None,
)

if buf_type == "read":
buffer_index_type = 0
elif buf_type == "write":
buffer_index_type = 1
else:
raise ValueError(f"Invalid buf_type: {buf_type}. Expected 'read' or 'write'.")

return _ffi_api.ScheduleAnnotateBufferAccess(
self, block, buffer_index, buffer_index_type, index_map
)
10 changes: 10 additions & 0 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1059,5 +1059,15 @@ void ConcreteScheduleNode::UnsafeHideBufferAccess(const BlockRV& block_rv, const
this->state_->DebugVerify();
}

void ConcreteScheduleNode::AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type,
const IndexMap& index_map) {
TVM_TIR_SCHEDULE_BEGIN();
tir::AnnotateBufferAccess(state_, this->GetSRef(block_rv), buffer_index, buffer_index_type,
index_map);
TVM_TIR_SCHEDULE_END("annotate-buffer-access", this->error_render_level_);
this->state_->DebugVerify();
}

} // namespace tir
} // namespace tvm
2 changes: 2 additions & 0 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,8 @@ class ConcreteScheduleNode : public ScheduleNode {
void EnterPostproc() override {}
void UnsafeHideBufferAccess(const BlockRV& block_rv, const String& buf_type,
const Array<IntImm>& buf_index_array) override;
void AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type, const IndexMap& index_map) override;

protected:
/******** Utility functions ********/
Expand Down
10 changes: 10 additions & 0 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,16 @@ TVM_DLL void RollingBuffer(ScheduleState self, const StmtSRef& block_sref, int w
TVM_DLL void UnsafeHideBufferAccess(ScheduleState self, const StmtSRef& block_sref,
const String& buf_type, const Array<IntImm>& buf_index_array);

/*!
* \brief Annotate the read or write region of a specific buffer in a block
* \param self The state of the schedule
* \param block_sref The sref of the block to be annotated
* \param buffer_index The index of the buffer in block's read or write region
* \param buffer_index_type The type of the buffer index, kRead or kWrite
* \param index_map The IndexMap that defines the new read or write region for the buffer
*/
TVM_DLL void AnnotateBufferAccess(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
BufferIndexType buffer_index_type, const IndexMap& index_map);
} // namespace tir
} // namespace tvm

Expand Down
149 changes: 149 additions & 0 deletions src/tir/schedule/primitive/annotate_buffer_access.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
#include "../utils.h"

namespace tvm {
namespace tir {

class AnnotateRegionRewriter : public StmtExprMutator {
public:
AnnotateRegionRewriter(Buffer buffer, int buffer_index, BufferRegion new_region,
BufferIndexType buffer_index_type)
: buffer_(buffer),
buffer_index_(buffer_index),
new_region_(new_region),
buffer_index_type_(buffer_index_type) {}

Stmt VisitStmt_(const BlockNode* op) final {
Block block = Downcast<Block>(StmtExprMutator::VisitStmt_(op));

Array<BufferRegion> regions =
buffer_index_type_ == BufferIndexType::kWrite ? block->writes : block->reads;
ICHECK_GE(buffer_index_, 0) << "Buffer index must be non-negative";
ICHECK_LT(buffer_index_, static_cast<int>(regions.size())) << "Buffer index out of range";
regions.Set(buffer_index_, new_region_);

ObjectPtr<BlockNode> n = CopyOnWrite(block.get());
if (buffer_index_type_ == BufferIndexType::kWrite) {
n->writes = std::move(regions);
} else {
n->reads = std::move(regions);
}

// Annotate the block with explicit_read_region or explicit_write_region
Map<String, ObjectRef> new_annotations = n->annotations;
String annotation_key = buffer_index_type_ == BufferIndexType::kWrite
? attr::explicit_write_region
: attr::explicit_read_region;
if (new_annotations.count(annotation_key)) {
Array<Integer> buffer_indices = Downcast<Array<Integer>>(new_annotations[annotation_key]);
bool found = false;
for (const Integer& index : buffer_indices) {
if (index->value == buffer_index_) {
found = true;
break;
}
}
if (!found) {
buffer_indices.push_back(Integer(buffer_index_));
new_annotations.Set(annotation_key, buffer_indices);
}
} else {
new_annotations.Set(annotation_key, Array<Integer>{Integer(buffer_index_)});
}
n->annotations = std::move(new_annotations);

return Block(n);
}

private:
Buffer buffer_;
int buffer_index_;
BufferRegion new_region_;
BufferIndexType buffer_index_type_;
};

void AnnotateBufferAccess(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
BufferIndexType buffer_index_type, const IndexMap& index_map) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block_sref);
Buffer buffer = GetNthAccessBuffer(self, GetRef<Block>(block), buffer_index, buffer_index_type);

arith::Analyzer analyzer;
Array<PrimExpr> block_iter_vars;
for (const IterVar& iter_var : block->iter_vars) {
block_iter_vars.push_back(iter_var->var);
}
Array<PrimExpr> new_indices = index_map->MapIndices(block_iter_vars, &analyzer);
ICHECK_EQ(new_indices.size() % 2, 0) << "The size of new_indices should be even.";
Array<Range> new_ranges;
for (size_t i = 0; i < new_indices.size(); i += 2) {
// (begin, end) represents a region
new_ranges.push_back(Range::FromMinExtent(
new_indices[i], analyzer.Simplify(new_indices[i + 1] - new_indices[i])));
}

BufferRegion new_region(buffer, new_ranges);

AnnotateRegionRewriter mutator(buffer, buffer_index, new_region, buffer_index_type);
Stmt new_stmt = mutator(GetRef<Stmt>(block_sref->stmt));

self->Replace(block_sref, new_stmt, {{GetRef<Block>(block), Downcast<Block>(new_stmt)}});
}

struct AnnotateBufferAccessTraits : public UnpackedInstTraits<AnnotateBufferAccessTraits> {
static constexpr const char* kName = "AnnotateBufferAccess";
static constexpr bool kIsPure = false;

private:
static constexpr size_t kNumInputs = 4;
static constexpr size_t kNumAttrs = 0;
static constexpr size_t kNumDecisions = 0;

static void UnpackedApplyToSchedule(Schedule sch, BlockRV block, Integer buffer_index,
Integer buffer_index_type, IndexMap index_map) {
return sch->AnnotateBufferAccess(block, buffer_index->value,
static_cast<BufferIndexType>(buffer_index_type->value),
index_map);
}

static String IndexMap2GenNewRangesLambda(const IndexMap& index_map) {
std::ostringstream oss;
oss << "lambda ";
for (size_t i = 0; i < index_map->initial_indices.size(); ++i) {
if (i != 0) oss << ", ";
oss << index_map->initial_indices[i];
}
oss << ": [";
for (size_t i = 0; i < index_map->final_indices.size(); i += 2) {
if (i != 0) oss << ", ";
if (index_map->final_indices[i].same_as(index_map->final_indices[i + 1])) {
oss << index_map->final_indices[i];
} else {
oss << "(" << index_map->final_indices[i] << ", " << index_map->final_indices[i + 1] << ")";
}
}
oss << "]";
return String(oss.str());
}

static String UnpackedAsPython(Array<String> outputs, String block, Integer buffer_index,
Integer buffer_index_type, IndexMap index_map) {
PythonAPICall py("annotate_buffer_access");
py.Input("block", block);
py.Input("buffer_index", buffer_index->value);

std::ostringstream os;
os << "\"" << BufferIndexType2Str(static_cast<BufferIndexType>(buffer_index_type->value))
<< "\"";
py.Input("buf_type", os.str());

py.Input("gen_new_ranges", IndexMap2GenNewRangesLambda(index_map));
return py.Str();
}

template <typename>
friend struct ::tvm::tir::UnpackedInstTraits;
};

TVM_REGISTER_INST_KIND_TRAITS(AnnotateBufferAccessTraits);

} // namespace tir
} // namespace tvm
7 changes: 7 additions & 0 deletions src/tir/schedule/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,13 @@ TVM_REGISTER_GLOBAL("tir.schedule.ScheduleEnterPostproc")
.set_body_method<Schedule>(&ScheduleNode::EnterPostproc);
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleUnsafeHideBufferAccess")
.set_body_method<Schedule>(&ScheduleNode::UnsafeHideBufferAccess);
/******** (FFI) Annotate buffer access ********/
TVM_REGISTER_GLOBAL("tir.schedule.ScheduleAnnotateBufferAccess")
.set_body_typed([](Schedule self, const BlockRV& block_rv, int buffer_index,
int buffer_index_type, const IndexMap& index_map) {
return self->AnnotateBufferAccess(block_rv, buffer_index,
static_cast<BufferIndexType>(buffer_index_type), index_map);
});

} // namespace tir
} // namespace tvm
12 changes: 12 additions & 0 deletions src/tir/schedule/traced_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -769,5 +769,17 @@ void TracedScheduleNode::UnsafeHideBufferAccess(const BlockRV& block_rv, const S
/*outputs=*/{}));
}

void TracedScheduleNode::AnnotateBufferAccess(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type,
const IndexMap& index_map) {
ConcreteScheduleNode::AnnotateBufferAccess(block_rv, buffer_index, buffer_index_type, index_map);
static const InstructionKind& kind = InstructionKind::Get("AnnotateBufferAccess");
trace_->Append(/*inst=*/Instruction(
/*kind=*/kind,
/*inputs=*/{block_rv, Integer(buffer_index), Integer(buffer_index_type), index_map},
/*attrs=*/{},
/*outputs=*/{}));
}

} // namespace tir
} // namespace tvm
Loading

0 comments on commit be963f6

Please sign in to comment.