Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/data/api/_autogen.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
DataIterator
Dataset
Schema
stats.DatasetSummary
grouped_data.GroupedData
aggregate.AggregateFn
aggregate.AggregateFnV2
5 changes: 5 additions & 0 deletions doc/source/data/api/dataset.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@ Schema
.. autoclass:: Schema
:members:

DatasetSummary
--------------
.. currentmodule:: ray.data.stats

.. autoclass:: DatasetSummary
:members:

Developer API
-------------
Expand Down
6 changes: 6 additions & 0 deletions doc/source/data/api/datatype.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,9 @@ Class

.. autoclass:: DataType
:members:

Enumeration
-----------

.. autoclass:: TypeCategory
:members:
2 changes: 1 addition & 1 deletion python/ray/data/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -1777,7 +1777,7 @@ py_test(

py_test(
name = "test_dataset_stats",
size = "small",
size = "large",
srcs = ["tests/test_dataset_stats.py"],
tags = [
"exclusive",
Expand Down
2 changes: 2 additions & 0 deletions python/ray/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ray.data._internal.logging import configure_logging
from ray.data.context import DataContext, DatasetContext
from ray.data.dataset import Dataset, Schema, SinkMode, ClickHouseTableSettings
from ray.data.stats import DatasetSummary
from ray.data.datasource import (
BlockBasedFileDatasink,
Datasink,
Expand Down Expand Up @@ -123,6 +124,7 @@
"Dataset",
"DataContext",
"DatasetContext", # Backwards compatibility alias.
"DatasetSummary",
"DataIterator",
"DatasetIterator", # Backwards compatibility alias.
"Datasink",
Expand Down
19 changes: 19 additions & 0 deletions python/ray/data/aggregate.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import abc
import math
import pickle
import re
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -51,6 +52,8 @@ def __ge__(self, other: Any) -> bool:
)
AggOutputType = TypeVar("AggOutputType")

_AGGREGATION_NAME_PATTERN = re.compile(r"^([^(]+)(?:\(.*\))?$")


@Deprecated(message="AggregateFn is deprecated, please use AggregateFnV2")
@PublicAPI
Expand Down Expand Up @@ -199,6 +202,14 @@ def __init__(
self._target_col_name = on
self._ignore_nulls = ignore_nulls

# Extract and store the agg name (e.g., "sum" from "sum(col)")
# This avoids string parsing later
match = _AGGREGATION_NAME_PATTERN.match(name)
if match:
self._agg_name = match.group(1)
else:
self._agg_name = name

_safe_combine = _null_safe_combine(self.combine, ignore_nulls)
_safe_aggregate = _null_safe_aggregate(self.aggregate_block, ignore_nulls)
_safe_finalize = _null_safe_finalize(self.finalize)
Expand All @@ -216,6 +227,14 @@ def __init__(
def get_target_column(self) -> Optional[str]:
return self._target_col_name

def get_agg_name(self) -> str:
"""Return the agg name (e.g., 'sum', 'mean', 'count').

Returns the aggregation type extracted from the name during initialization.
For example, returns 'sum' for an aggregator named 'sum(col)'.
"""
return self._agg_name

@abc.abstractmethod
def combine(
self, current_accumulator: AccumulatorType, new: AccumulatorType
Expand Down
150 changes: 149 additions & 1 deletion python/ray/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,16 @@
get_compute_strategy,
merge_resources_to_ray_remote_args,
)
from ray.data.aggregate import AggregateFn, Max, Mean, Min, Std, Sum, Unique
from ray.data.aggregate import (
AggregateFn,
AggregateFnV2,
Max,
Mean,
Min,
Std,
Sum,
Unique,
)
from ray.data.block import (
VALID_BATCH_FORMATS,
Block,
Expand All @@ -112,6 +121,7 @@
from ray.data.datasource import Connection, Datasink, FilenameProvider, SaveMode
from ray.data.datasource.datasink import WriteResult, _gen_datasink_write_result
from ray.data.datasource.file_datasink import _FileDatasink
from ray.data.datatype import DataType
from ray.data.iterator import DataIterator
from ray.data.random_access_dataset import RandomAccessDataset
from ray.types import ObjectRef
Expand All @@ -135,6 +145,7 @@

from ray.data._internal.execution.interfaces import Executor, NodeIdStr
from ray.data.grouped_data import GroupedData
from ray.data.stats import DatasetSummary

from ray.data.expressions import Expr, StarExpr, col

Expand Down Expand Up @@ -3254,6 +3265,143 @@ def std(
ret = self._aggregate_on(Std, on, ignore_nulls=ignore_nulls, ddof=ddof)
return self._aggregate_result(ret)

@AllToAllAPI
@ConsumptionAPI
@PublicAPI(api_group=GGA_API_GROUP, stability="alpha")
def summary(
self,
columns: Optional[List[str]] = None,
override_dtype_agg_mapping: Optional[
Dict[DataType, Callable[[str], List[AggregateFnV2]]]
] = None,
) -> "DatasetSummary":
"""Generate a statistical summary of the dataset, organized by data type.

This method computes various statistics for different column dtypes:

- For numerical dtypes (int*, float*, decimal, bool): count, mean, min, max, std, approx_quantile (median), missing%, zero%
- For string and binary dtypes: count, missing%, approx_top_k (top 10 values)
- For temporal dtypes (timestamp, date, time, duration): count, min, max, missing%
- For other dtypes: count, missing%, approx_top_k

You can customize the aggregations performed for specific data types using the
`override_dtype_agg_mapping` parameter.

The summary separates statistics into two tables:
- Schema-matching stats: Statistics that preserve the original column type (e.g., min/max for integers)
- Schema-changing stats: Statistics that change the type (e.g., mean converts int to float)

Examples:
>>> import ray
>>> ds = ray.data.from_items([
... {"age": 25, "salary": 50000, "name": "Alice", "city": "NYC"},
... {"age": 30, "salary": 60000, "name": None, "city": "LA"},
... {"age": 0, "salary": None, "name": "Bob", "city": None},
... ])
>>> summary = ds.summary()
>>> # Get combined pandas DataFrame with all statistics
>>> summary.to_pandas() # doctest: +SKIP
statistic age city name salary
0 approx_quantile[0] 25.000000 None None 60000.000000
1 approx_topk[0] NaN {'city': 'LA', 'count': 1} {'count': 1, 'name': 'Bob'} NaN
2 approx_topk[1] NaN {'city': 'NYC', 'count': 1} {'count': 1, 'name': 'Alice'} NaN
3 count 3.000000 3 3 3.000000
4 max 30.000000 NaN NaN 60000.000000
5 mean 18.333333 None None 55000.000000
6 min 0.000000 NaN NaN 50000.000000
7 missing_pct 0.000000 33.333333 33.333333 33.333333
8 std 13.123346 None None 5000.000000
9 zero_pct 33.333333 None None 0.000000

>>> # Access individual column statistics
>>> summary.get_column_stats("age") # doctest: +SKIP
statistic value
0 approx_quantile[0] 25.000000
1 approx_topk[0] NaN
2 approx_topk[1] NaN
3 count 3.000000
4 max 30.000000
5 mean 18.333333
6 min 0.000000
7 missing_pct 0.000000
8 std 13.123346
9 zero_pct 33.333333

Custom aggregations for specific types:

>>> from ray.data.datatype import DataType
>>> from ray.data.aggregate import Sum, Count
>>> # Override aggregations for int64 columns
>>> custom_mapping = {
... DataType.int64(): lambda col: [Count(on=col), Sum(on=col)]
... }
>>> summary = ds.summary(override_dtype_agg_mapping=custom_mapping)

Args:
columns: Optional list of column names to include in the summary.
If None, all columns will be included.
override_dtype_agg_mapping: Optional mapping from DataType to factory
functions. Each factory function takes a column name and returns a
list of aggregators for that column. This will be merged with the
default mapping, with user-provided mappings taking precedence.

Returns:
A DatasetSummary object with methods to access statistics and the
original dataset schema. Use `to_pandas()` to get all statistics
as a DataFrame, or `get_column_stats(col)` for a specific column
"""
from ray.data.stats import (
DatasetSummary,
_build_summary_table,
_dtype_aggregators_for_dataset,
_parse_summary_stats,
)

# Compute aggregations
dtype_aggs = _dtype_aggregators_for_dataset(
self.schema(),
columns=columns,
dtype_agg_mapping=override_dtype_agg_mapping,
)

if not dtype_aggs.aggregators:
raise ValueError(
"summary() requires at least one column with a supported type. "
f"Columns provided: {columns if columns is not None else 'all'}. "
"Check that the specified columns exist and have supported types "
"(numeric, string, binary, or temporal). Columns with None or "
"object types are skipped."
)

aggs_dataset = self.groupby(None).aggregate(*dtype_aggs.aggregators)
agg_result = aggs_dataset.take(1)[0]

# Separate statistics by whether they preserve original column types
original_schema = self.schema().base_schema
agg_schema = aggs_dataset.schema().base_schema
(
schema_matching_stats,
schema_changing_stats,
all_columns,
) = _parse_summary_stats(
agg_result, original_schema, agg_schema, dtype_aggs.aggregators
)

# Build PyArrow tables
schema_matching_table = _build_summary_table(
schema_matching_stats, all_columns, original_schema, preserve_types=True
)
schema_changing_table = _build_summary_table(
schema_changing_stats, all_columns, original_schema, preserve_types=False
)

return DatasetSummary(
_stats_matching_column_dtype=schema_matching_table,
_stats_mismatching_column_dtype=schema_changing_table,
dataset_schema=original_schema,
columns=list(all_columns),
)

@AllToAllAPI
@PublicAPI(api_group=SSR_API_GROUP)
def sort(
Expand Down
Loading