HADOOP-18426. Use weighted calculation for MutableStat mean/variance to fix accuracy. (#4844). Contributed by Erik Krogen.

Co-authored-by: Shuyan Zhang <zqingchai@gmail.com>
Signed-off-by: He Xiaoqiao <hexiaoqiao@apache.org>
This commit is contained in:
Erik Krogen 2022-09-06 22:49:56 -07:00 committed by GitHub
parent cc41ad63f9
commit c664f953c9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 74 additions and 40 deletions

View File

@ -27,33 +27,29 @@
public class SampleStat {
private final MinMax minmax = new MinMax();
private long numSamples = 0;
private double a0, a1, s0, s1, total;
private double mean, s;
/**
* Construct a new running sample stat
*/
public SampleStat() {
a0 = s0 = 0.0;
total = 0.0;
mean = 0.0;
s = 0.0;
}
public void reset() {
numSamples = 0;
a0 = s0 = 0.0;
total = 0.0;
mean = 0.0;
s = 0.0;
minmax.reset();
}
// We want to reuse the object, sometimes.
void reset(long numSamples, double a0, double a1, double s0, double s1,
double total, MinMax minmax) {
this.numSamples = numSamples;
this.a0 = a0;
this.a1 = a1;
this.s0 = s0;
this.s1 = s1;
this.total = total;
this.minmax.reset(minmax);
void reset(long numSamples1, double mean1, double s1, MinMax minmax1) {
numSamples = numSamples1;
mean = mean1;
s = s1;
minmax.reset(minmax1);
}
/**
@ -61,7 +57,7 @@ void reset(long numSamples, double a0, double a1, double s0, double s1,
* @param other the destination to hold our values
*/
public void copyTo(SampleStat other) {
other.reset(numSamples, a0, a1, s0, s1, total, minmax);
other.reset(numSamples, mean, s, minmax);
}
/**
@ -78,24 +74,22 @@ public SampleStat add(double x) {
* Add some sample and a partial sum to the running stat.
* Note, min/max is not evaluated using this method.
* @param nSamples number of samples
* @param x the partial sum
* @param xTotal the partial sum
* @return self
*/
public SampleStat add(long nSamples, double x) {
public SampleStat add(long nSamples, double xTotal) {
numSamples += nSamples;
total += x;
if (numSamples == 1) {
a0 = a1 = x;
s0 = 0.0;
}
else {
// The Welford method for numerical stability
a1 = a0 + (x - a0) / numSamples;
s1 = s0 + (x - a0) * (x - a1);
a0 = a1;
s0 = s1;
}
// use the weighted incremental version of Welford's algorithm to get
// numerical stability while treating the samples as being weighted
// by nSamples
// see https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
double x = xTotal / nSamples;
double meanOld = mean;
mean += ((double) nSamples / numSamples) * (x - meanOld);
s += nSamples * (x - meanOld) * (x - mean);
return this;
}
@ -110,21 +104,21 @@ public long numSamples() {
* @return the total of all samples added
*/
public double total() {
return total;
return mean * numSamples;
}
/**
* @return the arithmetic mean of the samples
*/
public double mean() {
return numSamples > 0 ? (total / numSamples) : 0.0;
return numSamples > 0 ? mean : 0.0;
}
/**
* @return the variance of the samples
*/
public double variance() {
return numSamples > 1 ? s1 / (numSamples - 1) : 0.0;
return numSamples > 1 ? s / (numSamples - 1) : 0.0;
}
/**

View File

@ -29,6 +29,8 @@
import static org.mockito.Mockito.verify;
import static org.junit.Assert.*;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Random;
@ -36,6 +38,7 @@
import org.apache.hadoop.metrics2.MetricsRecordBuilder;
import org.apache.hadoop.metrics2.util.Quantile;
import org.apache.hadoop.thirdparty.com.google.common.math.Stats;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -47,7 +50,7 @@ public class TestMutableMetrics {
private static final Logger LOG =
LoggerFactory.getLogger(TestMutableMetrics.class);
private final double EPSILON = 1e-42;
private static final double EPSILON = 1e-42;
/**
* Test the snapshot method
@ -306,19 +309,56 @@ public void testDuplicateMetrics() {
/**
* Tests that when using {@link MutableStat#add(long, long)}, even with a high
* sample count, the mean does not lose accuracy.
* sample count, the mean does not lose accuracy. This also validates that
* the std dev is correct, assuming samples of equal value.
*/
@Test public void testMutableStatWithBulkAdd() {
@Test
public void testMutableStatWithBulkAdd() {
List<Long> samples = new ArrayList<>();
for (int i = 0; i < 1000; i++) {
samples.add(1000L);
}
for (int i = 0; i < 1000; i++) {
samples.add(2000L);
}
Stats stats = Stats.of(samples);
for (int bulkSize : new int[] {1, 10, 100, 1000}) {
MetricsRecordBuilder rb = mockMetricsRecordBuilder();
MetricsRegistry registry = new MetricsRegistry("test");
MutableStat stat = registry.newStat("Test", "Test", "Ops", "Val", false);
MutableStat stat = registry.newStat("Test", "Test", "Ops", "Val", true);
stat.add(1000, 1000);
stat.add(1000, 2000);
for (int i = 0; i < samples.size(); i += bulkSize) {
stat.add(bulkSize, samples
.subList(i, i + bulkSize)
.stream()
.mapToLong(Long::longValue)
.sum()
);
}
registry.snapshot(rb, false);
assertCounter("TestNumOps", 2000L, rb);
assertGauge("TestAvgVal", 1.5, rb);
assertGauge("TestAvgVal", stats.mean(), rb);
assertGauge("TestStdevVal", stats.sampleStandardDeviation(), rb);
}
}
@Test
public void testLargeMutableStatAdd() {
MetricsRecordBuilder rb = mockMetricsRecordBuilder();
MetricsRegistry registry = new MetricsRegistry("test");
MutableStat stat = registry.newStat("Test", "Test", "Ops", "Val", true);
long sample = 1000000000000009L;
for (int i = 0; i < 100; i++) {
stat.add(1, sample);
}
registry.snapshot(rb, false);
assertCounter("TestNumOps", 100L, rb);
assertGauge("TestAvgVal", (double) sample, rb);
assertGauge("TestStdevVal", 0.0, rb);
}
/**