Skip to content

Commit 86c5f53

Browse files
goutamvenkat-anyscale400Ping
authored andcommitted
[Data] - Fix Pushdown Optimizations with Hive Partitioning (ray-project#58723)
## Description When hive partitioned, partition cols don't reside in the physical schema of the table, so you can't do projection and predicate pushdown of that subset of columns into the read layer. Basically we filter those out before pushing down. ## Related issues Fixes ray-project#58714 ## Additional information > Optional: Add implementation details, API changes, usage examples, screenshots, etc. --------- Signed-off-by: Goutam <goutam@anyscale.com>
1 parent 80f8747 commit 86c5f53

File tree

8 files changed

+1129
-15
lines changed

8 files changed

+1129
-15
lines changed

python/ray/data/_internal/datasource/parquet_datasource.py

Lines changed: 268 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
_BATCH_SIZE_PRESERVING_STUB_COL_NAME,
2626
ArrowBlockAccessor,
2727
)
28+
from ray.data._internal.planner.plan_expression.expression_visitors import (
29+
get_column_references,
30+
)
2831
from ray.data._internal.progress_bar import ProgressBar
2932
from ray.data._internal.remote_fn import cached_remote_fn
3033
from ray.data._internal.util import (
@@ -52,6 +55,7 @@
5255
from ray.data.datasource.path_util import (
5356
_resolve_paths_and_filesystem,
5457
)
58+
from ray.data.expressions import BinaryExpr, Expr, Operation
5559
from ray.util.debug import log_once
5660

5761
if TYPE_CHECKING:
@@ -160,6 +164,120 @@ def check_for_legacy_tensor_type(schema):
160164
)
161165

162166

167+
@dataclass
168+
class _SplitPredicateResult:
169+
"""Result of splitting a predicate by column type.
170+
171+
Attributes:
172+
data_predicate: Expression containing only data column predicates
173+
(for PyArrow pushdown), or None if no data predicates exist.
174+
partition_predicate: Expression containing only partition column predicates
175+
(for partition pruning), or None if no partition predicates exist.
176+
"""
177+
178+
data_predicate: Optional[Expr]
179+
partition_predicate: Optional[Expr]
180+
181+
182+
def _split_predicate_by_columns(
183+
predicate: Expr,
184+
partition_columns: set,
185+
) -> _SplitPredicateResult:
186+
"""Split a predicate into data-only and partition-only parts.
187+
188+
This function extracts both data column predicates and partition column
189+
predicates from AND chains, enabling both PyArrow pushdown (data part) and
190+
partition pruning (partition part).
191+
192+
Args:
193+
predicate: The predicate expression to analyze.
194+
partition_columns: Set of partition column names.
195+
196+
Returns:
197+
_SplitPredicateResult containing:
198+
- data_predicate: Expression with only data columns (for PyArrow pushdown),
199+
or None if no data predicates can be extracted.
200+
- partition_predicate: Expression with only partition columns (for pruning),
201+
or None if no partition predicates can be extracted.
202+
203+
Examples:
204+
>>> from ray.data.expressions import col
205+
>>> # Pure data predicate:
206+
>>> result = _split_predicate_by_columns(col("data1") > 5, {"partition_col"})
207+
>>> result.data_predicate is not None # Should have data predicate
208+
True
209+
>>> result.partition_predicate is None # Should not have partition predicate
210+
True
211+
212+
>>> # Pure partition predicate:
213+
>>> result = _split_predicate_by_columns(col("partition_col") == "US", {"partition_col"})
214+
>>> result.data_predicate is None # Should not have data predicate
215+
True
216+
>>> result.partition_predicate is not None # Should have partition predicate
217+
True
218+
219+
>>> # Mixed AND - can split both parts:
220+
>>> result = _split_predicate_by_columns(
221+
... (col("data1") > 5) & (col("partition_col") == "US"),
222+
... {"partition_col"}
223+
... )
224+
>>> result.data_predicate is not None # Should have data predicate
225+
True
226+
>>> result.partition_predicate is not None # Should have partition predicate
227+
True
228+
229+
>>> # Mixed OR - can't split safely:
230+
>>> result = _split_predicate_by_columns(
231+
... (col("data1") > 5) | (col("partition_col") == "US"),
232+
... {"partition_col"}
233+
... )
234+
>>> result.data_predicate is None # Should not have data predicate
235+
True
236+
>>> result.partition_predicate is None # Should not have partition predicate
237+
True
238+
"""
239+
referenced_cols = set(get_column_references(predicate))
240+
data_cols = referenced_cols - partition_columns
241+
partition_cols_in_predicate = referenced_cols & partition_columns
242+
243+
if not partition_cols_in_predicate:
244+
# Pure data predicate
245+
return _SplitPredicateResult(data_predicate=predicate, partition_predicate=None)
246+
247+
if not data_cols:
248+
# Pure partition predicate
249+
return _SplitPredicateResult(data_predicate=None, partition_predicate=predicate)
250+
251+
# Mixed predicate - try to split if it's an AND chain
252+
if isinstance(predicate, BinaryExpr) and predicate.op == Operation.AND:
253+
# Recursively split left and right sides
254+
left_result = _split_predicate_by_columns(predicate.left, partition_columns)
255+
right_result = _split_predicate_by_columns(predicate.right, partition_columns)
256+
257+
# Helper to combine predicates from both sides
258+
def combine_predicates(
259+
left: Optional[Expr], right: Optional[Expr]
260+
) -> Optional[Expr]:
261+
if left and right:
262+
return left & right
263+
return left or right
264+
265+
data_predicate = combine_predicates(
266+
left_result.data_predicate, right_result.data_predicate
267+
)
268+
partition_predicate = combine_predicates(
269+
left_result.partition_predicate, right_result.partition_predicate
270+
)
271+
272+
return _SplitPredicateResult(
273+
data_predicate=data_predicate, partition_predicate=partition_predicate
274+
)
275+
276+
# For OR, NOT, or other operations with mixed columns,
277+
# we can't safely split - must evaluate the full predicate together
278+
return _SplitPredicateResult(data_predicate=None, partition_predicate=None)
279+
280+
163281
class ParquetDatasource(Datasource):
164282
"""Parquet datasource, for reading and writing Parquet files.
165283
@@ -255,9 +373,13 @@ def __init__(
255373
# columns manually.
256374
data_columns, partition_columns = None, None
257375
if columns is not None:
258-
data_columns, partition_columns = _infer_data_and_partition_columns(
259-
columns, pq_ds.fragments[0], partitioning
260-
)
376+
if pq_ds.fragments:
377+
data_columns, partition_columns = _infer_data_and_partition_columns(
378+
columns, pq_ds.fragments[0], partitioning
379+
)
380+
else:
381+
# Empty dataset - can't infer columns without fragments
382+
data_columns, partition_columns = [], []
261383

262384
if to_batch_kwargs is None:
263385
to_batch_kwargs = {}
@@ -274,11 +396,35 @@ def __init__(
274396
self._to_batches_kwargs = to_batch_kwargs
275397
# Store as projection_map (identity mapping if columns specified, None otherwise)
276398
# Note: Empty list [] means no columns, None means all columns
277-
if data_columns is None:
399+
# Include partition columns in projection_map if they were requested, so that
400+
# projection pushdown can properly track them
401+
if data_columns is None and partition_columns is None:
278402
self._projection_map = None
279403
else:
280-
self._projection_map = {col: col for col in data_columns}
281-
self._partition_columns = partition_columns
404+
self._projection_map = {}
405+
if data_columns is not None:
406+
self._projection_map.update({col: col for col in data_columns})
407+
if partition_columns is not None:
408+
self._projection_map.update({col: col for col in partition_columns})
409+
410+
# Eagerly compute the actual partition columns for _partition_columns.
411+
# This ensures _partition_columns is always a list (never None).
412+
actual_partition_columns = partition_columns
413+
if partition_columns is None and partitioning is not None and pq_ds.fragments:
414+
parse = PathPartitionParser(partitioning)
415+
parsed_partitions = parse(pq_ds.fragments[0].path)
416+
if parsed_partitions:
417+
actual_partition_columns = list(parsed_partitions.keys())
418+
419+
# Store selected partition columns. Always a list (never None) representing
420+
# the actual partition columns to include.
421+
self._partition_columns = (
422+
actual_partition_columns if actual_partition_columns is not None else []
423+
)
424+
# Track whether partition columns were explicitly part of the user's column selection
425+
self._partition_columns_selected = (
426+
partition_columns is not None and len(self._partition_columns) > 0
427+
)
282428
self._read_schema = schema
283429
self._file_schema = pq_ds.schema
284430
self._partition_schema = _get_partition_columns_schema(
@@ -390,7 +536,7 @@ def get_read_tasks(
390536
self._default_batch_size,
391537
self._get_data_columns(),
392538
self.get_column_renames(),
393-
self._partition_columns,
539+
self._get_partition_columns(),
394540
self._read_schema,
395541
self._include_paths,
396542
self._partitioning,
@@ -441,10 +587,123 @@ def get_current_projection(self) -> Optional[List[str]]:
441587
# NOTE: In case there's no projection both file and partition columns
442588
# will be none
443589
data_columns = self._get_data_columns()
444-
if data_columns is None and self._partition_columns is None:
590+
partition_columns = self._get_partition_columns()
591+
if data_columns is None and partition_columns is None:
445592
return None
446593

447-
return (data_columns or []) + (self._partition_columns or [])
594+
return (data_columns or []) + (partition_columns or [])
595+
596+
def _get_partition_columns(self) -> Optional[List[str]]:
597+
"""Extract partition columns from projection map.
598+
599+
This method extracts partition columns from _projection_map, which is the
600+
source of truth after projection pushdown. Since partition columns are now
601+
included in _projection_map during initialization when requested, we can
602+
reliably extract them from the map.
603+
604+
Returns:
605+
List of partition column names in the projection, None if there's
606+
no projection (meaning include all partition columns), or [] if
607+
partition columns aren't in the projection map (meaning include
608+
no partition columns).
609+
"""
610+
if self._projection_map is None:
611+
return None
612+
613+
if not self._partition_columns:
614+
return None
615+
616+
# Extract partition columns that are in the projection map
617+
partition_cols = [
618+
col for col in self._projection_map.keys() if col in self._partition_columns
619+
]
620+
621+
# If partition columns are found in projection map, return them
622+
if partition_cols:
623+
return partition_cols
624+
625+
# No partition columns in projection map.
626+
# Since the projection map exists and is the source of truth after
627+
# projection pushdown, return [] (no partition columns to include).
628+
return []
629+
630+
def _get_data_columns(self) -> Optional[List[str]]:
631+
"""Extract data columns from projection map, excluding partition columns.
632+
633+
Partition columns aren't in the physical file schema, so they must be
634+
filtered out before passing to PyArrow's to_batches().
635+
636+
Returns:
637+
List of data column names to read from files, or None if no projection.
638+
Can return empty list if only partition columns are projected.
639+
"""
640+
if self._projection_map is None:
641+
return None
642+
643+
# Get partition columns and filter them out from the projection
644+
partition_cols = self._partition_columns
645+
data_cols = [
646+
col for col in self._projection_map.keys() if col not in partition_cols
647+
]
648+
649+
return data_cols
650+
651+
def apply_predicate(
652+
self,
653+
predicate_expr: Expr,
654+
) -> "ParquetDatasource":
655+
"""Apply a predicate with data pushdown and partition pruning.
656+
657+
This method optimizes predicates in three ways:
658+
1. Data predicates → pushed to PyArrow (row-level filtering)
659+
2. Partition predicates → used for partition pruning (file-level filtering)
660+
3. Mixed predicates → both optimizations applied together
661+
"""
662+
partition_cols = set(self._partition_columns)
663+
664+
if not partition_cols:
665+
# No partition columns - can push down everything normally
666+
return super().apply_predicate(predicate_expr)
667+
668+
# Split predicate into data and partition parts
669+
split_result = _split_predicate_by_columns(predicate_expr, partition_cols)
670+
671+
# Apply partition pruning if we have a partition predicate
672+
if (
673+
split_result.partition_predicate is not None
674+
and self._partitioning is not None
675+
):
676+
parser = PathPartitionParser(self._partitioning)
677+
pruned_fragments = []
678+
pruned_paths = []
679+
680+
for fragment, path in zip(self._pq_fragments, self._pq_paths):
681+
# Evaluate partition predicate - skip if it doesn't match
682+
if parser.evaluate_predicate_on_partition(
683+
path, split_result.partition_predicate
684+
):
685+
pruned_fragments.append(fragment)
686+
pruned_paths.append(path)
687+
688+
# Apply partition pruning directly to self
689+
self._pq_fragments = pruned_fragments
690+
self._pq_paths = pruned_paths
691+
692+
# Push down data predicate to PyArrow if present
693+
# Create a copy and push down the data predicate to PyArrow
694+
import copy
695+
696+
datasource = copy.copy(self)
697+
698+
# Only call apply_predicate if there's a data predicate to push down
699+
# If data_predicate is None (pure partition predicate), skip it to avoid
700+
# creating invalid expressions like existing_expr & None
701+
if split_result.data_predicate is not None:
702+
return super(ParquetDatasource, datasource).apply_predicate(
703+
split_result.data_predicate
704+
)
705+
706+
return datasource
448707

449708
def _estimate_in_mem_size(self, fragments: List[_ParquetFragment]) -> int:
450709
in_mem_size = sum([f.file_size for f in fragments]) * self._encoding_ratio

python/ray/data/_internal/logical/operators/read_operator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,9 @@ def get_current_predicate(self) -> Optional[Expr]:
187187
return self._datasource.get_current_predicate()
188188

189189
def apply_predicate(self, predicate_expr: Expr) -> "Read":
190-
clone = copy.copy(self)
191-
192190
predicated_datasource = self._datasource.apply_predicate(predicate_expr)
191+
192+
clone = copy.copy(self)
193193
clone._datasource = predicated_datasource
194194
clone._datasource_or_legacy_reader = predicated_datasource
195195

python/ray/data/_internal/logical/rules/predicate_pushdown.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,16 @@ def _try_push_down_predicate(cls, op: LogicalOperator) -> LogicalOperator:
193193
predicate_expr, rename_map
194194
)
195195

196-
# Push the predicate down and return the result without the filter
197-
return input_op.apply_predicate(predicate_expr)
196+
# Push the predicate down
197+
result_op = input_op.apply_predicate(predicate_expr)
198+
199+
# If the operator is unchanged (e.g., predicate references partition columns
200+
# that can't be pushed down), keep the Filter operator
201+
if result_op is input_op:
202+
return filter_op
203+
204+
# Otherwise, return the result without the filter (predicate was pushed down)
205+
return result_op
198206

199207
# Case 2: Check if operator allows predicates to pass through
200208
if isinstance(input_op, LogicalOperatorSupportsPredicatePassThrough):

python/ray/data/_internal/planner/plan_expression/expression_visitors.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,3 +347,26 @@ def visit_download(self, expr: "DownloadExpr") -> str:
347347

348348
def visit_star(self, expr: "StarExpr") -> str:
349349
return self._make_tree_lines("COL(*)", expr=expr)
350+
351+
352+
def get_column_references(expr: Expr) -> List[str]:
353+
"""Extract all column references from an expression.
354+
355+
This is a convenience function that creates a _ColumnReferenceCollector,
356+
visits the expression tree, and returns the list of referenced column names.
357+
358+
Args:
359+
expr: The expression to extract column references from.
360+
361+
Returns:
362+
List of column names referenced in the expression, in order of appearance.
363+
364+
Example:
365+
>>> from ray.data.expressions import col
366+
>>> expr = (col("a") > 5) & (col("b") == "test")
367+
>>> get_column_references(expr)
368+
['a', 'b']
369+
"""
370+
collector = _ColumnReferenceCollector()
371+
collector.visit(expr)
372+
return collector.get_column_refs()

0 commit comments

Comments
 (0)