Skip to content

Commit 81cf351

Browse files
[Data] Add approximate quantile to aggregator (#57598)
<!-- Thank you for your contribution! Please review https://github.com/ray-project/ray/blob/master/CONTRIBUTING.rst before opening a pull request. --> <!-- Please add a reviewer to the assignee section when you create a PR. If you don't have the access to it, we will shortly find a reviewer and assign them to your PR. --> ## Why are these changes needed? Add ApproximateQuantile aggregator to Ray Data using DataSketches KLL. Reason: • Enables efficient support for the summary API. • More scalable than exact Quantile on large datasets. Note: • DataSketches is not added as a Ray dependency; if missing, users are prompted to install it. --- Here's a simple test to show the efficiency difference between `ApproximateQuantile` and `Quantile` ```py import ray import ray.data import time ray.init(num_cpus=16) from ray.data.aggregate import ApproximateQuantile, Quantile ds = ray.data.range(10**8) start_time = time.time() print(ds.aggregate(ApproximateQuantile(on="id", quantiles=[0.5]))) print(f"Time taken ApproximateQuantile: {time.time() - start_time} seconds") ds = ray.data.range(10**8) start_time = time.time() print(ds.aggregate(Quantile(on="id", q=0.5))) print(f"Time taken Quantile: {time.time() - start_time} seconds") ``` In this run with 1e8 rows, the approximate median returned 49,979,428.0 in ~12.46s, while the exact Quantile returned 49,999,999.5 in ~163.33s. The difference reflects the sketch’s accuracy trade-off for significant speed and scalability gains. When k=800 (the default), we are guaranteed to have the error rate < 0.45% , in this test our error rate is `(49,999,999.5-49,979,428.0)/49,999,999.5`= 0.00041143 = 0.041143% which is < 0.45% , but we get the approximate median **13.11x** faster. ``` {'approx_quantile(id)': [49979428.0]} Time taken ApproximateQuantile: 12.457247257232666 seconds {'quantile(id)': 49999999.5} Time taken Quantile: 163.32705521583557 seconds ``` <!-- Please give a short summary of the change and the problem this solves. --> ## Related issue number <!-- For example: "Closes #1234" --> ## Checks - [ ] I've signed off every commit(by using the -s flag, i.e., `git commit -s`) in this PR. - [ ] I've run pre-commit jobs to lint the changes in this PR. ([pre-commit setup](https://docs.ray.io/en/latest/ray-contribute/getting-involved.html#lint-and-formatting)) - [ ] I've included any doc changes needed for https://docs.ray.io/en/master/. - [ ] I've added any new APIs to the API Reference. For example, if I added a method in Tune, I've added it in `doc/source/tune/api/` under the corresponding `.rst` file. - [ ] I've made sure the tests are passing. Note that there might be a few flaky tests, see the recent failures at https://flakey-tests.ray.io/ - Testing Strategy - [ ] Unit tests - [ ] Release tests - [ ] This PR is not tested :( --------- Signed-off-by: You-Cheng Lin (Owen) <mses010108@gmail.com> Signed-off-by: You-Cheng Lin <106612301+owenowenisme@users.noreply.github.com> Co-authored-by: Balaji Veeramani <balaji@anyscale.com>
1 parent 3668836 commit 81cf351

File tree

7 files changed

+214
-19
lines changed

7 files changed

+214
-19
lines changed

doc/source/data/api/aggregate.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,5 @@ compute aggregations.
2727
Unique
2828
MissingValuePercentage
2929
ZeroPercentage
30+
ApproximateQuantile
31+

python/ray/data/aggregate.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1189,3 +1189,101 @@ def finalize(self, accumulator: List[int]) -> Optional[float]:
11891189
if accumulator[1] == 0:
11901190
return None
11911191
return (accumulator[0] / accumulator[1]) * 100.0
1192+
1193+
1194+
@PublicAPI(stability="alpha")
1195+
class ApproximateQuantile(AggregateFnV2):
1196+
def _require_datasketches(self):
1197+
try:
1198+
from datasketches import kll_floats_sketch # type: ignore[import]
1199+
except ImportError as exc:
1200+
raise ImportError(
1201+
"ApproximateQuantile requires the `datasketches` package. "
1202+
"Install it with `pip install datasketches`."
1203+
) from exc
1204+
return kll_floats_sketch
1205+
1206+
def __init__(
1207+
self,
1208+
on: str,
1209+
quantiles: List[float],
1210+
quantile_precision: int = 800,
1211+
alias_name: Optional[str] = None,
1212+
):
1213+
"""
1214+
Computes the approximate quantiles of a column by using a datasketches kll_floats_sketch.
1215+
https://datasketches.apache.org/docs/KLL/KLLSketch.html
1216+
1217+
The accuracy of the KLL quantile sketch is a function of the configured quantile precision, which also affects
1218+
the overall size of the sketch.
1219+
The KLL Sketch has absolute error. For example, a specified rank accuracy of 1% at the
1220+
median (rank = 0.50) means that the true quantile (if you could extract it from the set)
1221+
should be between getQuantile(0.49) and getQuantile(0.51). This same 1% error applied at a
1222+
rank of 0.95 means that the true quantile should be between getQuantile(0.94) and getQuantile(0.96).
1223+
In other words, the error is a fixed +/- epsilon for the entire range of ranks.
1224+
1225+
Typical single-sided rank error by quantile_precision (use for getQuantile/getRank):
1226+
- quantile_precision=100 → ~2.61%
1227+
- quantile_precision=200 → ~1.33%
1228+
- quantile_precision=400 → ~0.68%
1229+
- quantile_precision=800 → ~0.35%
1230+
1231+
See https://datasketches.apache.org/docs/KLL/KLLAccuracyAndSize.html for details on accuracy and size.
1232+
1233+
Null values in the target column are ignored when constructing the sketch.
1234+
1235+
Example:
1236+
1237+
.. testcode::
1238+
1239+
import ray
1240+
from ray.data.aggregate import ApproximateQuantile
1241+
1242+
# Create a dataset with some values
1243+
ds = ray.data.from_items(
1244+
[{"value": 20.0}, {"value": 40.0}, {"value": 60.0},
1245+
{"value": 80.0}, {"value": 100.0}]
1246+
)
1247+
1248+
result = ds.aggregate(ApproximateQuantile(on="value", quantiles=[0.1, 0.5, 0.9]))
1249+
# Result: {'approx_quantile(value)': [20.0, 60.0, 100.0]}
1250+
1251+
1252+
Args:
1253+
on: The name of the column to calculate the quantile on. Must be a numeric column.
1254+
quantiles: The list of quantiles to compute. Must be between 0 and 1 inclusive. For example, quantiles=[0.5] computes the median. Null entries in the source column are skipped.
1255+
quantile_precision: Controls the accuracy and memory footprint of the sketch (K in KLL); higher values yield lower error but use more memory. Defaults to 800. See https://datasketches.apache.org/docs/KLL/KLLAccuracyAndSize.html for details on accuracy and size.
1256+
alias_name: Optional name for the resulting column. If not provided, defaults to "approx_quantile({column_name})".
1257+
"""
1258+
self._sketch_cls = self._require_datasketches()
1259+
self._quantiles = quantiles
1260+
self._quantile_precision = quantile_precision
1261+
super().__init__(
1262+
alias_name if alias_name else f"approx_quantile({str(on)})",
1263+
on=on,
1264+
ignore_nulls=True,
1265+
zero_factory=lambda: self.zero(quantile_precision).serialize(),
1266+
)
1267+
1268+
def zero(self, quantile_precision: int):
1269+
return self._sketch_cls(k=quantile_precision)
1270+
1271+
def aggregate_block(self, block: Block) -> bytes:
1272+
block_acc = BlockAccessor.for_block(block)
1273+
table = block_acc.to_arrow()
1274+
column = table.column(self.get_target_column())
1275+
sketch = self.zero(self._quantile_precision)
1276+
for value in column:
1277+
# we ignore nulls here
1278+
if value.as_py() is not None:
1279+
sketch.update(float(value.as_py()))
1280+
return sketch.serialize()
1281+
1282+
def combine(self, current_accumulator: bytes, new: bytes) -> bytes:
1283+
combined = self.zero(self._quantile_precision)
1284+
combined.merge(self._sketch_cls.deserialize(current_accumulator))
1285+
combined.merge(self._sketch_cls.deserialize(new))
1286+
return combined.serialize()
1287+
1288+
def finalize(self, accumulator: bytes) -> List[float]:
1289+
return self._sketch_cls.deserialize(accumulator).get_quantiles(self._quantiles)

python/ray/data/stats.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from ray.data.aggregate import (
88
AggregateFnV2,
9+
ApproximateQuantile,
910
Count,
1011
Max,
1112
Mean,
@@ -31,6 +32,7 @@ def numerical_aggregators(column: str) -> List[AggregateFnV2]:
3132
- min
3233
- max
3334
- std
35+
- approximate_quantile
3436
- missing_value_percentage
3537
- zero_percentage
3638
@@ -46,6 +48,7 @@ def numerical_aggregators(column: str) -> List[AggregateFnV2]:
4648
Min(on=column, ignore_nulls=True),
4749
Max(on=column, ignore_nulls=True),
4850
Std(on=column, ignore_nulls=True, ddof=0),
51+
ApproximateQuantile(on=column, quantiles=[0.5]),
4952
MissingValuePercentage(on=column),
5053
ZeroPercentage(on=column, ignore_nulls=True),
5154
]

python/ray/data/tests/test_custom_agg.py

Lines changed: 86 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,11 @@
22
import pytest
33

44
import ray
5-
from ray.data.aggregate import MissingValuePercentage, ZeroPercentage
5+
from ray.data.aggregate import (
6+
ApproximateQuantile,
7+
MissingValuePercentage,
8+
ZeroPercentage,
9+
)
610
from ray.data.tests.conftest import * # noqa
711
from ray.tests.conftest import * # noqa
812

@@ -276,6 +280,87 @@ def test_zero_percentage_negative_values(self, ray_start_regular_shared_2_cpus):
276280
assert result["zero_pct(value)"] == expected
277281

278282

283+
class TestApproximateQuantile:
284+
"""Test cases for ApproximateQuantile aggregation."""
285+
286+
def test_approximate_quantile_basic(self, ray_start_regular_shared_2_cpus):
287+
"""Test basic approximate quantile calculation."""
288+
data = [
289+
{
290+
"id": 1,
291+
"value": 10,
292+
},
293+
{"id": 2, "value": 0},
294+
{"id": 3, "value": 30},
295+
{"id": 4, "value": 0},
296+
{"id": 5, "value": 50},
297+
]
298+
ds = ray.data.from_items(data)
299+
300+
result = ds.aggregate(
301+
ApproximateQuantile(on="value", quantiles=[0.1, 0.5, 0.9])
302+
)
303+
expected = [0.0, 10.0, 50.0]
304+
assert result["approx_quantile(value)"] == expected
305+
306+
def test_approximate_quantile_ignores_nulls(self, ray_start_regular_shared_2_cpus):
307+
data = [
308+
{"id": 1, "value": 5.0},
309+
{"id": 2, "value": None},
310+
{"id": 3, "value": 15.0},
311+
{"id": 4, "value": None},
312+
{"id": 5, "value": 25.0},
313+
]
314+
ds = ray.data.from_items(data)
315+
316+
result = ds.aggregate(ApproximateQuantile(on="value", quantiles=[0.5]))
317+
assert result["approx_quantile(value)"] == [15.0]
318+
319+
def test_approximate_quantile_custom_alias(self, ray_start_regular_shared_2_cpus):
320+
data = [
321+
{"id": 1, "value": 1.0},
322+
{"id": 2, "value": 3.0},
323+
{"id": 3, "value": 5.0},
324+
{"id": 4, "value": 7.0},
325+
{"id": 5, "value": 9.0},
326+
]
327+
ds = ray.data.from_items(data)
328+
329+
quantiles = [0.0, 1.0]
330+
result = ds.aggregate(
331+
ApproximateQuantile(
332+
on="value", quantiles=quantiles, alias_name="value_range"
333+
)
334+
)
335+
336+
assert result["value_range"] == [1.0, 9.0]
337+
assert len(result["value_range"]) == len(quantiles)
338+
339+
def test_approximate_quantile_groupby(self, ray_start_regular_shared_2_cpus):
340+
data = [
341+
{"group": "A", "value": 1.0},
342+
{"group": "A", "value": 2.0},
343+
{"group": "A", "value": 3.0},
344+
{"group": "B", "value": 10.0},
345+
{"group": "B", "value": 20.0},
346+
{"group": "B", "value": 30.0},
347+
]
348+
ds = ray.data.from_items(data)
349+
350+
result = (
351+
ds.groupby("group")
352+
.aggregate(ApproximateQuantile(on="value", quantiles=[0.5]))
353+
.take_all()
354+
)
355+
356+
result_by_group = {
357+
row["group"]: row["approx_quantile(value)"] for row in result
358+
}
359+
360+
assert result_by_group["A"] == [2.0]
361+
assert result_by_group["B"] == [20.0]
362+
363+
279364
if __name__ == "__main__":
280365
import sys
281366

python/ray/data/tests/test_dataset_stats.py

Lines changed: 21 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import ray
66
from ray.data.aggregate import (
7+
ApproximateQuantile,
78
Count,
89
Max,
910
Mean,
@@ -51,8 +52,8 @@ def test_numerical_columns_detection(self):
5152
assert len(feature_aggs.vector_columns) == 0
5253

5354
# Check that we have the right number of aggregators
54-
# 3 numerical columns * 7 aggregators each + 1 string column * 2 aggregators = 23 total
55-
assert len(feature_aggs.aggregators) == 23
55+
# 3 numerical columns * 8 aggregators each + 1 string column * 2 aggregators = 26 total
56+
assert len(feature_aggs.aggregators) == 26
5657

5758
def test_categorical_columns_detection(self):
5859
"""Test that string columns are correctly identified as categorical."""
@@ -74,8 +75,8 @@ def test_categorical_columns_detection(self):
7475
assert "value" in feature_aggs.numerical_columns
7576
assert "category" not in feature_aggs.numerical_columns
7677

77-
# Check aggregator count: 1 numerical * 7 + 2 categorical * 2 = 11
78-
assert len(feature_aggs.aggregators) == 11
78+
# Check aggregator count: 1 numerical * 8 + 2 categorical * 2 = 12
79+
assert len(feature_aggs.aggregators) == 12
7980

8081
def test_vector_columns_detection(self):
8182
"""Test that list columns are correctly identified as vector columns."""
@@ -97,8 +98,8 @@ def test_vector_columns_detection(self):
9798
assert "scalar" in feature_aggs.numerical_columns
9899
assert "text" in feature_aggs.str_columns
99100

100-
# Check aggregator count: 1 numerical * 7 + 1 categorical * 2 + 1 vector * 2 = 11
101-
assert len(feature_aggs.aggregators) == 11
101+
# Check aggregator count: 1 numerical * 8 + 1 categorical * 2 + 1 vector * 2 = 12
102+
assert len(feature_aggs.aggregators) == 12
102103

103104
def test_mixed_column_types(self):
104105
"""Test dataset with all column types mixed together."""
@@ -130,8 +131,8 @@ def test_mixed_column_types(self):
130131
# bool_val should be treated as numerical (integer-like)
131132
assert "bool_val" in feature_aggs.numerical_columns
132133

133-
# Check aggregator count: 3 numerical * 7 + 1 categorical * 2 + 1 vector * 2 = 25
134-
assert len(feature_aggs.aggregators) == 25
134+
# Check aggregator count: 3 numerical * 8 + 1 categorical * 2 + 1 vector * 2 = 28
135+
assert len(feature_aggs.aggregators) == 28
135136

136137
def test_column_filtering(self):
137138
"""Test that only specified columns are included when columns parameter is provided."""
@@ -151,8 +152,8 @@ def test_column_filtering(self):
151152
assert "col3" in feature_aggs.vector_columns
152153
assert "col4" not in feature_aggs.numerical_columns
153154

154-
# Check aggregator count: 1 numerical * 7 + 1 vector * 2 = 9
155-
assert len(feature_aggs.aggregators) == 9
155+
# Check aggregator count: 1 numerical * 8 + 1 vector * 2 = 10
156+
assert len(feature_aggs.aggregators) == 10
156157

157158
def test_empty_dataset_schema(self):
158159
"""Test behavior with empty dataset that has no schema."""
@@ -199,8 +200,8 @@ def test_unsupported_column_types(self):
199200
assert "unsupported_binary" not in feature_aggs.str_columns
200201
assert "unsupported_binary" not in feature_aggs.vector_columns
201202

202-
# Check aggregator count: 1 numerical * 7 + 1 categorical * 2 = 9
203-
assert len(feature_aggs.aggregators) == 9
203+
# Check aggregator count: 1 numerical * 8 + 1 categorical * 2 = 10
204+
assert len(feature_aggs.aggregators) == 10
204205

205206
def test_aggregator_types_verification(self):
206207
"""Test that the correct aggregator types are generated for each column type."""
@@ -215,16 +216,17 @@ def test_aggregator_types_verification(self):
215216
# Check that we have the right types of aggregators
216217
agg_names = [agg.name for agg in feature_aggs.aggregators]
217218

218-
# Numerical aggregators should include all 7 types
219+
# Numerical aggregators should include all 8 types
219220
num_agg_names = [name for name in agg_names if "num" in name]
220-
assert len(num_agg_names) == 7
221+
assert len(num_agg_names) == 8
221222
assert any("count" in name.lower() for name in num_agg_names)
222223
assert any("mean" in name.lower() for name in num_agg_names)
223224
assert any("min" in name.lower() for name in num_agg_names)
224225
assert any("max" in name.lower() for name in num_agg_names)
225226
assert any("std" in name.lower() for name in num_agg_names)
226227
assert any("missing" in name.lower() for name in num_agg_names)
227228
assert any("zero" in name.lower() for name in num_agg_names)
229+
assert any("approx_quantile" in name.lower() for name in num_agg_names)
228230

229231
# Categorical aggregators should include count and missing percentage
230232
cat_agg_names = [name for name in agg_names if "cat" in name]
@@ -246,7 +248,7 @@ def test_aggregator_instances_verification(self):
246248

247249
# Find aggregators for the numerical column
248250
num_aggs = [agg for agg in feature_aggs.aggregators if "num" in agg.name]
249-
assert len(num_aggs) == 7
251+
assert len(num_aggs) == 8
250252

251253
# Check that we have the right aggregator types
252254
agg_types = [type(agg) for agg in num_aggs]
@@ -257,6 +259,7 @@ def test_aggregator_instances_verification(self):
257259
assert Std in agg_types
258260
assert MissingValuePercentage in agg_types
259261
assert ZeroPercentage in agg_types
262+
assert ApproximateQuantile in agg_types
260263

261264
# Find aggregators for the categorical column
262265
cat_aggs = [agg for agg in feature_aggs.aggregators if "cat" in agg.name]
@@ -352,8 +355,8 @@ def test_large_dataset_performance(self):
352355
assert "category" in feature_aggs.str_columns
353356
assert "vector" in feature_aggs.vector_columns
354357

355-
# Check aggregator count: 2 numerical * 7 + 1 categorical * 2 + 1 vector * 2 = 18
356-
assert len(feature_aggs.aggregators) == 18
358+
# Check aggregator count: 2 numerical * 8 + 1 categorical * 2 + 1 vector * 2 = 20
359+
assert len(feature_aggs.aggregators) == 20
357360

358361

359362
class TestIndividualAggregatorFunctions:
@@ -363,7 +366,7 @@ def test_numerical_aggregators(self):
363366
"""Test numerical_aggregators function."""
364367
aggs = numerical_aggregators("test_column")
365368

366-
assert len(aggs) == 7
369+
assert len(aggs) == 8
367370
assert all(hasattr(agg, "get_target_column") for agg in aggs)
368371
assert all(agg.get_target_column() == "test_column" for agg in aggs)
369372

python/requirements/ml/data-test-requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@ pyiceberg[sql-sqlite]==0.9.0
2323
clickhouse-connect
2424
pybase64
2525
hudi==0.4.0
26+
datasketches

python/requirements_compiled.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,8 @@ datasets==3.6.0
421421
# -r python/requirements/ml/data-test-requirements.txt
422422
# -r python/requirements/ml/train-requirements.txt
423423
# evaluate
424+
datasketches==5.2.0
425+
# via -r python/requirements/ml/data-test-requirements.txt
424426
debugpy==1.8.0
425427
# via ipykernel
426428
decorator==5.1.1
@@ -1247,6 +1249,7 @@ numpy==1.26.4
12471249
# cupy-cuda12x
12481250
# dask
12491251
# datasets
1252+
# datasketches
12501253
# decord
12511254
# deepspeed
12521255
# dm-control

0 commit comments

Comments
 (0)