Provide robust and efficient variance computation for online statistics.

This CL implements Welford's algorithm for a
numerically stable computation of the variance.
This implementation is plugged in SamplesStatsCounter class (adapter pattern).

A 'NumericalStability' unit test has been added,
whose previous implementation of SamplesStatsCounter failed to pass.

Follow-up CLs will factorize more occurences of duplicated and misbehaved
computations.

Bug: webrtc:10412
Change-Id: Id807c3d34e9c780fb1cbd769d30b655c575c88ac
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/131394
Commit-Queue: Yves Gerey <yvesg@google.com>
Reviewed-by: Artem Titov <titovartem@webrtc.org>
Reviewed-by: Karl Wiberg <kwiberg@webrtc.org>
Reviewed-by: Mirko Bonadei <mbonadei@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#27547}
This commit is contained in:
Yves Gerey 2019-04-10 17:18:48 +02:00 committed by Commit Bot
parent fb20afd38c
commit 890f62b8fe
7 changed files with 388 additions and 25 deletions

View File

@ -588,10 +588,12 @@ rtc_static_library("rtc_numerics") {
sources = [
"numerics/exp_filter.cc",
"numerics/exp_filter.h",
"numerics/math_utils.h",
"numerics/moving_average.cc",
"numerics/moving_average.h",
"numerics/moving_median_filter.h",
"numerics/percentile_filter.h",
"numerics/running_statistics.h",
"numerics/samples_stats_counter.cc",
"numerics/samples_stats_counter.h",
"numerics/sequence_number_util.h",
@ -1297,6 +1299,7 @@ if (rtc_include_tests) {
"numerics/moving_average_unittest.cc",
"numerics/moving_median_filter_unittest.cc",
"numerics/percentile_filter_unittest.cc",
"numerics/running_statistics_unittest.cc",
"numerics/samples_stats_counter_unittest.cc",
"numerics/sequence_number_util_unittest.cc",
]

View File

@ -36,4 +36,39 @@ typename std::make_unsigned<T>::type unsigned_difference(T x, T y) {
return static_cast<unsigned_type>(x) - static_cast<unsigned_type>(y);
}
// Provide neutral element with respect to min().
// Typically used as an initial value for running minimum.
template <typename T,
typename std::enable_if<std::numeric_limits<T>::has_infinity>::type* =
nullptr>
constexpr T infinity_or_max() {
return std::numeric_limits<T>::infinity();
}
template <typename T,
typename std::enable_if<
!std::numeric_limits<T>::has_infinity>::type* = nullptr>
constexpr T infinity_or_max() {
// Fallback to max().
return std::numeric_limits<T>::max();
}
// Provide neutral element with respect to max().
// Typically used as an initial value for running maximum.
template <typename T,
typename std::enable_if<std::numeric_limits<T>::has_infinity>::type* =
nullptr>
constexpr T minus_infinity_or_min() {
static_assert(std::is_signed<T>::value, "Unsupported. Please open a bug.");
return -std::numeric_limits<T>::infinity();
}
template <typename T,
typename std::enable_if<
!std::numeric_limits<T>::has_infinity>::type* = nullptr>
constexpr T minus_infinity_or_min() {
// Fallback to min().
return std::numeric_limits<T>::min();
}
#endif // RTC_BASE_NUMERICS_MATH_UTILS_H_

View File

@ -0,0 +1,135 @@
/*
* Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef RTC_BASE_NUMERICS_RUNNING_STATISTICS_H_
#define RTC_BASE_NUMERICS_RUNNING_STATISTICS_H_
#include <algorithm>
#include <cmath>
#include <limits>
#include "absl/types/optional.h"
#include "rtc_base/numerics/math_utils.h"
namespace webrtc {
// tl;dr: Robust and efficient online computation of statistics,
// using Welford's method for variance. [1]
//
// This should be your go-to class if you ever need to compute
// min, max, mean, variance and standard deviation.
// If you need to get percentiles, please use webrtc::SamplesStatsCounter.
//
// The measures return absl::nullopt if no samples were fed (Size() == 0),
// otherwise the returned optional is guaranteed to contain a value.
//
// [1]
// https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm
// The type T is a scalar which must be convertible to double.
// Rationale: we often need greater precision for measures
// than for the samples themselves.
template <typename T>
class RunningStatistics {
public:
// Update stats ////////////////////////////////////////////
// Add a value participating in the statistics in O(1) time.
void AddSample(T sample) {
max_ = std::max(max_, sample);
min_ = std::min(min_, sample);
++size_;
// Welford's incremental update.
const double delta = sample - mean_;
mean_ += delta / size_;
const double delta2 = sample - mean_;
cumul_ += delta * delta2;
}
// Merge other stats, as if samples were added one by one, but in O(1).
void MergeStatistics(const RunningStatistics<T>& other) {
if (other.size_ == 0) {
return;
}
max_ = std::max(max_, other.max_);
min_ = std::min(min_, other.min_);
const int64_t new_size = size_ + other.size_;
const double new_mean =
(mean_ * size_ + other.mean_ * other.size_) / new_size;
// Each cumulant must be corrected.
// * from: sum((x_i - mean_)²)
// * to: sum((x_i - new_mean)²)
auto delta = [new_mean](const RunningStatistics<T>& stats) {
return stats.size_ * (new_mean * (new_mean - 2 * stats.mean_) +
stats.mean_ * stats.mean_);
};
cumul_ = cumul_ + delta(*this) + other.cumul_ + delta(other);
mean_ = new_mean;
size_ = new_size;
}
// Get Measures ////////////////////////////////////////////
// Returns number of samples involved,
// that is number of times AddSample() was called.
int64_t Size() const { return size_; }
// Returns min in O(1) time.
absl::optional<T> GetMin() const {
if (size_ == 0) {
return absl::nullopt;
}
return min_;
}
// Returns max in O(1) time.
absl::optional<T> GetMax() const {
if (size_ == 0) {
return absl::nullopt;
}
return max_;
}
// Returns mean in O(1) time.
absl::optional<double> GetMean() const {
if (size_ == 0) {
return absl::nullopt;
}
return mean_;
}
// Returns unbiased sample variance in O(1) time.
absl::optional<double> GetVariance() const {
if (size_ == 0) {
return absl::nullopt;
}
return cumul_ / size_;
}
// Returns unbiased standard deviation in O(1) time.
absl::optional<double> GetStandardDeviation() const {
if (size_ == 0) {
return absl::nullopt;
}
return std::sqrt(*GetVariance());
}
private:
int64_t size_ = 0; // Samples seen.
T min_ = infinity_or_max<T>();
T max_ = minus_infinity_or_min<T>();
double mean_ = 0;
double cumul_ = 0; // Variance * size_, sometimes noted m2.
};
} // namespace webrtc
#endif // RTC_BASE_NUMERICS_RUNNING_STATISTICS_H_

View File

@ -0,0 +1,131 @@
/*
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "rtc_base/numerics/running_statistics.h"
#include <math.h>
#include <random>
#include <vector>
#include "absl/algorithm/container.h"
#include "test/gtest.h"
// Tests were copied from samples_stats_counter_unittest.cc.
namespace webrtc {
namespace {
RunningStatistics<double> CreateStatsFilledWithIntsFrom1ToN(int n) {
std::vector<double> data;
for (int i = 1; i <= n; i++) {
data.push_back(i);
}
absl::c_shuffle(data, std::mt19937(std::random_device()()));
RunningStatistics<double> stats;
for (double v : data) {
stats.AddSample(v);
}
return stats;
}
// Add n samples drawn from uniform distribution in [a;b].
RunningStatistics<double> CreateStatsFromUniformDistribution(int n,
double a,
double b) {
std::mt19937 gen{std::random_device()()};
std::uniform_real_distribution<> dis(a, b);
RunningStatistics<double> stats;
for (int i = 1; i <= n; i++) {
stats.AddSample(dis(gen));
}
return stats;
}
class RunningStatisticsTest : public ::testing::TestWithParam<int> {};
constexpr int SIZE_FOR_MERGE = 5;
} // namespace
TEST(RunningStatisticsTest, FullSimpleTest) {
auto stats = CreateStatsFilledWithIntsFrom1ToN(100);
EXPECT_DOUBLE_EQ(*stats.GetMin(), 1.0);
EXPECT_DOUBLE_EQ(*stats.GetMax(), 100.0);
EXPECT_DOUBLE_EQ(*stats.GetMean(), 50.5);
}
TEST(RunningStatistics, VarianceAndDeviation) {
RunningStatistics<int> stats;
stats.AddSample(2);
stats.AddSample(2);
stats.AddSample(-1);
stats.AddSample(5);
EXPECT_DOUBLE_EQ(*stats.GetMean(), 2.0);
EXPECT_DOUBLE_EQ(*stats.GetVariance(), 4.5);
EXPECT_DOUBLE_EQ(*stats.GetStandardDeviation(), sqrt(4.5));
}
TEST(RunningStatisticsTest, VarianceFromUniformDistribution) {
// Check variance converge to 1/12 for [0;1) uniform distribution.
// Acts as a sanity check for NumericStabilityForVariance test.
auto stats = CreateStatsFromUniformDistribution(1e6, 0, 1);
EXPECT_NEAR(*stats.GetVariance(), 1. / 12, 1e-3);
}
TEST(RunningStatisticsTest, NumericStabilityForVariance) {
// Same test as VarianceFromUniformDistribution,
// except the range is shifted to [1e9;1e9+1).
// Variance should also converge to 1/12.
// NB: Although we lose precision for the samples themselves, the fractional
// part still enjoys 22 bits of mantissa and errors should even out,
// so that couldn't explain a mismatch.
auto stats = CreateStatsFromUniformDistribution(1e6, 1e9, 1e9 + 1);
EXPECT_NEAR(*stats.GetVariance(), 1. / 12, 1e-3);
}
TEST_P(RunningStatisticsTest, MergeStatistics) {
int data[SIZE_FOR_MERGE] = {2, 2, -1, 5, 10};
// Split the data in different partitions.
// We have 6 distinct tests:
// * Empty merged with full sequence.
// * 1 sample merged with 4 last.
// * 2 samples merged with 3 last.
// [...]
// * Full merged with empty sequence.
// All must lead to the same result.
// I miss QuickCheck so much.
RunningStatistics<int> stats0, stats1;
for (int i = 0; i < GetParam(); ++i) {
stats0.AddSample(data[i]);
}
for (int i = GetParam(); i < SIZE_FOR_MERGE; ++i) {
stats1.AddSample(data[i]);
}
stats0.MergeStatistics(stats1);
EXPECT_EQ(stats0.Size(), SIZE_FOR_MERGE);
EXPECT_DOUBLE_EQ(*stats0.GetMin(), -1);
EXPECT_DOUBLE_EQ(*stats0.GetMax(), 10);
EXPECT_DOUBLE_EQ(*stats0.GetMean(), 3.6);
EXPECT_DOUBLE_EQ(*stats0.GetVariance(), 13.84);
EXPECT_DOUBLE_EQ(*stats0.GetStandardDeviation(), sqrt(13.84));
}
INSTANTIATE_TEST_SUITE_P(RunningStatisticsTests,
RunningStatisticsTest,
::testing::Range(0, SIZE_FOR_MERGE + 1));
} // namespace webrtc

View File

@ -26,26 +26,15 @@ SamplesStatsCounter& SamplesStatsCounter::operator=(SamplesStatsCounter&&) =
default;
void SamplesStatsCounter::AddSample(double value) {
stats_.AddSample(value);
samples_.push_back(value);
sorted_ = false;
if (value > max_) {
max_ = value;
}
if (value < min_) {
min_ = value;
}
sum_ += value;
sum_squared_ += value * value;
}
void SamplesStatsCounter::AddSamples(const SamplesStatsCounter& other) {
for (double sample : other.samples_)
samples_.push_back(sample);
stats_.MergeStatistics(other.stats_);
samples_.insert(samples_.end(), other.samples_.begin(), other.samples_.end());
sorted_ = false;
max_ = std::max(max_, other.max_);
min_ = std::min(min_, other.min_);
sum_ += other.sum_;
sum_squared_ += other.sum_squared_;
}
double SamplesStatsCounter::GetPercentile(double percentile) {

View File

@ -11,14 +11,15 @@
#ifndef RTC_BASE_NUMERICS_SAMPLES_STATS_COUNTER_H_
#define RTC_BASE_NUMERICS_SAMPLES_STATS_COUNTER_H_
#include <math.h>
#include <limits>
#include <vector>
#include "rtc_base/checks.h"
#include "rtc_base/numerics/running_statistics.h"
namespace webrtc {
// This class extends RunningStatistics by providing GetPercentile() method,
// while slightly adapting the interface.
class SamplesStatsCounter {
public:
SamplesStatsCounter();
@ -41,31 +42,31 @@ class SamplesStatsCounter {
// samples.
double GetMin() const {
RTC_DCHECK(!IsEmpty());
return min_;
return *stats_.GetMin();
}
// Returns max in O(1) time. This function may not be called if there are no
// samples.
double GetMax() const {
RTC_DCHECK(!IsEmpty());
return max_;
return *stats_.GetMax();
}
// Returns average in O(1) time. This function may not be called if there are
// no samples.
double GetAverage() const {
RTC_DCHECK(!IsEmpty());
return sum_ / samples_.size();
return *stats_.GetMean();
}
// Returns variance in O(1) time. This function may not be called if there are
// no samples.
double GetVariance() const {
RTC_DCHECK(!IsEmpty());
return sum_squared_ / samples_.size() - GetAverage() * GetAverage();
return *stats_.GetVariance();
}
// Returns standard deviation in O(1) time. This function may not be called if
// there are no samples.
double GetStandardDeviation() const {
RTC_DCHECK(!IsEmpty());
return sqrt(GetVariance());
return *stats_.GetStandardDeviation();
}
// Returns percentile in O(nlogn) on first call and in O(1) after, if no
// additions were done. This function may not be called if there are no
@ -76,11 +77,8 @@ class SamplesStatsCounter {
double GetPercentile(double percentile);
private:
RunningStatistics<double> stats_;
std::vector<double> samples_;
double min_ = std::numeric_limits<double>::max();
double max_ = std::numeric_limits<double>::min();
double sum_ = 0;
double sum_squared_ = 0;
bool sorted_ = false;
};

View File

@ -34,6 +34,24 @@ SamplesStatsCounter CreateStatsFilledWithIntsFrom1ToN(int n) {
return stats;
}
// Add n samples drawn from uniform distribution in [a;b].
SamplesStatsCounter CreateStatsFromUniformDistribution(int n,
double a,
double b) {
std::mt19937 gen{std::random_device()()};
std::uniform_real_distribution<> dis(a, b);
SamplesStatsCounter stats;
for (int i = 1; i <= n; i++) {
stats.AddSample(dis(gen));
}
return stats;
}
class SamplesStatsCounterTest : public ::testing::TestWithParam<int> {};
constexpr int SIZE_FOR_MERGE = 10;
} // namespace
TEST(SamplesStatsCounter, FullSimpleTest) {
@ -76,4 +94,58 @@ TEST(SamplesStatsCounter, TestBorderValues) {
EXPECT_DOUBLE_EQ(stats.GetPercentile(1.0), 5);
}
TEST(SamplesStatsCounter, VarianceFromUniformDistribution) {
// Check variance converge to 1/12 for [0;1) uniform distribution.
// Acts as a sanity check for NumericStabilityForVariance test.
SamplesStatsCounter stats = CreateStatsFromUniformDistribution(1e6, 0, 1);
EXPECT_NEAR(stats.GetVariance(), 1. / 12, 1e-3);
}
TEST(SamplesStatsCounter, NumericStabilityForVariance) {
// Same test as VarianceFromUniformDistribution,
// except the range is shifted to [1e9;1e9+1).
// Variance should also converge to 1/12.
// NB: Although we lose precision for the samples themselves, the fractional
// part still enjoys 22 bits of mantissa and errors should even out,
// so that couldn't explain a mismatch.
SamplesStatsCounter stats =
CreateStatsFromUniformDistribution(1e6, 1e9, 1e9 + 1);
EXPECT_NEAR(stats.GetVariance(), 1. / 12, 1e-3);
}
TEST_P(SamplesStatsCounterTest, AddSamples) {
int data[SIZE_FOR_MERGE] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
// Split the data in different partitions.
// We have 11 distinct tests:
// * Empty merged with full sequence.
// * 1 sample merged with 9 last.
// * 2 samples merged with 8 last.
// [...]
// * Full merged with empty sequence.
// All must lead to the same result.
SamplesStatsCounter stats0, stats1;
for (int i = 0; i < GetParam(); ++i) {
stats0.AddSample(data[i]);
}
for (int i = GetParam(); i < SIZE_FOR_MERGE; ++i) {
stats1.AddSample(data[i]);
}
stats0.AddSamples(stats1);
EXPECT_EQ(stats0.GetMin(), 0);
EXPECT_EQ(stats0.GetMax(), 9);
EXPECT_DOUBLE_EQ(stats0.GetAverage(), 4.5);
EXPECT_DOUBLE_EQ(stats0.GetVariance(), 8.25);
EXPECT_DOUBLE_EQ(stats0.GetStandardDeviation(), sqrt(8.25));
EXPECT_DOUBLE_EQ(stats0.GetPercentile(0.1), 0.9);
EXPECT_DOUBLE_EQ(stats0.GetPercentile(0.5), 4.5);
EXPECT_DOUBLE_EQ(stats0.GetPercentile(0.9), 8.1);
}
INSTANTIATE_TEST_SUITE_P(SamplesStatsCounterTests,
SamplesStatsCounterTest,
::testing::Range(0, SIZE_FOR_MERGE + 1));
} // namespace webrtc