diff --git a/xla/tsl/util/stats_calculator.h b/xla/tsl/util/stats_calculator.h index 84045fb6ceece..253895ca605fa 100644 --- a/xla/tsl/util/stats_calculator.h +++ b/xla/tsl/util/stats_calculator.h @@ -117,6 +117,42 @@ class Stat { HighPrecisionValueType squared_sum_ = 0; }; +// A `StatWithPercentiles` inherited from `Stat`, also keeps track of the +// values added and can be used to compute the percentile values. +template +class StatWithPercentiles : public Stat { + public: + void UpdateStat(ValueType v) { + Stat::UpdateStat(v); + values_.push_back(v); + } + + // Returns the percentile value. + ValueType percentile(int percentile) const { + if (percentile < 0 || percentile > 100 || values_.empty()) { + return std::numeric_limits::quiet_NaN(); + } + std::vector values = values_; + if (percentile == 100) { + return values[values.size() - 1]; + } else { + std::nth_element(values.begin(), + values.begin() + values.size() * percentile / 100, + values.end()); + return values[values.size() * percentile / 100]; + } + } + + void OutputToStream(std::ostream* stream) const { + Stat::OutputToStream(stream); + *stream << " p5=" << percentile(5) << " median=" << percentile(50) + << " p95=" << percentile(95); + } + + private: + std::vector values_; +}; + // A StatsCalculator assists in performance analysis of Graph executions. // // It summarizes time spent executing (on GPU/CPU), memory used etc for diff --git a/xla/tsl/util/stats_calculator_test.cc b/xla/tsl/util/stats_calculator_test.cc index d58186630598f..bab88a0236fe7 100644 --- a/xla/tsl/util/stats_calculator_test.cc +++ b/xla/tsl/util/stats_calculator_test.cc @@ -16,6 +16,8 @@ limitations under the License. #include "xla/tsl/util/stats_calculator.h" #include +#include +#include #include "tsl/platform/test.h" @@ -104,5 +106,39 @@ TEST(StatsCalculatorTest, UpdateStat) { EXPECT_NEAR(43.30704330706496060826769, stat.std_deviation(), FLT_EPSILON); } +TEST(StatsCalculatorTest, StatWithPercentiles) { + StatWithPercentiles stat; + EXPECT_TRUE(stat.empty()); + EXPECT_TRUE(stat.all_same()); + stat.UpdateStat(1); + EXPECT_TRUE(stat.all_same()); + stat.UpdateStat(-1.0); + EXPECT_FALSE(stat.all_same()); + stat.UpdateStat(100); + stat.UpdateStat(0); + EXPECT_EQ(4, stat.count()); + EXPECT_EQ(-1, stat.min()); + EXPECT_EQ(100, stat.max()); + EXPECT_EQ(25, stat.avg()); + EXPECT_EQ(1, stat.first()); + EXPECT_EQ(0, stat.newest()); + EXPECT_EQ(10002, stat.squared_sum()); + EXPECT_EQ(625, stat.avg() * stat.avg()); + // Sample variance + EXPECT_EQ(7502 / 3, stat.sample_variance()); + // Sample standard deviation, from WolframAlpha + EXPECT_EQ(50, std::sqrt(stat.sample_variance())); + // Population variance + EXPECT_EQ(7502 / 4, stat.variance()); + // Population standard deviation, from WolframAlpha + EXPECT_EQ(43, stat.std_deviation()); + EXPECT_EQ(1, stat.percentile(50)); + EXPECT_EQ(100, stat.percentile(90)); + stat.UpdateStat(150); + EXPECT_EQ(1, stat.percentile(50)); + EXPECT_EQ(150, stat.percentile(90)); + EXPECT_EQ(150, stat.percentile(100)); +} + } // namespace } // namespace tsl