Skip to content

Commit

Permalink
Reduce Uniform/Independent-Elements iterator register footprint (#2383)
Browse files Browse the repository at this point in the history
Rewrite `UniformElements` and `IndependentElements` iterators to reduce the register footprint.
  - avoid multiple return within a function
  - reduce the iterator state size by one element
  • Loading branch information
psychocoderHPC authored Sep 20, 2024
1 parent 63ad43f commit 11ab218
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 49 deletions.
2 changes: 1 addition & 1 deletion example/bufferCopy/src/bufferCopy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ auto example(TAccTag const&) -> int

// Define the work division for kernels to be run on devAcc and devHost
using Vec = alpaka::Vec<Dim, Idx>;
Vec const elementsPerThread(Vec::all(static_cast<Idx>(1)));
Vec const elementsPerThread(Vec::all(static_cast<Idx>(3)));
Vec const elementsPerGrid(Vec::all(static_cast<Idx>(10)));

// Create host and device buffers
Expand Down
50 changes: 22 additions & 28 deletions include/alpaka/exec/IndependentElements.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ namespace alpaka
};

private:
const Idx first_;
const Idx stride_;
const Idx extent_;
Idx const first_;
Idx const stride_;
Idx const extent_;
};

} // namespace detail
Expand Down Expand Up @@ -311,11 +311,12 @@ namespace alpaka

ALPAKA_FN_ACC inline const_iterator(Idx elements, Idx stride, Idx extent, Idx first)
: elements_{elements}
, stride_{stride}
,
// we need to reduce the stride by on element range because index_ is later increased with each
// increment
stride_{stride - elements}
, extent_{extent}
, first_{std::min(first, extent)}
, index_{first_}
, range_{std::min(first + elements, extent)}
, index_{std::min(first, extent)}
{
}

Expand All @@ -328,22 +329,16 @@ namespace alpaka
// pre-increment the iterator
ALPAKA_FN_ACC inline const_iterator& operator++()
{
// increment the index along the elements processed by the current thread
++indexElem_;
++index_;
if(index_ < range_)
return *this;

// increment the thread index with the block stride
first_ += stride_;
index_ = first_;
range_ = std::min(first_ + elements_, extent_);
if(index_ < extent_)
return *this;
if(indexElem_ >= elements_)
{
indexElem_ = 0;
index_ += stride_;
}
if(index_ >= extent_)
index_ = extent_;

// the iterator has reached or passed the end of the extent, clamp it to the extent
first_ = extent_;
index_ = extent_;
range_ = extent_;
return *this;
}

Expand All @@ -357,7 +352,7 @@ namespace alpaka

ALPAKA_FN_ACC inline bool operator==(const_iterator const& other) const
{
return (index_ == other.index_) and (first_ == other.first_);
return (*(*this) == *other);
}

ALPAKA_FN_ACC inline bool operator!=(const_iterator const& other) const
Expand All @@ -371,16 +366,15 @@ namespace alpaka
Idx stride_;
Idx extent_;
// modified by the pre/post-increment operator
Idx first_;
Idx index_;
Idx range_;
Idx indexElem_ = 0;
};

private:
const Idx elements_;
const Idx thread_;
const Idx stride_;
const Idx extent_;
Idx const elements_;
Idx const thread_;
Idx const stride_;
Idx const extent_;
};

} // namespace detail
Expand Down
35 changes: 15 additions & 20 deletions include/alpaka/exec/UniformElements.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,11 +130,12 @@ namespace alpaka

ALPAKA_FN_ACC inline const_iterator(Idx elements, Idx stride, Idx extent, Idx first)
: elements_{elements}
, stride_{stride}
,
// we need to reduce the stride by on element range because index_ is later increased with each
// increment
stride_{stride - elements}
, extent_{extent}
, first_{std::min(first, extent)}
, index_{first_}
, range_{std::min(first + elements, extent)}
, index_{std::min(first, extent)}
{
}

Expand All @@ -148,21 +149,16 @@ namespace alpaka
ALPAKA_FN_ACC inline const_iterator& operator++()
{
// increment the index along the elements processed by the current thread
++indexElem_;
++index_;
if(index_ < range_)
return *this;

// increment the thread index with the grid stride
first_ += stride_;
index_ = first_;
range_ = std::min(first_ + elements_, extent_);
if(index_ < extent_)
return *this;
if(indexElem_ >= elements_)
{
indexElem_ = 0;
index_ += stride_;
}
if(index_ >= extent_)
index_ = extent_;

// the iterator has reached or passed the end of the extent, clamp it to the extent
first_ = extent_;
index_ = extent_;
range_ = extent_;
return *this;
}

Expand All @@ -176,7 +172,7 @@ namespace alpaka

ALPAKA_FN_ACC inline bool operator==(const_iterator const& other) const
{
return (index_ == other.index_) and (first_ == other.first_);
return (*(*this) == *other);
}

ALPAKA_FN_ACC inline bool operator!=(const_iterator const& other) const
Expand All @@ -190,9 +186,8 @@ namespace alpaka
Idx stride_;
Idx extent_;
// modified by the pre/post-increment operator
Idx first_;
Idx index_;
Idx range_;
Idx indexElem_ = 0;
};

private:
Expand Down

0 comments on commit 11ab218

Please sign in to comment.