Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/main/java/com/tdunning/math/stats/AVLGroupTree.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ final class AVLGroupTree extends AbstractCollection<Centroid> implements Seriali
private double[] centroids;
private int[] counts;
private List<Double>[] datas;
private int[] aggregatedCounts;
private long[] aggregatedCounts;
private final IntAVLTree tree;

AVLGroupTree() {
Expand Down Expand Up @@ -99,7 +99,7 @@ protected void fixAggregates(int node) {
};
centroids = new double[tree.capacity()];
counts = new int[tree.capacity()];
aggregatedCounts = new int[tree.capacity()];
aggregatedCounts = new long[tree.capacity()];
if (record) {
@SuppressWarnings("unchecked")
final List<Double>[] datas = new List[tree.capacity()];
Expand All @@ -110,6 +110,7 @@ protected void fixAggregates(int node) {
/**
* Return the number of centroids in the tree.
*/
@Override
public int size() {
return tree.size();
}
Expand Down Expand Up @@ -274,7 +275,7 @@ public void remove() {
/**
* Return the total count of points that have been added to the tree.
*/
public int sum() {
public long sum() {
return aggregatedCounts[tree.root()];
}

Expand Down
30 changes: 30 additions & 0 deletions src/test/java/com/tdunning/math/stats/TDigestTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,36 @@ public void testMoreThan2BValues() {
for (double q : quantiles) {
final double v = digest.quantile(q);
assertTrue(v >= prev);
assertTrue("Unexpectedly low value: " + v, v >= 0.0);
assertTrue("Unexpectedly high value: " + v, v <= 1.0);
prev = v;
}
}

@Test
public void testMoreThan4BValues() {
final TDigest digest = factory().create();
Random gen = getRandom();
for (int i = 0; i < 1000; ++i) {
final double next = gen.nextDouble();
digest.add(next);
}
for (int i = 0; i < 10; ++i) {
final double next = gen.nextDouble();
final int count = 1 << 29;
digest.add(next, count);
}
assertEquals(1000 + 10L * (1 << 29), digest.size());
assertTrue(digest.size() > 2 * Integer.MAX_VALUE);
final double[] quantiles = new double[] { 0, 0.1, 0.5, 0.9, 1, gen.nextDouble() };
Arrays.sort(quantiles);
double prev = Double.NEGATIVE_INFINITY;
for (double q : quantiles) {
final double v = digest.quantile(q);
System.out.println("q=" + q + ", v=" + v);
assertTrue(v >= prev);
assertTrue("Unexpectedly low value: " + v, v >= 0.0);
assertTrue("Unexpectedly high value: " + v, v <= 1.0);
prev = v;
}
}
Expand Down