Skip to content

Commit

Permalink
Create a new class StatWithPercentiles that inherits Stat.
Browse files Browse the repository at this point in the history
A StatWithPercentiles object keeps track of the values added to it, and supports computing percentile values.

PiperOrigin-RevId: 689222675
  • Loading branch information
Google-ML-Automation committed Oct 24, 2024
1 parent 872eaa0 commit 450b61f
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 0 deletions.
36 changes: 36 additions & 0 deletions xla/tsl/util/stats_calculator.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,42 @@ class Stat {
HighPrecisionValueType squared_sum_ = 0;
};

// A `StatWithPercentiles` inherited from `Stat`, also keeps track of the
// values added and can be used to compute the percentile values.
template <typename ValueType, typename HighPrecisionValueType = double>
class StatWithPercentiles : public Stat<ValueType, HighPrecisionValueType> {
public:
void UpdateStat(ValueType v) {
Stat<ValueType, HighPrecisionValueType>::UpdateStat(v);
values_.push_back(v);
}

// Returns the percentile value.
ValueType percentile(int percentile) const {
if (percentile < 0 || percentile > 100 || values_.empty()) {
return std::numeric_limits<ValueType>::quiet_NaN();
}
std::vector<ValueType> values = values_;
if (percentile == 100) {
return values[values.size() - 1];
} else {
std::nth_element(values.begin(),
values.begin() + values.size() * percentile / 100,
values.end());
return values[values.size() * percentile / 100];
}
}

void OutputToStream(std::ostream* stream) const {
Stat<ValueType, HighPrecisionValueType>::OutputToStream(stream);
*stream << " p5=" << percentile(5) << " median=" << percentile(50)
<< " p95=" << percentile(95);
}

private:
std::vector<ValueType> values_;
};

// A StatsCalculator assists in performance analysis of Graph executions.
//
// It summarizes time spent executing (on GPU/CPU), memory used etc for
Expand Down
36 changes: 36 additions & 0 deletions xla/tsl/util/stats_calculator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ limitations under the License.
#include "xla/tsl/util/stats_calculator.h"

#include <cfloat>
#include <cmath>
#include <cstdint>

#include "tsl/platform/test.h"

Expand Down Expand Up @@ -104,5 +106,39 @@ TEST(StatsCalculatorTest, UpdateStat) {
EXPECT_NEAR(43.30704330706496060826769, stat.std_deviation(), FLT_EPSILON);
}

TEST(StatsCalculatorTest, StatWithPercentiles) {
StatWithPercentiles<int64_t> stat;
EXPECT_TRUE(stat.empty());
EXPECT_TRUE(stat.all_same());
stat.UpdateStat(1);
EXPECT_TRUE(stat.all_same());
stat.UpdateStat(-1.0);
EXPECT_FALSE(stat.all_same());
stat.UpdateStat(100);
stat.UpdateStat(0);
EXPECT_EQ(4, stat.count());
EXPECT_EQ(-1, stat.min());
EXPECT_EQ(100, stat.max());
EXPECT_EQ(25, stat.avg());
EXPECT_EQ(1, stat.first());
EXPECT_EQ(0, stat.newest());
EXPECT_EQ(10002, stat.squared_sum());
EXPECT_EQ(625, stat.avg() * stat.avg());
// Sample variance
EXPECT_EQ(7502 / 3, stat.sample_variance());
// Sample standard deviation, from WolframAlpha
EXPECT_EQ(50, std::sqrt(stat.sample_variance()));
// Population variance
EXPECT_EQ(7502 / 4, stat.variance());
// Population standard deviation, from WolframAlpha
EXPECT_EQ(43, stat.std_deviation());
EXPECT_EQ(1, stat.percentile(50));
EXPECT_EQ(100, stat.percentile(90));
stat.UpdateStat(150);
EXPECT_EQ(1, stat.percentile(50));
EXPECT_EQ(150, stat.percentile(90));
EXPECT_EQ(150, stat.percentile(100));
}

} // namespace
} // namespace tsl

0 comments on commit 450b61f

Please sign in to comment.