AGC2 AdaptiveModeLevelEstimator min consecutive speech frames (2/3)

This is the second CL needed to add a new `AdaptiveModeLevelEstimator`
feature that makes AGC2 more robus to VAD mistakes: the level estimator
discards estimation updates when too few consecutive speech frames are
observed.

In this CL, the `SaturationProtector` class has been replaced by a
struct that define the state and two functions to change it.
This is done in order to use the saturation protector state in
`AdaptiveModeLevelEstimator::State` and will allow to add a
temporary state in `AdaptiveModeLevelEstimator` (see the child CL).

Tested: Bit-exactness verified with audioproc_f

Bug: webrtc:7494
Change-Id: Ic5ecd1e174010656ed20664ef7b7e5798ebb7978
Reviewed-on: https://webrtc-review.googlesource.com/c/src/+/185041
Commit-Queue: Alessio Bazzica <alessiob@webrtc.org>
Reviewed-by: Karl Wiberg <kwiberg@webrtc.org>
Cr-Commit-Position: refs/heads/master@{#32226}
This commit is contained in:
Alessio Bazzica 2020-09-29 11:56:38 +02:00 committed by Commit Bot
parent b6f002b55f
commit 56f63c3e7e
5 changed files with 235 additions and 162 deletions

View File

@ -49,9 +49,9 @@ AdaptiveModeLevelEstimator::AdaptiveModeLevelEstimator(
float initial_saturation_margin_db,
float extra_saturation_margin_db)
: apm_data_dumper_(apm_data_dumper),
saturation_protector_(apm_data_dumper, initial_saturation_margin_db),
level_estimator_type_(level_estimator),
use_saturation_protector_(use_saturation_protector),
initial_saturation_margin_db_(initial_saturation_margin_db),
extra_saturation_margin_db_(extra_saturation_margin_db),
last_level_dbfs_(absl::nullopt) {
Reset();
@ -102,11 +102,11 @@ void AdaptiveModeLevelEstimator::Update(
// Cache level estimation.
last_level_dbfs_ = state_.level_dbfs.GetRatio();
// TODO(crbug.com/webrtc/7494): Update saturation protector state in `state`.
if (use_saturation_protector_) {
saturation_protector_.UpdateMargin(
UpdateSaturationProtectorState(
/*speech_peak_dbfs=*/vad_level.peak_dbfs,
/*speech_level_dbfs=*/last_level_dbfs_.value());
/*speech_level_dbfs=*/last_level_dbfs_.value(),
state_.saturation_protector);
}
DebugDumpEstimate();
@ -115,7 +115,7 @@ void AdaptiveModeLevelEstimator::Update(
float AdaptiveModeLevelEstimator::GetLevelDbfs() const {
float level_dbfs = last_level_dbfs_.value_or(kInitialSpeechLevelEstimateDbfs);
if (use_saturation_protector_) {
level_dbfs += saturation_protector_.margin_db();
level_dbfs += state_.saturation_protector.margin_db;
level_dbfs += extra_saturation_margin_db_;
}
return rtc::SafeClamp<float>(level_dbfs, -90.f, 30.f);
@ -127,7 +127,6 @@ bool AdaptiveModeLevelEstimator::IsConfident() const {
}
void AdaptiveModeLevelEstimator::Reset() {
saturation_protector_.Reset();
ResetState(state_);
last_level_dbfs_ = absl::nullopt;
}
@ -136,15 +135,17 @@ void AdaptiveModeLevelEstimator::ResetState(State& state) {
state.time_to_full_buffer_ms = kFullBufferSizeMs;
state.level_dbfs.numerator = 0.f;
state.level_dbfs.denominator = 0.f;
// TODO(crbug.com/webrtc/7494): Reset saturation protector state in `state`.
ResetSaturationProtectorState(initial_saturation_margin_db_,
state.saturation_protector);
}
void AdaptiveModeLevelEstimator::DebugDumpEstimate() {
if (apm_data_dumper_) {
apm_data_dumper_->DumpRaw("agc2_adaptive_level_estimate_dbfs",
GetLevelDbfs());
apm_data_dumper_->DumpRaw("agc2_adaptive_saturation_margin_db",
state_.saturation_protector.margin_db);
}
saturation_protector_.DebugDumpEstimate();
}
} // namespace webrtc

View File

@ -62,18 +62,18 @@ class AdaptiveModeLevelEstimator {
};
int time_to_full_buffer_ms;
Ratio level_dbfs;
// TODO(crbug.com/webrtc/7494): Add saturation protector state.
SaturationProtectorState saturation_protector;
};
void ResetState(State& state);
void DebugDumpEstimate();
ApmDataDumper* const apm_data_dumper_;
SaturationProtector saturation_protector_;
const AudioProcessing::Config::GainController2::LevelEstimator
level_estimator_type_;
const bool use_saturation_protector_;
const float initial_saturation_margin_db_;
const float extra_saturation_margin_db_;
// TODO(crbug.com/webrtc/7494): Add temporary state.
State state_;

View File

@ -11,7 +11,6 @@
#include "modules/audio_processing/agc2/saturation_protector.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/numerics/safe_compare.h"
#include "rtc_base/numerics/safe_minmax.h"
namespace webrtc {
@ -23,14 +22,31 @@ constexpr float kMinLevelDbfs = -90.f;
constexpr float kMinMarginDb = 12.f;
constexpr float kMaxMarginDb = 25.f;
using saturation_protector_impl::RingBuffer;
} // namespace
void SaturationProtector::RingBuffer::Reset() {
bool RingBuffer::operator==(const RingBuffer& b) const {
RTC_DCHECK_LE(size_, buffer_.size());
RTC_DCHECK_LE(b.size_, b.buffer_.size());
if (size_ != b.size_) {
return false;
}
for (int i = 0, i0 = FrontIndex(), i1 = b.FrontIndex(); i < size_;
++i, ++i0, ++i1) {
if (buffer_[i0 % buffer_.size()] != b.buffer_[i1 % b.buffer_.size()]) {
return false;
}
}
return true;
}
void RingBuffer::Reset() {
next_ = 0;
size_ = 0;
}
void SaturationProtector::RingBuffer::PushBack(float v) {
void RingBuffer::PushBack(float v) {
RTC_DCHECK_GE(next_, 0);
RTC_DCHECK_GE(size_, 0);
RTC_DCHECK_LT(next_, buffer_.size());
@ -44,71 +60,62 @@ void SaturationProtector::RingBuffer::PushBack(float v) {
}
}
absl::optional<float> SaturationProtector::RingBuffer::Front() const {
absl::optional<float> RingBuffer::Front() const {
if (size_ == 0) {
return absl::nullopt;
}
RTC_DCHECK_LT(next_, buffer_.size());
return buffer_[rtc::SafeEq(size_, buffer_.size()) ? next_ : 0];
RTC_DCHECK_LT(FrontIndex(), buffer_.size());
return buffer_[FrontIndex()];
}
SaturationProtector::SaturationProtector(ApmDataDumper* apm_data_dumper)
: SaturationProtector(apm_data_dumper, GetInitialSaturationMarginDb()) {}
SaturationProtector::SaturationProtector(ApmDataDumper* apm_data_dumper,
float initial_saturation_margin_db)
: apm_data_dumper_(apm_data_dumper),
initial_saturation_margin_db_(initial_saturation_margin_db) {
Reset();
bool SaturationProtectorState::operator==(
const SaturationProtectorState& b) const {
return margin_db == b.margin_db && peak_delay_buffer == b.peak_delay_buffer &&
max_peaks_dbfs == b.max_peaks_dbfs &&
time_since_push_ms == b.time_since_push_ms;
}
void SaturationProtector::Reset() {
margin_db_ = initial_saturation_margin_db_;
peak_delay_buffer_.Reset();
max_peaks_dbfs_ = kMinLevelDbfs;
time_since_push_ms_ = 0;
void ResetSaturationProtectorState(float initial_margin_db,
SaturationProtectorState& state) {
state.margin_db = initial_margin_db;
state.peak_delay_buffer.Reset();
state.max_peaks_dbfs = kMinLevelDbfs;
state.time_since_push_ms = 0;
}
void SaturationProtector::UpdateMargin(float speech_peak_dbfs,
float speech_level_dbfs) {
void UpdateSaturationProtectorState(float speech_peak_dbfs,
float speech_level_dbfs,
SaturationProtectorState& state) {
// Get the max peak over `kPeakEnveloperSuperFrameLengthMs` ms.
max_peaks_dbfs_ = std::max(max_peaks_dbfs_, speech_peak_dbfs);
time_since_push_ms_ += kFrameDurationMs;
if (time_since_push_ms_ >
static_cast<int>(kPeakEnveloperSuperFrameLengthMs)) {
// Push `max_peaks_dbfs_` back into the ring buffer.
peak_delay_buffer_.PushBack(max_peaks_dbfs_);
state.max_peaks_dbfs = std::max(state.max_peaks_dbfs, speech_peak_dbfs);
state.time_since_push_ms += kFrameDurationMs;
if (rtc::SafeGt(state.time_since_push_ms, kPeakEnveloperSuperFrameLengthMs)) {
// Push `max_peaks_dbfs` back into the ring buffer.
state.peak_delay_buffer.PushBack(state.max_peaks_dbfs);
// Reset.
max_peaks_dbfs_ = kMinLevelDbfs;
time_since_push_ms_ = 0;
state.max_peaks_dbfs = kMinLevelDbfs;
state.time_since_push_ms = 0;
}
// Update margin by comparing the estimated speech level and the delayed max
// speech peak power.
// TODO(alessiob): Check with aleloi@ why we use a delay and how to tune it.
const float difference_db = GetDelayedPeakDbfs() - speech_level_dbfs;
if (margin_db_ < difference_db) {
margin_db_ = margin_db_ * kSaturationProtectorAttackConstant +
difference_db * (1.f - kSaturationProtectorAttackConstant);
const float delayed_peak_dbfs =
state.peak_delay_buffer.Front().value_or(state.max_peaks_dbfs);
const float difference_db = delayed_peak_dbfs - speech_level_dbfs;
if (difference_db > state.margin_db) {
// Attack.
state.margin_db =
state.margin_db * kSaturationProtectorAttackConstant +
difference_db * (1.f - kSaturationProtectorAttackConstant);
} else {
margin_db_ = margin_db_ * kSaturationProtectorDecayConstant +
difference_db * (1.f - kSaturationProtectorDecayConstant);
// Decay.
state.margin_db = state.margin_db * kSaturationProtectorDecayConstant +
difference_db * (1.f - kSaturationProtectorDecayConstant);
}
margin_db_ = rtc::SafeClamp<float>(margin_db_, kMinMarginDb, kMaxMarginDb);
}
float SaturationProtector::GetDelayedPeakDbfs() const {
return peak_delay_buffer_.Front().value_or(max_peaks_dbfs_);
}
void SaturationProtector::DebugDumpEstimate() const {
if (apm_data_dumper_) {
apm_data_dumper_->DumpRaw(
"agc2_adaptive_saturation_protector_delayed_peak_dbfs",
GetDelayedPeakDbfs());
apm_data_dumper_->DumpRaw("agc2_adaptive_saturation_margin_db", margin_db_);
}
state.margin_db =
rtc::SafeClamp<float>(state.margin_db, kMinMarginDb, kMaxMarginDb);
}
} // namespace webrtc

View File

@ -15,58 +15,68 @@
#include "absl/types/optional.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "rtc_base/numerics/safe_compare.h"
namespace webrtc {
namespace saturation_protector_impl {
class ApmDataDumper;
class SaturationProtector {
// Ring buffer which only supports (i) push back and (ii) read oldest item.
class RingBuffer {
public:
explicit SaturationProtector(ApmDataDumper* apm_data_dumper);
SaturationProtector(ApmDataDumper* apm_data_dumper,
float initial_saturation_margin_db);
bool operator==(const RingBuffer& b) const;
inline bool operator!=(const RingBuffer& b) const { return !(*this == b); }
// Maximum number of values that the buffer can contain.
int Capacity() const { return buffer_.size(); }
// Number of values in the buffer.
int Size() const { return size_; }
void Reset();
// Updates the margin by analyzing the estimated speech level
// `speech_level_dbfs` and the peak power `speech_peak_dbfs` for an observed
// frame which is reliably classified as "speech".
void UpdateMargin(float speech_peak_dbfs, float speech_level_dbfs);
// Returns latest computed margin.
float margin_db() const { return margin_db_; }
void DebugDumpEstimate() const;
// Pushes back `v`. If the buffer is full, the oldest value is replaced.
void PushBack(float v);
// Returns the oldest item in the buffer. Returns an empty value if the
// buffer is empty.
absl::optional<float> Front() const;
private:
// Ring buffer which only supports (i) push back and (ii) read oldest item.
class RingBuffer {
public:
void Reset();
// Pushes back `v`. If the buffer is full, the oldest item is replaced.
void PushBack(float v);
// Returns the oldest item in the buffer. Returns an empty value if the
// buffer is empty.
absl::optional<float> Front() const;
private:
std::array<float, kPeakEnveloperBufferSize> buffer_;
int next_ = 0;
int size_ = 0;
};
float GetDelayedPeakDbfs() const;
ApmDataDumper* apm_data_dumper_;
// Parameters.
const float initial_saturation_margin_db_;
// State.
float margin_db_;
RingBuffer peak_delay_buffer_;
float max_peaks_dbfs_;
int time_since_push_ms_;
inline int FrontIndex() const {
return rtc::SafeEq(size_, buffer_.size()) ? next_ : 0;
}
// `buffer_` has `size_` elements (up to the size of `buffer_`) and `next_` is
// the position where the next new value is written in `buffer_`.
std::array<float, kPeakEnveloperBufferSize> buffer_;
int next_ = 0;
int size_ = 0;
};
} // namespace saturation_protector_impl
// Saturation protector state. Exposed publicly for check-pointing and restore
// ops.
struct SaturationProtectorState {
bool operator==(const SaturationProtectorState& s) const;
inline bool operator!=(const SaturationProtectorState& s) const {
return !(*this == s);
}
float margin_db; // Recommended margin.
saturation_protector_impl::RingBuffer peak_delay_buffer;
float max_peaks_dbfs;
int time_since_push_ms; // Time since the last ring buffer push operation.
};
// Resets the saturation protector state.
void ResetSaturationProtectorState(float initial_margin_db,
SaturationProtectorState& state);
// Updates `state` by analyzing the estimated speech level `speech_level_dbfs`
// and the peak power `speech_peak_dbfs` for an observed frame which is
// reliably classified as "speech". `state` must not be modified without calling
// this function.
void UpdateSaturationProtectorState(float speech_peak_dbfs,
float speech_level_dbfs,
SaturationProtectorState& state);
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_H_

View File

@ -21,111 +21,166 @@ namespace {
constexpr float kInitialMarginDb = 20.f;
using saturation_protector_impl::RingBuffer;
SaturationProtectorState CreateSaturationProtectorState() {
SaturationProtectorState state;
ResetSaturationProtectorState(kInitialMarginDb, state);
return state;
}
// Updates `state` for `num_iterations` times with constant speech level and
// peak powers and returns the maximum margin.
float RunOnConstantLevel(int num_iterations,
float speech_peak_dbfs,
float speech_level_dbfs,
SaturationProtector* saturation_protector) {
float last_margin = saturation_protector->margin_db();
SaturationProtectorState& state) {
float last_margin = state.margin_db;
float max_difference = 0.f;
for (int i = 0; i < num_iterations; ++i) {
saturation_protector->UpdateMargin(speech_peak_dbfs, speech_level_dbfs);
const float new_margin = saturation_protector->margin_db();
UpdateSaturationProtectorState(speech_peak_dbfs, speech_level_dbfs, state);
const float new_margin = state.margin_db;
max_difference =
std::max(max_difference, std::abs(new_margin - last_margin));
last_margin = new_margin;
saturation_protector->DebugDumpEstimate();
}
return max_difference;
}
} // namespace
TEST(AutomaticGainController2SaturationProtector, ProtectorShouldNotCrash) {
ApmDataDumper apm_data_dumper(0);
SaturationProtector saturation_protector(&apm_data_dumper, kInitialMarginDb);
saturation_protector.UpdateMargin(/*speech_peak_dbfs=*/-10.f,
/*speech_level_dbfs=*/-20.f);
static_cast<void>(saturation_protector.margin_db());
saturation_protector.DebugDumpEstimate();
TEST(AutomaticGainController2SaturationProtector, RingBufferInit) {
RingBuffer b;
EXPECT_EQ(b.Size(), 0);
EXPECT_FALSE(b.Front().has_value());
}
// Check that the estimate converges to the ratio between peaks and
// level estimator values after a while.
TEST(AutomaticGainController2SaturationProtector, RingBufferPushBack) {
RingBuffer b;
constexpr float kValue = 123.f;
b.PushBack(kValue);
EXPECT_EQ(b.Size(), 1);
ASSERT_TRUE(b.Front().has_value());
EXPECT_EQ(b.Front().value(), kValue);
}
TEST(AutomaticGainController2SaturationProtector, RingBufferReset) {
RingBuffer b;
b.PushBack(123.f);
b.Reset();
EXPECT_EQ(b.Size(), 0);
EXPECT_FALSE(b.Front().has_value());
}
// Checks that the front value does not change until the ring buffer gets full.
TEST(AutomaticGainController2SaturationProtector,
RingBufferFrontUntilBufferIsFull) {
RingBuffer b;
constexpr float kValue = 123.f;
b.PushBack(kValue);
for (int i = 1; i < b.Capacity(); ++i) {
EXPECT_EQ(b.Front().value(), kValue);
b.PushBack(kValue + i);
}
}
// Checks that when the buffer is full it behaves as a shift register.
TEST(AutomaticGainController2SaturationProtector,
FullRingBufferFrontIsDelayed) {
RingBuffer b;
// Fill the buffer.
for (int i = 0; i < b.Capacity(); ++i) {
b.PushBack(i);
}
// The ring buffer should now behave as a shift register with a delay equal to
// its capacity.
for (int i = b.Capacity(); i < 2 * b.Capacity() + 1; ++i) {
EXPECT_EQ(b.Front().value(), i - b.Capacity());
b.PushBack(i);
}
}
// Checks that a state after reset equals a state after construction.
TEST(AutomaticGainController2SaturationProtector, ResetState) {
SaturationProtectorState init_state;
ResetSaturationProtectorState(kInitialMarginDb, init_state);
SaturationProtectorState state;
ResetSaturationProtectorState(kInitialMarginDb, state);
RunOnConstantLevel(/*num_iterations=*/10, /*speech_level_dbfs=*/-20.f,
/*speech_peak_dbfs=*/-10.f, state);
ASSERT_NE(init_state, state); // Make sure that there are side-effects.
ResetSaturationProtectorState(kInitialMarginDb, state);
EXPECT_EQ(init_state, state);
}
// Checks that the estimate converges to the ratio between peaks and level
// estimator values after a while.
TEST(AutomaticGainController2SaturationProtector,
ProtectorEstimatesCrestRatio) {
ApmDataDumper apm_data_dumper(0);
SaturationProtector saturation_protector(&apm_data_dumper, kInitialMarginDb);
constexpr int kNumIterations = 2000;
constexpr float kPeakLevel = -20.f;
const float kCrestFactor = kInitialMarginDb + 1.f;
const float kSpeechLevel = kPeakLevel - kCrestFactor;
const float kMaxDifference = 0.5 * std::abs(kInitialMarginDb - kCrestFactor);
constexpr float kCrestFactor = kInitialMarginDb + 1.f;
constexpr float kSpeechLevel = kPeakLevel - kCrestFactor;
const float kMaxDifference = 0.5f * std::abs(kInitialMarginDb - kCrestFactor);
static_cast<void>(RunOnConstantLevel(2000, kPeakLevel, kSpeechLevel,
&saturation_protector));
auto state = CreateSaturationProtectorState();
RunOnConstantLevel(kNumIterations, kPeakLevel, kSpeechLevel, state);
EXPECT_NEAR(saturation_protector.margin_db(), kCrestFactor, kMaxDifference);
EXPECT_NEAR(state.margin_db, kCrestFactor, kMaxDifference);
}
TEST(AutomaticGainController2SaturationProtector, ProtectorChangesSlowly) {
ApmDataDumper apm_data_dumper(0);
SaturationProtector saturation_protector(&apm_data_dumper, kInitialMarginDb);
constexpr float kPeakLevel = -20.f;
const float kCrestFactor = kInitialMarginDb - 5.f;
const float kOtherCrestFactor = kInitialMarginDb;
const float kSpeechLevel = kPeakLevel - kCrestFactor;
const float kOtherSpeechLevel = kPeakLevel - kOtherCrestFactor;
// Checks that the margin does not change too quickly.
TEST(AutomaticGainController2SaturationProtector, ChangeSlowly) {
constexpr int kNumIterations = 1000;
float max_difference = RunOnConstantLevel(
kNumIterations, kPeakLevel, kSpeechLevel, &saturation_protector);
constexpr float kPeakLevel = -20.f;
constexpr float kCrestFactor = kInitialMarginDb - 5.f;
constexpr float kOtherCrestFactor = kInitialMarginDb;
constexpr float kSpeechLevel = kPeakLevel - kCrestFactor;
constexpr float kOtherSpeechLevel = kPeakLevel - kOtherCrestFactor;
max_difference =
std::max(RunOnConstantLevel(kNumIterations, kPeakLevel, kOtherSpeechLevel,
&saturation_protector),
max_difference);
constexpr float kMaxChangeSpeedDbPerSecond = 0.5; // 1 db / 2 seconds.
auto state = CreateSaturationProtectorState();
float max_difference =
RunOnConstantLevel(kNumIterations, kPeakLevel, kSpeechLevel, state);
max_difference = std::max(
RunOnConstantLevel(kNumIterations, kPeakLevel, kOtherSpeechLevel, state),
max_difference);
constexpr float kMaxChangeSpeedDbPerSecond = 0.5f; // 1 db / 2 seconds.
EXPECT_LE(max_difference,
kMaxChangeSpeedDbPerSecond / 1000 * kFrameDurationMs);
}
TEST(AutomaticGainController2SaturationProtector,
ProtectorAdaptsToDelayedChanges) {
ApmDataDumper apm_data_dumper(0);
SaturationProtector saturation_protector(&apm_data_dumper, kInitialMarginDb);
// Checks that there is a delay between input change and margin adaptations.
TEST(AutomaticGainController2SaturationProtector, AdaptToDelayedChanges) {
constexpr int kDelayIterations = kFullBufferSizeMs / kFrameDurationMs;
constexpr float kInitialSpeechLevelDbfs = -30;
constexpr float kLaterSpeechLevelDbfs = -15;
constexpr float kInitialSpeechLevelDbfs = -30.f;
constexpr float kLaterSpeechLevelDbfs = -15.f;
auto state = CreateSaturationProtectorState();
// First run on initial level.
float max_difference = RunOnConstantLevel(
kDelayIterations, kInitialSpeechLevelDbfs + kInitialMarginDb,
kInitialSpeechLevelDbfs, &saturation_protector);
kInitialSpeechLevelDbfs, state);
// Then peak changes, but not RMS.
max_difference =
std::max(RunOnConstantLevel(
kDelayIterations, kLaterSpeechLevelDbfs + kInitialMarginDb,
kInitialSpeechLevelDbfs, &saturation_protector),
std::max(RunOnConstantLevel(kDelayIterations,
kLaterSpeechLevelDbfs + kInitialMarginDb,
kInitialSpeechLevelDbfs, state),
max_difference);
// Then both change.
max_difference =
std::max(RunOnConstantLevel(kDelayIterations,
kLaterSpeechLevelDbfs + kInitialMarginDb,
kLaterSpeechLevelDbfs, &saturation_protector),
kLaterSpeechLevelDbfs, state),
max_difference);
// The saturation protector expects that the RMS changes roughly
// 'kFullBufferSizeMs' after peaks change. This is to account for
// delay introduces by the level estimator. Therefore, the input
// above is 'normal' and 'expected', and shouldn't influence the
// margin by much.
const float total_difference =
std::abs(saturation_protector.margin_db() - kInitialMarginDb);
// 'kFullBufferSizeMs' after peaks change. This is to account for delay
// introduced by the level estimator. Therefore, the input above is 'normal'
// and 'expected', and shouldn't influence the margin by much.
const float total_difference = std::abs(state.margin_db - kInitialMarginDb);
EXPECT_LE(total_difference, 0.05f);
EXPECT_LE(max_difference, 0.01f);